Unverified Commit 4397f59a authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbytes] improve dtype mismatch handling for bnb + lora. (#11270)

* improve dtype mismatch handling for bnb + lora.

* add a test

* fix and updates

* update
parent 05679329
...@@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc ...@@ -171,9 +171,11 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torc
if cls_name == "Params4bit": if cls_name == "Params4bit":
output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state) output_tensor = bnb.functional.dequantize_4bit(weight.data, weight.quant_state)
logger.warning_once( msg = f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`"
f"The model is going to be dequantized in {output_tensor.dtype} - if you want to upcast it to another dtype, make sure to pass the desired dtype when quantizing the model through `bnb_4bit_quant_type` argument of `BitsAndBytesConfig`" if dtype:
) msg = f"The model is going to be first dequantized in {output_tensor.dtype} and type-casted to {dtype}"
output_tensor = output_tensor.to(dtype)
logger.warning_once(msg)
return output_tensor return output_tensor
if state.SCB is None: if state.SCB is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment