Unverified Commit 5190ba7f authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Fuse two kernels of hidden states padding into quantization kernel (#9005)


Co-authored-by: default avatarQiaolin-Yu <liin1211@outlook.com>
parent 5438886c
......@@ -210,13 +210,13 @@ class FusedMoE(torch.nn.Module):
self.use_enable_flashinfer_mxfp4_moe = global_server_args_dict.get(
"enable_flashinfer_mxfp4_moe", False
)
# TODO maybe we should remove this `if`, since `Mxfp4MoEMethod` does another round-up logic
if (
self.quant_config is not None
and self.quant_config.get_name() == "mxfp4"
and self.use_enable_flashinfer_mxfp4_moe
):
hidden_size = round_up(hidden_size, 256)
self.hidden_size = hidden_size
self.quant_method.create_weights(
layer=self,
num_experts=self.num_local_experts,
......@@ -796,13 +796,6 @@ class FusedMoE(torch.nn.Module):
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
origin_hidden_states_dim = hidden_states.shape[-1]
if self.hidden_size != origin_hidden_states_dim:
hidden_states = torch.nn.functional.pad(
hidden_states,
(0, self.hidden_size - origin_hidden_states_dim),
mode="constant",
value=0.0,
)
assert self.quant_method is not None
if self.moe_ep_size > 1 and not self.enable_flashinfer_cutlass_moe:
......
......@@ -570,8 +570,11 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
) -> torch.Tensor:
if self.use_flashinfer:
# Based on profiling results, we need to quantize x to mxfp8 here to achieve better performance
x_quant, x_scale = mxfp8_quantize(x, False) # to mxfp8
x_quant, x_scale = mxfp8_quantize(
x, False, alignment=self.hidden_size
) # to mxfp8
x_scale = x_scale.view(torch.float8_e4m3fn).reshape(-1)
assert x_quant.shape[-1] == self.hidden_size
top_k, router_logits = topk_output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment