"docs/source/api/vscode:/vscode.git/clone" did not exist on "869093e890ffface86430e378cc71ebcc080f2cf"
Unverified Commit 2df532ef authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fix the global scale fix does not support EPLB and improve enabling condition (#10369)

parent abea9250
...@@ -504,14 +504,8 @@ class FusedMoE(torch.nn.Module): ...@@ -504,14 +504,8 @@ class FusedMoE(torch.nn.Module):
param.data[:, :dim1, :dim2].copy_(loaded_weight) param.data[:, :dim1, :dim2].copy_(loaded_weight)
return return
# ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor
load_global_experts = (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
and "input_scale" in weight_name
)
global_expert_location_metadata = get_global_expert_location_metadata() global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None or load_global_experts: if global_expert_location_metadata is None:
self._weight_loader_impl( self._weight_loader_impl(
param=param, param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
...@@ -548,10 +542,12 @@ class FusedMoE(torch.nn.Module): ...@@ -548,10 +542,12 @@ class FusedMoE(torch.nn.Module):
shard_id: str, shard_id: str,
expert_id: int, expert_id: int,
) -> None: ) -> None:
# WARN: This makes the `expert_id` mean "local" and "global" in different cases
if not getattr(param, "_sglang_require_global_experts", False):
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
expert_id = self._map_global_expert_id_to_local_expert_id(expert_id)
if expert_id == -1:
return
self._weight_loader_impl( self._weight_loader_impl(
param=param, param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
......
...@@ -999,12 +999,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -999,12 +999,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
data=torch.empty(layer.num_experts, 2, dtype=torch.float32), data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
w13_input_scale._sglang_require_global_experts = True
layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter( w2_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_experts, dtype=torch.float32), data=torch.empty(layer.num_experts, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
w2_input_scale._sglang_require_global_experts = True
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def swizzle_blockscale(self, scale: torch.Tensor): def swizzle_blockscale(self, scale: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment