Unverified Commit 4397f59a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbytes] improve dtype mismatch handling for bnb + lora. (#11270)

* improve dtype mismatch handling for bnb + lora.

* add a test

* fix and updates

* update
parent 05679329
......@@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once(
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
)
msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
if dtype:
msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
output_tensor = output_tensor.to(dtype)
logger.warning_once(msg)
return output_tensor
if state.SCB is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment