Unverified Commit 48a65ccb authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI] Speed up test_fused_marlin_moe (#40178)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 55842a8d
...@@ -143,12 +143,14 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [ ...@@ -143,12 +143,14 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
{ {
"a_type": [scalar_types.bfloat16], "a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f, "b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2], "group_blocks": [2],
}, },
# MXFP8 # MXFP8
{ {
"a_type": [scalar_types.bfloat16], "a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float8_e4m3fn, "b_type": scalar_types.float8_e4m3fn,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2], "group_blocks": [2],
}, },
# AWQ-INT4 with INT8 activation # AWQ-INT4 with INT8 activation
...@@ -674,31 +676,35 @@ def test_fused_moe_wn16( ...@@ -674,31 +676,35 @@ def test_fused_moe_wn16(
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
MARLIN_MOE_SCENARIOS = [
# (m, n, k, e, topk, ep_size, act_order, is_k_full)
# No act_order: is_k_full=True matches usual case (marlin_is_k_full).
# N>=256 required for Marlin kernel thread config for MXFP8.
# Single token, small matrices
(1, 128, 256, 5, 2, 1, False, True),
# Single token, large matrices
(1, 1024, 2048, 5, 2, 1, False, True),
# Unaligned m, small matrices
(133, 256, 256, 5, 2, 1, False, True),
# Unaligned m, large matrices
(133, 1024, 2048, 12, 3, 1, False, True),
# Aligned batch, small matrices
(128, 256, 256, 5, 2, 1, False, True),
# Aligned batch, large matrices
(128, 1024, 2048, 12, 3, 1, False, True),
# Expert parallelism
(64, 1024, 2048, 12, 3, 4, False, True),
# Act order with is_k_full=True (no tensor parallelism)
(1, 1024, 2048, 5, 2, 1, True, True),
# Act order with is_k_full=False (tensor parallelism)
(133, 256, 256, 5, 2, 1, True, False),
]
def marlin_moe_generate_valid_test_cases(): def marlin_moe_generate_valid_test_cases():
import itertools import itertools
m_list = [1, 123, 666] def is_valid(
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [5, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
act_order_list = [True, False]
is_k_full_list = [True, False]
all_combinations = itertools.product(
MOE_MARLIN_QUANT_TEST_CONFIGS,
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
act_order_list,
is_k_full_list,
)
def is_invalid(
a_type, a_type,
b_type, b_type,
c_type, c_type,
...@@ -715,39 +721,42 @@ def marlin_moe_generate_valid_test_cases(): ...@@ -715,39 +721,42 @@ def marlin_moe_generate_valid_test_cases():
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16 group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and k % group_size != 0: if group_size > 0 and k % group_size != 0:
return False return False
if act_order and group_size in [-1, k, n]: if act_order and group_size in [-1, k, n]:
return False return False
if group_size in [k, n]: if group_size in [k, n]:
return False return False
if not act_order and is_k_full: if b_type == scalar_types.float8_e4m3fn and group_size == 32 and is_k_full:
return False return False
return a_type.size_bits < 16 or a_type is c_type return a_type.size_bits < 16 or a_type is c_type
cases = [] cases = []
for case in all_combinations: for quant_test_config in MOE_MARLIN_QUANT_TEST_CONFIGS:
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
if act_order and not quant_test_config.get("support_act_order", False):
continue
f16_types = [scalar_types.float16] f16_types = [scalar_types.float16]
inner_combinations = itertools.product( inner_combinations = list(
quant_test_config.get("a_type", f16_types), itertools.product(
[quant_test_config["b_type"]], quant_test_config.get("a_type", f16_types),
quant_test_config.get("c_type", f16_types), [quant_test_config["b_type"]],
quant_test_config["group_blocks"], quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
) )
supports_act_order = quant_test_config.get("support_act_order", False)
for sub_case in inner_combinations: for sub_case in inner_combinations:
if ( if (
sub_case[0] == scalar_types.float8_e4m3fn sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120] and current_platform.get_device_capability() not in [89, 120]
): ):
continue continue
args = sub_case + (m, n, k) + case[4:]
if is_invalid(*args): for scenario in MARLIN_MOE_SCENARIOS:
cases.append(args) m, n, k, e, topk, ep_size, act_order, is_k_full = scenario
if act_order and not supports_act_order:
continue
args = sub_case + (m, n, k, e, topk, ep_size, act_order, is_k_full)
if is_valid(*args):
cases.append(args)
return cases return cases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment