"vscode:/vscode.git/clone" did not exist on "e5558cb6e109eeb3ea510a89e7fef37f1a0beac3"
Unverified Commit 08effbff authored by Sangchun Ha (Patrick)'s avatar Sangchun Ha (Patrick) Committed by GitHub
Browse files

Error occurs when loading the gemma model in bitsandbytes format. (#2557)

parent 60bd3272
......@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict,
)
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator:
......@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if not weight_name.lower().endswith(".scb"):
continue
weight_key = weight_name.lower().replace(".scb", ".qweight")
weight_key = weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
if self._is_8bit_weight_name(weight_name):
continue
qweight_name = weight_name.replace(".weight", ".qweight")
if qweight_name in quant_state_dict:
if weight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor
......@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")):
if not self._is_4bit_weight_name(weight_name):
continue
# bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU
......@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files, use_safetensors
):
if not weight_name.endswith((".weight", ".bias")):
if self._is_4bit_weight_name(weight_name):
continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
):
quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor
yield weight_name, weight_tensor
else:
yield weight_name, weight_tensor
......
......@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
class Gemma2ForCausalLM(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = {
"qkv_proj": [
"q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment