Commit 61898ea3 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix llama and qwen layout

parent 4a7d8ab8
......@@ -86,7 +86,6 @@ def get_model_architecture(
model_config: ModelConfig) -> Tuple[Type[nn.Module], str]:
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# TODO: support deepseek distillation series models ( 'LlamaForCausalLM', 'Qwen2ForCausalLM' )
support_nn_architectures = ['LlamaForCausalLM', 'Qwen2ForCausalLM', 'QWenLMHeadModel', 'Qwen2VLForConditionalGeneration', 'Qwen2_5_VLForConditionalGeneration', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV2ForCausalLM', 'DeepseekV3ForCausalLM', 'DeepSeekMTPModel']
if any(arch in architectures for arch in support_nn_architectures):
......
......@@ -432,6 +432,11 @@ def safetensors_weights_iterator(
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
total_count = 0
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
total_count += len(f.keys())
current_count = 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
......@@ -440,7 +445,10 @@ def safetensors_weights_iterator(
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
current_count += 1
param = f.get_tensor(name)
param.current_count = current_count
param.total_count = total_count
yield name, param
......
......@@ -409,6 +409,8 @@ class LlamaModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if ("rotary_emb.cos_cached" in name
......@@ -461,7 +463,7 @@ class LlamaModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None :
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -474,7 +476,8 @@ class LlamaModel(nn.Module):
# qkv_words = "|".join(lay_qkv_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 4096:
lay_key_words.append("lm_head.weight")
......
......@@ -394,6 +394,8 @@ class Qwen2Model(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False))
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
if "rotary_emb.inv_freq" in name:
continue
if (self.quant_config is not None and
......@@ -436,7 +438,7 @@ class Qwen2Model(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"self_attn.qkv_proj.weight",
"self_attn.o_proj.weight",
......@@ -452,7 +454,8 @@ class Qwen2Model(nn.Module):
# qkv_bias_words = "|".join(lay_qkv_bias_words)
# for layername, weight in params_dict.items():
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
if "lm_head.weight" in layername and weight.shape[1] >= 3584:
lay_key_words.append("lm_head.weight")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment