Unverified Commit c7e85f53 authored by Trevor Morris's avatar Trevor Morris Committed by GitHub
Browse files

fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local...

fix: flashinfer_cutlass_moe: Use max of global expert scales instead of local for input scale (#10296)
parent 3df05f4d
...@@ -503,8 +503,14 @@ class FusedMoE(torch.nn.Module): ...@@ -503,8 +503,14 @@ class FusedMoE(torch.nn.Module):
param.data[:, :dim1, :dim2].copy_(loaded_weight) param.data[:, :dim1, :dim2].copy_(loaded_weight)
return return
# ModelOptNvFp4FusedMoEMethod uses max of global expert scaling factors for input scaling factor
load_global_experts = (
isinstance(self.quant_method, ModelOptNvFp4FusedMoEMethod)
and "input_scale" in weight_name
)
global_expert_location_metadata = get_global_expert_location_metadata() global_expert_location_metadata = get_global_expert_location_metadata()
if global_expert_location_metadata is None: if global_expert_location_metadata is None or load_global_experts:
self._weight_loader_impl( self._weight_loader_impl(
param=param, param=param,
loaded_weight=loaded_weight, loaded_weight=loaded_weight,
......
...@@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase): ...@@ -996,13 +996,13 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
) )
w13_input_scale = PerTensorScaleParameter( w13_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_local_experts, 2, dtype=torch.float32), data=torch.empty(layer.num_experts, 2, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w13_input_scale", w13_input_scale) layer.register_parameter("w13_input_scale", w13_input_scale)
w2_input_scale = PerTensorScaleParameter( w2_input_scale = PerTensorScaleParameter(
data=torch.empty(layer.num_local_experts, dtype=torch.float32), data=torch.empty(layer.num_experts, dtype=torch.float32),
weight_loader=weight_loader, weight_loader=weight_loader,
) )
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment