Unverified Commit 48a65ccb authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[CI] Speed up test_fused_marlin_moe (#40178)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
Signed-off-by: default avatarMichael Goin <mgoin64@gmail.com>
parent 55842a8d
......@@ -143,12 +143,14 @@ MOE_MARLIN_QUANT_TEST_CONFIGS = [
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float4_e2m1f,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
# MXFP8
{
"a_type": [scalar_types.bfloat16],
"b_type": scalar_types.float8_e4m3fn,
"c_type": [scalar_types.bfloat16],
"group_blocks": [2],
},
# AWQ-INT4 with INT8 activation
......@@ -674,31 +676,35 @@ def test_fused_moe_wn16(
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
MARLIN_MOE_SCENARIOS = [
# (m, n, k, e, topk, ep_size, act_order, is_k_full)
# No act_order: is_k_full=True matches usual case (marlin_is_k_full).
# N>=256 required for Marlin kernel thread config for MXFP8.
# Single token, small matrices
(1, 128, 256, 5, 2, 1, False, True),
# Single token, large matrices
(1, 1024, 2048, 5, 2, 1, False, True),
# Unaligned m, small matrices
(133, 256, 256, 5, 2, 1, False, True),
# Unaligned m, large matrices
(133, 1024, 2048, 12, 3, 1, False, True),
# Aligned batch, small matrices
(128, 256, 256, 5, 2, 1, False, True),
# Aligned batch, large matrices
(128, 1024, 2048, 12, 3, 1, False, True),
# Expert parallelism
(64, 1024, 2048, 12, 3, 4, False, True),
# Act order with is_k_full=True (no tensor parallelism)
(1, 1024, 2048, 5, 2, 1, True, True),
# Act order with is_k_full=False (tensor parallelism)
(133, 256, 256, 5, 2, 1, True, False),
]
def marlin_moe_generate_valid_test_cases():
import itertools
m_list = [1, 123, 666]
n_list = [128, 1024]
k_list = [256, 2048]
e_list = [5, 12]
topk_list = [2, 3]
ep_size_list = [1, 4]
act_order_list = [True, False]
is_k_full_list = [True, False]
all_combinations = itertools.product(
MOE_MARLIN_QUANT_TEST_CONFIGS,
m_list,
n_list,
k_list,
e_list,
topk_list,
ep_size_list,
act_order_list,
is_k_full_list,
)
def is_invalid(
def is_valid(
a_type,
b_type,
c_type,
......@@ -715,39 +721,42 @@ def marlin_moe_generate_valid_test_cases():
group_size = group_blocks if group_blocks <= 0 else group_blocks * 16
if group_size > 0 and k % group_size != 0:
return False
if act_order and group_size in [-1, k, n]:
return False
if group_size in [k, n]:
return False
if not act_order and is_k_full:
if b_type == scalar_types.float8_e4m3fn and group_size == 32 and is_k_full:
return False
return a_type.size_bits < 16 or a_type is c_type
cases = []
for case in all_combinations:
quant_test_config, m, n, k, _, _, _, act_order, *_ = case
if act_order and not quant_test_config.get("support_act_order", False):
continue
for quant_test_config in MOE_MARLIN_QUANT_TEST_CONFIGS:
f16_types = [scalar_types.float16]
inner_combinations = itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
inner_combinations = list(
itertools.product(
quant_test_config.get("a_type", f16_types),
[quant_test_config["b_type"]],
quant_test_config.get("c_type", f16_types),
quant_test_config["group_blocks"],
)
)
supports_act_order = quant_test_config.get("support_act_order", False)
for sub_case in inner_combinations:
if (
sub_case[0] == scalar_types.float8_e4m3fn
and current_platform.get_device_capability() not in [89, 120]
):
continue
args = sub_case + (m, n, k) + case[4:]
if is_invalid(*args):
cases.append(args)
for scenario in MARLIN_MOE_SCENARIOS:
m, n, k, e, topk, ep_size, act_order, is_k_full = scenario
if act_order and not supports_act_order:
continue
args = sub_case + (m, n, k, e, topk, ep_size, act_order, is_k_full)
if is_valid(*args):
cases.append(args)
return cases
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment