Commit af7b564d authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]fix tests of kernels

parent 1faa2c78
...@@ -50,6 +50,7 @@ def get_config_quant_dtype( ...@@ -50,6 +50,7 @@ def get_config_quant_dtype(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8: bool,
use_mxfp4_w4a4: bool, use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]: ) -> Union[None, torch.dtype, str]:
if use_fp8_w8a8: if use_fp8_w8a8:
...@@ -130,6 +131,7 @@ class FusedMoEQuantConfig: ...@@ -130,6 +131,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False, use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
use_mxfp4_w4a4: bool = False, use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False, per_act_token_quant: bool = False,
per_out_ch_quant: bool = False, per_out_ch_quant: bool = False,
...@@ -141,6 +143,7 @@ class FusedMoEQuantConfig: ...@@ -141,6 +143,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8, use_int8_w8a8,
use_int8_w8a16, use_int8_w8a16,
use_int4_w4a16, use_int4_w4a16,
use_int4_w4a8,
use_mxfp4_w4a4, use_mxfp4_w4a4,
] ]
]) <= 1, "Quantization flags are mutually exclusive." ]) <= 1, "Quantization flags are mutually exclusive."
...@@ -150,6 +153,7 @@ class FusedMoEQuantConfig: ...@@ -150,6 +153,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
) )
return FusedMoEQuantConfig( return FusedMoEQuantConfig(
......
...@@ -676,7 +676,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -676,7 +676,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_scale: Optional[torch.Tensor], B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor], B_zp: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor], topk_weights: Optional[torch.Tensor],
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor, sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor, expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor, num_tokens_post_padded: torch.Tensor,
...@@ -1863,6 +1862,7 @@ def fused_experts_impl( ...@@ -1863,6 +1862,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4) use_mxfp4_w4a4=use_mxfp4_w4a4)
get_config_func = functools.partial( get_config_func = functools.partial(
...@@ -1955,7 +1955,6 @@ def fused_experts_impl( ...@@ -1955,7 +1955,6 @@ def fused_experts_impl(
w1_scale, w1_scale,
w1_zp, w1_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -2008,7 +2007,6 @@ def fused_experts_impl( ...@@ -2008,7 +2007,6 @@ def fused_experts_impl(
w2_scale, w2_scale,
w2_zp, w2_zp,
curr_topk_weights, curr_topk_weights,
curr_topk_ids,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
...@@ -2370,7 +2368,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute): ...@@ -2370,7 +2368,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8, use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16, use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16, use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8, use_int4_w4a8=self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant, per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape, block_shape=self.block_shape,
B_bias=None # TODO support B_bias B_bias=None # TODO support B_bias
...@@ -2384,7 +2382,7 @@ def modular_triton_fused_moe( ...@@ -2384,7 +2382,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool, use_int8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
use_int4_w4a8:bool, use_int4_w4a8: bool,
use_mxfp4_w4a4: bool, use_mxfp4_w4a4: bool,
per_act_token_quant: bool, per_act_token_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
...@@ -2396,7 +2394,7 @@ def modular_triton_fused_moe( ...@@ -2396,7 +2394,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8= use_int4_w4a8, use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4, use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant, per_act_token_quant=per_act_token_quant,
block_shape=block_shape, block_shape=block_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment