Unverified Commit 2df532ef authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix the global scale fix does not support EPLB and improve enabling condition (#10369)

parent abea9250
......@@ -504,14 +504,8 @@ class FusedMoE(torch.nn.Module):
param.data[:, :dim1, :dim2].copy_(loaded_weight)
return
# ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor
load_global_experts = (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
and "input_scale" in weight_name
)
global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None or load_global_experts:
if global_expert_location_metadata is None:
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
......@@ -548,10 +542,12 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
expert_id: int,
) -> None:
# WARN: This makes the `expert_id` mean "local" and "global" in different cases
if not getattr(param, "_sglang_require_global_experts", False):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl(
param=param,
loaded_weight=loaded_weight,
......
......@@ -999,12 +999,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader,
)
w13_input_scale._sglang_require_global_experts = True
layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_experts, dtype=torch.float32),
weight_loader=weight_loader,
)
w2_input_scale._sglang_require_global_experts = True
layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment