Commit 7bdb03ea authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen3 moe inference

parent 95099cbc
......@@ -414,6 +414,9 @@ class Qwen3MoeModel(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if self.use_llama_nn:
current_count = loaded_weight.current_count
total_count = loaded_weight.total_count
for (param_name, weight_name, shard_id) in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if weight_name not in name:
......@@ -490,7 +493,7 @@ class Qwen3MoeModel(nn.Module):
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
if self.use_llama_nn and self.quant_method is None and current_count==total_count:
lay_key_words = [
"gate_up_proj.weight",
"down_proj.weight",
......@@ -507,7 +510,8 @@ class Qwen3MoeModel(nn.Module):
# lay_qkv_bias_words = ["self_attn.qkv_proj.bias"]
# qkv_bias_words = "|".join(lay_qkv_bias_words)
for layername in loaded_params:
# for layername in loaded_params:
for layername in params_dict.keys():
weight = params_dict[layername]
os.environ['LM_NN'] = '0'
# if self.use_fa_pad and (re.findall(qkv_bias_words, layername)):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment