Unverified Commit ffa5d74f authored by Harry Mellor's avatar Harry Mellor Committed by GitHub
Browse files

Enable loading of fused expert weights in the Transformers modelling backend (#36997)


Signed-off-by: default avatarHarry Mellor <19981378+hmellor@users.noreply.github.com>
parent 74fe80ee
...@@ -1342,22 +1342,41 @@ class FusedMoE(CustomOp): ...@@ -1342,22 +1342,41 @@ class FusedMoE(CustomOp):
weight_name = qual_name.replace(weight_name, param_name) weight_name = qual_name.replace(weight_name, param_name)
param_name = weight_name.removeprefix(f"{self.layer_name}.") param_name = weight_name.removeprefix(f"{self.layer_name}.")
param = getattr(self, param_name) param = getattr(self, param_name)
success = self.weight_loader( # Fused expert weights can be identified by their 3D tensors
param=param, if loaded_weight.dim() == 3:
loaded_weight=loaded_weight, # Repurpose expert_id as shard_idx for deconcatenating w1 and w3
weight_name=weight_name, if shard_id in {"w1", "w3"}:
shard_id=shard_id, shard_idx = expert_id
expert_id=expert_id, experts_shard = loaded_weight.chunk(2, dim=1)[shard_idx]
return_success=True, else:
) experts_shard = loaded_weight
if success: start = 0
logger.debug( else:
"Loaded %s for expert %d into %s", # loaded_weight is a single expert weight, so we add a dummy expert
param_name, # dimension to unify the loading logic with the fused case
expert_id, experts_shard = loaded_weight.unsqueeze(0)
self.layer_name, start = expert_id
# Unified loading logic for fused and non-fused experts
loaded_experts = experts_shard.unbind()
for expert_id, loaded_expert in enumerate(loaded_experts, start=start):
success = self.weight_loader(
param=param,
loaded_weight=loaded_expert,
weight_name=weight_name,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
) )
yield param_name if success:
logger.debug(
"Loaded expert %d of shard %s into %s for layer %s",
expert_id,
shard_id,
param_name,
self.layer_name,
)
yield param_name
def get_expert_weights(self) -> Iterable[torch.Tensor]: def get_expert_weights(self) -> Iterable[torch.Tensor]:
def _maybe_make_contiguous( def _maybe_make_contiguous(
......
...@@ -156,6 +156,17 @@ class MoEMixin(MixtureOfExperts): ...@@ -156,6 +156,17 @@ class MoEMixin(MixtureOfExperts):
Params for weights, fp8 weight scales, fp8 activation scales Params for weights, fp8 weight scales, fp8 activation scales
(param_name, weight_name, expert_id, shard_id) (param_name, weight_name, expert_id, shard_id)
""" """
# Models saved with fused experts. These are checkpoints released:
# - After Transformers v5
# - Before Transformers v5, but re-saved with save_original_format=False
# In the fused experts case, we repurpose the expert_id as shard_idx for
# deconcatenating w1 and w3 in FusedMoE.load_weights.
expert_mapping = [
("experts.w13_weight", "experts.gate_up_proj", 0, "w1"),
("experts.w13_weight", "experts.gate_up_proj", 1, "w3"),
("experts.w2_weight", "experts.down_proj", 0, "w2"),
]
# Models saved with ModuleList experts
ckpt_names = [ ckpt_names = [
# (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name) # (ckpt_gate_proj_name, ckpt_down_proj_name, ckpt_up_proj_name)
("gate_proj", "down_proj", "up_proj"), # Most common MoE style ("gate_proj", "down_proj", "up_proj"), # Most common MoE style
...@@ -164,7 +175,6 @@ class MoEMixin(MixtureOfExperts): ...@@ -164,7 +175,6 @@ class MoEMixin(MixtureOfExperts):
] ]
num_experts = self.model_config.get_num_experts() num_experts = self.model_config.get_num_experts()
num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts num_redundant_experts = self.parallel_config.eplb_config.num_redundant_experts
expert_mapping = []
for gate_proj, down_proj, up_proj in ckpt_names: for gate_proj, down_proj, up_proj in ckpt_names:
expert_mapping.extend( expert_mapping.extend(
FusedMoE.make_expert_params_mapping( FusedMoE.make_expert_params_mapping(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment