Unverified Commit 8caffd92 authored by Pengchao Wang's avatar Pengchao Wang Committed by GitHub
Browse files

[Bugfix][MXFP4] Call `trtllm_fp4_block_scale_moe` with kwargs (#33104)


Signed-off-by: default avatarPengchao Wang <wpc@fb.com>
parent 58a05b0c
...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase): ...@@ -1053,32 +1053,32 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1) x_scale = x_scale.view(torch.float8_e4m3fn).reshape(*x.shape[:-1], -1)
trtllm_gen_output = trtllm_fp4_block_scale_moe( trtllm_gen_output = trtllm_fp4_block_scale_moe(
router_logits.to(torch.bfloat16), routing_logits=router_logits.to(torch.bfloat16),
None, # routing_bias routing_bias=None,
x_quant, hidden_states=x_quant,
x_scale, hidden_states_scale=x_scale,
layer.w13_weight, # uint8 (e2m1 x 2) gemm1_weights=layer.w13_weight, # uint8 (e2m1 x 2)
layer.w13_weight_scale, # uint8 (e4m3 x 2) gemm1_weights_scale=layer.w13_weight_scale, # uint8 (e4m3 x 2)
layer.w13_bias, # fp32 per expert per channel gemm1_bias=layer.w13_bias, # fp32 per expert per channel
layer.gemm1_alpha, # fp32 per expert gemm1_alpha=layer.gemm1_alpha, # fp32 per expert
layer.gemm1_beta, # fp32 per expert gemm1_beta=layer.gemm1_beta, # fp32 per expert
layer.gemm1_clamp_limit, # fp32 per expert gemm1_clamp_limit=layer.gemm1_clamp_limit, # fp32 per expert
layer.w2_weight, # uint8 (e2m1 x 2) gemm2_weights=layer.w2_weight, # uint8 (e2m1 x 2)
layer.w2_weight_scale, # ue8m0 gemm2_weights_scale=layer.w2_weight_scale, # ue8m0
layer.w2_bias, # fp32 per expert per channel gemm2_bias=layer.w2_bias, # fp32 per expert per channel
None, # output1_scale_scalar output1_scale_scalar=None,
None, # output1_scale_gate_scalar output1_scale_gate_scalar=None,
None, # output2_scale_scalar output2_scale_scalar=None,
layer.global_num_experts, num_experts=layer.global_num_experts,
layer.top_k, top_k=layer.top_k,
None, # n_group n_group=None,
None, # topk_group topk_group=None,
self.intermediate_size, # padded to multiple of 256 intermediate_size=self.intermediate_size, # padded to multiple of 256
layer.ep_rank * layer.local_num_experts, # local_expert_offset local_expert_offset=layer.ep_rank * layer.local_num_experts,
self.num_experts, # local num experts local_num_experts=self.num_experts,
None, # routed_scaling_factor routed_scaling_factor=None,
1 if layer.renormalize else 0, # routing_method_type, renormalize routing_method_type=1 if layer.renormalize else 0,
True, # do finalize do_finalize=True,
tune_max_num_tokens=max(self.max_capture_size, 1), tune_max_num_tokens=max(self.max_capture_size, 1),
)[0] )[0]
return trtllm_gen_output return trtllm_gen_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment