Unverified Commit 49599150 authored by Jee Jee Li's avatar Jee Jee Li Committed by GitHub
Browse files

[Quantization] Modify the logic of BNB double quantization (#19742)


Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
parent 8d1e89d9
......@@ -492,8 +492,6 @@ class BitsAndBytesModelLoader(BaseModelLoader):
raise ValueError("Following weights were not initialized from "
f"checkpoint: {weights_not_loaded}")
torch.cuda.empty_cache()
param_dict = dict(model.named_parameters())
stacked_quant_state_dict: dict[str, dict[int, Any]] = {}
# TODO: Change this lazy import to normal import
......@@ -545,6 +543,8 @@ class BitsAndBytesModelLoader(BaseModelLoader):
for param_name, param in param_dict.items():
if param_name in stacked_quant_state_dict:
quant_states = stacked_quant_state_dict[param_name]
# Dequantize double quantized values during weight loading.
dequantize_dq(quant_states)
set_weight_attrs(param, {"bnb_quant_state": quant_states})
pack_ratio = getattr(param, "pack_factor", -1)
......@@ -565,6 +565,28 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if load_8bit:
set_weight_attrs(
param, {"matmul_state": [None] * len(quant_states)})
torch.cuda.empty_cache()
def download_model(self, model_config: ModelConfig) -> None:
self._prepare_weights(model_config.model, model_config.revision)
def dequantize_dq(quant_states: dict) -> None:
"""
When BNB employs Double Quantization, we perform the dequantization of
these constants during weight loading rather than at inference time,
thereby avoiding this computational overhead during inference. This comes
at the cost of increased memory usage.
"""
from bitsandbytes.functional import dequantize_blockwise
for _, quant_state in quant_states.items():
# Copied from: https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.45.3/bitsandbytes/functional.py#L1352-#L1356
if quant_state.nested:
absmax = dequantize_blockwise(quant_state.absmax,
quant_state.state2)
absmax += quant_state.offset
if absmax.dtype != torch.float32:
absmax = absmax.float()
quant_state.absmax = absmax
quant_state.nested = False
quant_state.offset = None
quant_state.state2 = None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment