Unverified Commit 288ae41f authored by Shu Wang's avatar Shu Wang Committed by GitHub
Browse files

[NVIDIA] Fix num_experts in modelopt_quant (#8811)

parent 01c99a99
......@@ -1063,10 +1063,15 @@ class FlashInferFP4MoE(FusedMoE):
gemm1_weights_scale=self.gemm1_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm1_bias=None,
gemm1_alpha=None,
gemm1_beta=None,
gemm1_clamp_limit=None,
gemm2_weights=self.gemm2_weights_fp4_shuffled.data,
gemm2_weights_scale=self.gemm2_scales_fp4_shuffled.data.view(
torch.float8_e4m3fn
),
gemm2_bias=None,
output1_scale_scalar=self.g1_scale_c.data,
output1_scale_gate_scalar=self.g1_alphas.data,
output2_scale_scalar=self.g2_alphas.data,
......
......@@ -764,8 +764,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
)
# TODO(ch-wan): check if this is needed
layer.num_experts = num_experts
layer.num_local_experts = num_experts
layer.intermediate_size_per_partition = intermediate_size_per_partition
layer.params_dtype = params_dtype
layer.quant_config = self.quant_config
......@@ -1106,7 +1104,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w13_weight_scale,
)
print("Applied flashinfer weight processing for both w13 and w2")
logger.info_once("Applied flashinfer weight processing for both w13 and w2")
else:
# CUTLASS processing - handle w13 and w2 separately
......@@ -1126,7 +1124,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
layer.w2_weight = Parameter(layer.w2_weight.data, requires_grad=False)
# Both flashinfer cutlass and regular cutlass use same processing for w2
print("Applied weight processing for both w13 and w2")
logger.info_once("Applied weight processing for both w13 and w2")
# Set up CUTLASS MoE parameters
device = layer.w13_weight.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment