Commit af7b564d authored by zhuwenwen's avatar zhuwenwen
Browse files

[fix]fix tests of kernels

parent 1faa2c78
......@@ -50,6 +50,7 @@ def get_config_quant_dtype(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_int4_w4a8: bool,
use_mxfp4_w4a4: bool,
) -> Union[None, torch.dtype, str]:
if use_fp8_w8a8:
......@@ -130,6 +131,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
use_int4_w4a8: bool = False,
use_mxfp4_w4a4: bool = False,
per_act_token_quant: bool = False,
per_out_ch_quant: bool = False,
......@@ -141,6 +143,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8,
use_int8_w8a16,
use_int4_w4a16,
use_int4_w4a8,
use_mxfp4_w4a4,
]
]) <= 1, "Quantization flags are mutually exclusive."
......@@ -150,6 +153,7 @@ class FusedMoEQuantConfig:
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4,
)
return FusedMoEQuantConfig(
......
......@@ -676,7 +676,6 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
B_scale: Optional[torch.Tensor],
B_zp: Optional[torch.Tensor],
topk_weights: Optional[torch.Tensor],
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
......@@ -1863,6 +1862,7 @@ def fused_experts_impl(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4)
get_config_func = functools.partial(
......@@ -1955,7 +1955,6 @@ def fused_experts_impl(
w1_scale,
w1_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......@@ -2008,7 +2007,6 @@ def fused_experts_impl(
w2_scale,
w2_zp,
curr_topk_weights,
curr_topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
......@@ -2370,7 +2368,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
use_int8_w8a8=self.use_int8_w8a8,
use_int8_w8a16=self.use_int8_w8a16,
use_int4_w4a16=self.use_int4_w4a16,
use_int4_w4a8= self.use_int4_w4a8,
use_int4_w4a8=self.use_int4_w4a8,
per_channel_quant=self.per_act_token_quant,
block_shape=self.block_shape,
B_bias=None # TODO support B_bias
......@@ -2384,7 +2382,7 @@ def modular_triton_fused_moe(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
use_int4_w4a8:bool,
use_int4_w4a8: bool,
use_mxfp4_w4a4: bool,
per_act_token_quant: bool,
block_shape: Optional[List[int]] = None,
......@@ -2396,7 +2394,7 @@ def modular_triton_fused_moe(
use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
use_int4_w4a8= use_int4_w4a8,
use_int4_w4a8=use_int4_w4a8,
use_mxfp4_w4a4=use_mxfp4_w4a4,
per_act_token_quant=per_act_token_quant,
block_shape=block_shape,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment