Unverified Commit 08effbff authored by Sangchun Ha (Patrick)'s avatar Sangchun Ha (Patrick) Committed by GitHub
Browse files

Error occurs when loading the gemma model in bitsandbytes format. (#2557)

parent 60bd3272
......@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict,
)
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
......@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if not weight_name.lower().endswith(".scb"):
continue
weight_key = weight_name.lower().replace(".scb", ".qweight")
weight_key = weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
if self._is_8bit_weight_name(weight_name):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict:
if weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor
......@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
......@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
if self._is_4bit_weight_name(weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor
......
......@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
class Gemma2ForCausalLM(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment