"vscode:/vscode.git/clone" did not exist on "7facfa33d062ee615402c80f8e4f1ea5146598ef"
Unverified Commit 08effbff authored by Sangchun Ha (Patrick)'s avatar Sangchun Ha (Patrick) Committed by GitHub
Browse files

Error occurs when loading the gemma model in bitsandbytes format. (#2557)

parent 60bd3272
...@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -770,6 +770,21 @@ class BitsAndBytesModelLoader(BaseModelLoader):
quant_state_dict, quant_state_dict,
) )
def _is_8bit_weight_name(self, weight_name: str):
quantized_suffix = {".scb", ".weight_format"}
return any(weight_name.lower().endswith(suffix) for suffix in quantized_suffix)
def _is_4bit_weight_name(self, weight_name: str):
quantized_suffix = {
"absmax",
"quant_map",
"nested_absmax",
"nested_quant_map",
"bitsandbytes",
}
suffix = weight_name.split(".")[-1]
return any(q_suffix in suffix for q_suffix in quantized_suffix)
def _quantized_8bit_generator( def _quantized_8bit_generator(
self, hf_weights_files, use_safetensors, quant_state_dict self, hf_weights_files, use_safetensors, quant_state_dict
) -> Generator: ) -> Generator:
...@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -779,21 +794,18 @@ class BitsAndBytesModelLoader(BaseModelLoader):
if not weight_name.lower().endswith(".scb"): if not weight_name.lower().endswith(".scb"):
continue continue
weight_key = weight_name.lower().replace(".scb", ".qweight") weight_key = weight_name.lower().replace(".scb", ".weight")
quant_state_dict[weight_key] = weight_tensor quant_state_dict[weight_key] = weight_tensor
for weight_name, weight_tensor in self._hf_weight_iter( for weight_name, weight_tensor in self._hf_weight_iter(
hf_weights_files, use_safetensors hf_weights_files, use_safetensors
): ):
if self._is_8bit_weight_name(weight_name):
if not weight_name.endswith((".weight", ".bias")):
continue continue
qweight_name = weight_name.replace(".weight", ".qweight") if weight_name in quant_state_dict:
if qweight_name in quant_state_dict:
set_weight_attrs(weight_tensor, {"load_in_8bit": True}) set_weight_attrs(weight_tensor, {"load_in_8bit": True})
yield qweight_name, weight_tensor yield weight_name, weight_tensor
else: else:
yield weight_name, weight_tensor yield weight_name, weight_tensor
...@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -806,7 +818,7 @@ class BitsAndBytesModelLoader(BaseModelLoader):
weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors) weight_iterator = self._hf_weight_iter(hf_weights_files, use_safetensors)
temp_state_dict = {} temp_state_dict = {}
for weight_name, weight_tensor in weight_iterator: for weight_name, weight_tensor in weight_iterator:
if weight_name.endswith((".weight", ".bias")): if not self._is_4bit_weight_name(weight_name):
continue continue
# bitsandbytes library requires # bitsandbytes library requires
# weight.quant_state.bitsandbytes__* in CPU # weight.quant_state.bitsandbytes__* in CPU
...@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader): ...@@ -830,16 +842,15 @@ class BitsAndBytesModelLoader(BaseModelLoader):
hf_weights_files, use_safetensors hf_weights_files, use_safetensors
): ):
if not weight_name.endswith((".weight", ".bias")): if self._is_4bit_weight_name(weight_name):
continue continue
if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or ( if (f"{weight_name}.quant_state.bitsandbytes__nf4" in temp_state_dict) or (
f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict f"{weight_name}.quant_state.bitsandbytes__fp4" in temp_state_dict
): ):
quant_state = _parse_quant_state(weight_name, temp_state_dict) quant_state = _parse_quant_state(weight_name, temp_state_dict)
weight_name = weight_name.replace(".weight", ".qweight")
quant_state_dict[weight_name] = quant_state quant_state_dict[weight_name] = quant_state
yield weight_name.replace(".weight", ".qweight"), weight_tensor yield weight_name, weight_tensor
else: else:
yield weight_name, weight_tensor yield weight_name, weight_tensor
......
...@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module): ...@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
class Gemma2ForCausalLM(nn.Module): class Gemma2ForCausalLM(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
".o_proj.",
]
bitsandbytes_stacked_params_mapping = {
# shard_name, weight_name, index
"q_proj": ("qkv_proj", 0),
"k_proj": ("qkv_proj", 1),
"v_proj": ("qkv_proj", 2),
"gate_proj": ("gate_up_proj", 0),
"up_proj": ("gate_up_proj", 1),
}
packed_modules_mapping = { packed_modules_mapping = {
"qkv_proj": [ "qkv_proj": [
"q_proj", "q_proj",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment