Unverified Commit 5e8e6cb4 authored by Sayak Paul's avatar Sayak Paul Committed by GitHub
Browse files

[bitsandbytes] Simplify bnb int8 dequant (#10401)

* fix dequantization for latest bnb.

* smol fixes.

* fix type annotation

* update peft link

* updates
parent 3e35f56b
...@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name ...@@ -153,8 +153,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
return model return model
# Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41 # Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
""" """
Helper function to dequantize 4bit or 8bit bnb weights. Helper function to dequantize 4bit or 8bit bnb weights.
...@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None): ...@@ -177,13 +177,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
if state.SCB is None: if state.SCB is None:
state.SCB = weight.SCB state.SCB = weight.SCB
im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device) if hasattr(bnb.functional, "int8_vectorwise_dequant"):
im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im) # Use bitsandbytes API if available (requires v0.45.0+)
im, Sim = bnb.functional.transform(im, "col32") dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
if state.CxB is None: else:
state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB) # Multiply by (scale/127) to dequantize.
out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB) dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
if dtype:
dequantized = dequantized.to(dtype)
return dequantized
def _create_accelerate_new_hook(old_hook): def _create_accelerate_new_hook(old_hook):
...@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook): ...@@ -205,6 +208,7 @@ def _create_accelerate_new_hook(old_hook):
def _dequantize_and_replace( def _dequantize_and_replace(
model, model,
dtype,
modules_to_not_convert=None, modules_to_not_convert=None,
current_key_name=None, current_key_name=None,
quantization_config=None, quantization_config=None,
...@@ -244,7 +248,7 @@ def _dequantize_and_replace( ...@@ -244,7 +248,7 @@ def _dequantize_and_replace(
else: else:
state = None state = None
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state)) new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
if bias is not None: if bias is not None:
new_module.bias = bias new_module.bias = bias
...@@ -263,9 +267,10 @@ def _dequantize_and_replace( ...@@ -263,9 +267,10 @@ def _dequantize_and_replace(
if len(list(module.children())) > 0: if len(list(module.children())) > 0:
_, has_been_replaced = _dequantize_and_replace( _, has_been_replaced = _dequantize_and_replace(
module, module,
modules_to_not_convert, dtype=dtype,
current_key_name, modules_to_not_convert=modules_to_not_convert,
quantization_config, current_key_name=current_key_name,
quantization_config=quantization_config,
has_been_replaced=has_been_replaced, has_been_replaced=has_been_replaced,
) )
# Remove the last key for recursion # Remove the last key for recursion
...@@ -280,6 +285,7 @@ def dequantize_and_replace( ...@@ -280,6 +285,7 @@ def dequantize_and_replace(
): ):
model, has_been_replaced = _dequantize_and_replace( model, has_been_replaced = _dequantize_and_replace(
model, model,
dtype=model.dtype,
modules_to_not_convert=modules_to_not_convert, modules_to_not_convert=modules_to_not_convert,
quantization_config=quantization_config, quantization_config=quantization_config,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment