"vscode:/vscode.git/clone" did not exist on "2df85862a51711b49abd95b8797900a6511599d3"
Unverified Commit 89588179 authored by Stefan He's avatar Stefan He Committed by GitHub
Browse files

[1/3] Optimize Slime Update Weights: Remove QWen3MOE Load Weight Overhead (#8751)

parent 8c7bb39d
...@@ -766,7 +766,10 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -766,7 +766,10 @@ class Qwen3MoeForCausalLM(nn.Module):
num_experts=self.config.num_experts, num_experts=self.config.num_experts,
) )
params_dict = dict(self.named_parameters()) # Cache params_dict to avoid repeated expensive traversal of model parameters
if not hasattr(self, "_cached_params_dict"):
self._cached_params_dict = dict(self.named_parameters())
params_dict = self._cached_params_dict
for name, loaded_weight in weights: for name, loaded_weight in weights:
layer_id = get_layer_id(name) layer_id = get_layer_id(name)
if ( if (
...@@ -805,11 +808,22 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -805,11 +808,22 @@ class Qwen3MoeForCausalLM(nn.Module):
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Track if this is an expert weight to enable early skipping
is_expert_weight = False
for mapping in expert_params_mapping: for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name: if weight_name not in name:
continue continue
# Mark as expert weight regardless of whether we can process it
is_expert_weight = True
name = name.replace(weight_name, param_name) name = name.replace(weight_name, param_name)
if name not in params_dict:
# Expert weight not on this rank, will be skipped below
continue
param = params_dict[name] param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader( weight_loader(
...@@ -821,6 +835,10 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -821,6 +835,10 @@ class Qwen3MoeForCausalLM(nn.Module):
) )
break break
else: else:
if is_expert_weight:
# This is an expert weight but not mapped to this rank, skip all remaining processing
continue
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
...@@ -837,11 +855,13 @@ class Qwen3MoeForCausalLM(nn.Module): ...@@ -837,11 +855,13 @@ class Qwen3MoeForCausalLM(nn.Module):
logger.warning(f"Parameter {name} not found in params_dict") logger.warning(f"Parameter {name} not found in params_dict")
# TODO mimic deepseek # TODO mimic deepseek
self.routed_experts_weights_of_layer = { # Lazy initialization of expert weights cache to avoid slowing down load_weights
layer_id: self.model.layers[layer_id].mlp.get_moe_weights() if not hasattr(self, "routed_experts_weights_of_layer"):
for layer_id in range(self.start_layer, self.end_layer) self.routed_experts_weights_of_layer = {
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock) layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
} for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, Qwen3MoeSparseMoeBlock)
}
@classmethod @classmethod
def get_model_config_for_expert_location(cls, config): def get_model_config_for_expert_location(cls, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment