Unverified Commit b8c48c5d authored by Fardin Hoque's avatar Fardin Hoque Committed by GitHub
Browse files

kernels/moe test pruning (#27053)


Signed-off-by: default avatarFardin Hoque <kfhfar@amazon.com>
Co-authored-by: default avatarWentao Ye <44945378+yewentao256@users.noreply.github.com>
parent 17d055f5
...@@ -24,23 +24,16 @@ from vllm.triton_utils import tl ...@@ -24,23 +24,16 @@ from vllm.triton_utils import tl
MNK_FACTORS = [ MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 128, 2048),
(1, 512, 512), (1, 512, 512),
(1, 1024, 128),
(1, 1024, 2048), (1, 1024, 2048),
(32, 128, 128), (32, 128, 128),
(32, 512, 512), (32, 512, 512),
(32, 1024, 2048), (32, 1024, 2048),
(45, 128, 128),
(45, 128, 2048), (45, 128, 2048),
(45, 512, 512),
(45, 1024, 128), (45, 1024, 128),
(45, 1024, 2048),
(64, 512, 512), (64, 512, 512),
(64, 1024, 2048), (64, 1024, 2048),
(222, 128, 128),
(222, 128, 2048), (222, 128, 2048),
(222, 1024, 128),
(222, 1024, 2048), (222, 1024, 2048),
] ]
NUM_EXPERTS = [8, 64] NUM_EXPERTS = [8, 64]
...@@ -117,10 +110,19 @@ def test_batched_mm( ...@@ -117,10 +110,19 @@ def test_batched_mm(
block_shape: list[int] | None, block_shape: list[int] | None,
per_act_token_quant: bool, per_act_token_quant: bool,
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7) current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8: if (per_act_token_quant or block_shape is not None) and not use_fp8_w8a8:
pytest.skip("Don't test blocking for non-quantized types.") pytest.skip("Don't test blocking for non-quantized types.")
...@@ -244,10 +246,19 @@ def test_fused_moe_batched_experts( ...@@ -244,10 +246,19 @@ def test_fused_moe_batched_experts(
block_shape: list[int] | None, block_shape: list[int] | None,
input_scales: bool, input_scales: bool,
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
current_platform.seed_everything(7) current_platform.seed_everything(7)
use_fp8_w8a8 = dtype == torch.float8_e4m3fn use_fp8_w8a8 = dtype == torch.float8_e4m3fn
if (dtype == torch.float8_e4m3fn) and not current_platform.has_device_capability(
89
):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
if topk > e: if topk > e:
pytest.skip("topk > e") pytest.skip("topk > e")
......
...@@ -42,57 +42,43 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32] ...@@ -42,57 +42,43 @@ DTYPES = [torch.bfloat16] # [torch.half, torch.bfloat16, torch.float32]
# and its hidden size is 7168. # and its hidden size is 7168.
MNK_FACTORS = [ MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 512, 512),
(1, 128, 7168), (1, 128, 7168),
(1, 1024, 7168), (1, 1024, 7168),
(1, 4608, 128), (1, 4608, 128),
(1, 4608, 512),
(1, 4608, 7168), (1, 4608, 7168),
(83, 128, 128), (83, 128, 128),
(83, 512, 512), (83, 512, 512),
(83, 1024, 7168),
(83, 4608, 512), (83, 4608, 512),
(83, 4608, 7168), (83, 4608, 7168),
(128, 128, 128),
(128, 512, 512), (128, 512, 512),
(128, 1024, 7168), (128, 1024, 7168),
(128, 4608, 512),
(128, 4608, 7168), (128, 4608, 7168),
(2048, 128, 128), (2048, 128, 128),
(2048, 1024, 7168), (2048, 1024, 7168),
(2048, 4608, 512), (2048, 4608, 512),
(2048, 4608, 7168), (2048, 4608, 7168),
(8192, 128, 128), (8192, 128, 128),
(8192, 512, 512),
(8192, 128, 7168), (8192, 128, 7168),
(8192, 1024, 7168), (8192, 1024, 7168),
(8192, 4608, 512),
(8192, 4608, 7168), (8192, 4608, 7168),
] ]
MNK_FACTORS_DG = [ MNK_FACTORS_DG = [
(128, 128, 128), (128, 128, 128),
(128, 512, 512),
(128, 128, 7168), (128, 128, 7168),
(128, 1024, 7168), (128, 1024, 7168),
(128, 4608, 128), (128, 4608, 128),
(128, 4608, 512),
(128, 4608, 7168), (128, 4608, 7168),
(192, 128, 128),
(192, 512, 512), (192, 512, 512),
(192, 1024, 7168), (192, 1024, 7168),
(192, 4608, 512),
(192, 4608, 7168), (192, 4608, 7168),
(1335, 128, 128), (1335, 128, 128),
(1335, 1024, 7168), (1335, 1024, 7168),
(1335, 4608, 512), (1335, 4608, 512),
(1335, 4608, 7168), (1335, 4608, 7168),
(2048, 128, 128), (2048, 128, 128),
(2048, 512, 512),
(2048, 128, 7168), (2048, 128, 7168),
(2048, 1024, 7168), (2048, 1024, 7168),
(2048, 4608, 128),
(2048, 4608, 512),
(2048, 4608, 7168), (2048, 4608, 7168),
] ]
......
...@@ -21,36 +21,28 @@ vllm_config = VllmConfig() ...@@ -21,36 +21,28 @@ vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128 vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192 vllm_config.scheduler_config.max_model_len = 8192
DTYPES = [torch.half, torch.bfloat16] DTYPES = [torch.bfloat16]
MNK_FACTORS = [ MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 512, 512),
(1, 128, 7168), (1, 128, 7168),
(1, 1024, 7168), (1, 1024, 7168),
(1, 4096, 128),
(1, 4096, 512), (1, 4096, 512),
(1, 4096, 7168), (1, 4096, 7168),
(33, 128, 128),
(33, 512, 512), (33, 512, 512),
(33, 128, 7168), (33, 128, 7168),
(33, 1024, 7168), (33, 1024, 7168),
(33, 4096, 128), (33, 4096, 128),
(33, 4096, 512),
(33, 4096, 7168), (33, 4096, 7168),
(128, 128, 128), (128, 128, 128),
(128, 512, 512),
(128, 1024, 7168), (128, 1024, 7168),
(128, 4096, 512), (128, 4096, 512),
(128, 4096, 7168), (128, 4096, 7168),
(222, 128, 128),
(222, 512, 512), (222, 512, 512),
(222, 1024, 7168), (222, 1024, 7168),
(222, 4096, 512),
(222, 4096, 7168), (222, 4096, 7168),
(2048, 128, 128), (2048, 128, 128),
(2048, 1024, 7168), (2048, 1024, 7168),
(2048, 4096, 512),
(2048, 4096, 4096), (2048, 4096, 4096),
] ]
......
...@@ -26,16 +26,13 @@ TOP_KS = [6, 8] ...@@ -26,16 +26,13 @@ TOP_KS = [6, 8]
MNK_FACTORS = [ MNK_FACTORS = [
(2, 1024, 1024), (2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024), (2, 3072, 1024),
(2, 3072, 1536), (2, 3072, 1536),
(7, 3072, 1536), (7, 3072, 1536),
(64, 1024, 1024), (64, 1024, 1024),
(64, 1024, 1536), (64, 1024, 1536),
(64, 3072, 1024), (64, 3072, 1024),
(64, 3072, 1536),
(224, 1024, 1024), (224, 1024, 1024),
(224, 1024, 1536),
(224, 3072, 1024), (224, 3072, 1024),
(224, 3072, 1536), (224, 3072, 1536),
(32768, 1024, 1024), (32768, 1024, 1024),
......
...@@ -393,7 +393,6 @@ def _test_deepep_deepgemm_moe( ...@@ -393,7 +393,6 @@ def _test_deepep_deepgemm_moe(
MNKs = [ MNKs = [
(8, 128, 128), (8, 128, 128),
(8, 128, 512), (8, 128, 512),
(8, 512, 512),
(3, 1024, 2048), (3, 1024, 2048),
(32, 128, 1024), (32, 128, 1024),
(45, 512, 2048), (45, 512, 2048),
......
...@@ -130,10 +130,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size): ...@@ -130,10 +130,8 @@ def run_single_case(m, n, k, topk, num_experts, block_size):
# Note: N <= 512 will disable the deepgemm path due to performance issues. # Note: N <= 512 will disable the deepgemm path due to performance issues.
MNKs = [ MNKs = [
(1024, 768, 128), (1024, 768, 128),
(1024, 768, 512),
(2048, 768, 512), (2048, 768, 512),
(512, 1024, 1024), (512, 1024, 1024),
(512, 2048, 2048),
(4096, 4096, 1024), (4096, 4096, 1024),
] ]
......
...@@ -34,8 +34,6 @@ TOP_KS = [1] ...@@ -34,8 +34,6 @@ TOP_KS = [1]
MNK_FACTORS = [ MNK_FACTORS = [
(256, 8192, 5120), (256, 8192, 5120),
(256, 4096, 5120),
(127, 8192, 5120),
(127, 4096, 5120), (127, 4096, 5120),
(10, 8192, 5120), (10, 8192, 5120),
(10, 4096, 5120), (10, 4096, 5120),
......
...@@ -34,10 +34,8 @@ if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_cap ...@@ -34,10 +34,8 @@ if not has_flashinfer_cutlass_fused_moe() or not current_platform.has_device_cap
MNK_FACTORS = [ MNK_FACTORS = [
(2, 1024, 1024), (2, 1024, 1024),
(2, 1024, 1536),
(2, 3072, 1024), (2, 3072, 1024),
(2, 3072, 1536), (2, 3072, 1536),
(64, 1024, 1024),
(64, 1024, 1536), (64, 1024, 1536),
(64, 3072, 1024), (64, 3072, 1024),
(64, 2048, 1536), (64, 2048, 1536),
...@@ -49,7 +47,7 @@ MNK_FACTORS = [ ...@@ -49,7 +47,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_flashinfer_fp4_moe_no_graph( def test_flashinfer_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
......
...@@ -27,7 +27,7 @@ from vllm.platforms import current_platform ...@@ -27,7 +27,7 @@ from vllm.platforms import current_platform
@pytest.mark.parametrize("topk_group", [2]) @pytest.mark.parametrize("topk_group", [2])
@pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"]) @pytest.mark.parametrize("scoring_func", ["softmax", "sigmoid"])
@pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5]) @pytest.mark.parametrize("routed_scaling_factor", [1.0, 2.5])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float32])
def test_grouped_topk( def test_grouped_topk(
monkeypatch: pytest.MonkeyPatch, monkeypatch: pytest.MonkeyPatch,
n_token: int, n_token: int,
......
...@@ -295,6 +295,8 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -295,6 +295,8 @@ def test_modular_kernel_combinations_singlegpu(
world_size: int, world_size: int,
pytestconfig, pytestconfig,
): ):
"""Note: float8_e4m3fn is not supported on CUDA architecture < 89,
and those tests will be skipped on unsupported hardware."""
config = Config( config = Config(
Ms=Ms, Ms=Ms,
K=k, K=k,
...@@ -309,6 +311,12 @@ def test_modular_kernel_combinations_singlegpu( ...@@ -309,6 +311,12 @@ def test_modular_kernel_combinations_singlegpu(
world_size=world_size, world_size=world_size,
) )
if (
quant_config is not None and quant_config.quant_dtype == torch.float8_e4m3fn
) and not current_platform.has_device_capability(89):
pytest.skip(
"Triton limitation: fp8e4nv data type is not supported on CUDA arch < 89"
)
verbosity = pytestconfig.getoption("verbose") verbosity = pytestconfig.getoption("verbose")
run(config, verbosity > 0) run(config, verbosity > 0)
......
...@@ -66,8 +66,6 @@ FUSED_MOE_MNK_FACTORS = [ ...@@ -66,8 +66,6 @@ FUSED_MOE_MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 2048, 128), (1, 2048, 128),
(33, 2048, 128), (33, 2048, 128),
(222, 1024, 1024),
(32768, 128, 128),
(32768, 2048, 511), (32768, 2048, 511),
(40000, 1024, 1024), (40000, 1024, 1024),
] ]
...@@ -76,7 +74,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [ ...@@ -76,7 +74,6 @@ FUSED_MOE_WN16_MNK_FACTORS = [
(1, 128, 128), (1, 128, 128),
(1, 1024, 1024), (1, 1024, 1024),
(32, 2048, 128), (32, 2048, 128),
(32, 1024, 1024),
(222, 2048, 1024), (222, 2048, 1024),
] ]
...@@ -512,8 +509,8 @@ def marlin_moe_generate_valid_test_cases(): ...@@ -512,8 +509,8 @@ def marlin_moe_generate_valid_test_cases():
e_list = [4, 12] e_list = [4, 12]
topk_list = [2, 3] topk_list = [2, 3]
ep_size_list = [1, 4] ep_size_list = [1, 4]
dtype_list = [torch.half, torch.bfloat16] dtype_list = [torch.bfloat16]
group_size_list = [-1, 16, 32, 128] group_size_list = [-1, 32, 128]
act_order_list = [True, False] act_order_list = [True, False]
quant_type_list = [ quant_type_list = [
scalar_types.float4_e2m1f, scalar_types.float4_e2m1f,
...@@ -885,10 +882,10 @@ def test_batched_moe_align_block_size_opcheck(): ...@@ -885,10 +882,10 @@ def test_batched_moe_align_block_size_opcheck():
) )
@pytest.mark.parametrize("m", [1, 33, 64, 222]) @pytest.mark.parametrize("m", [1, 33, 222])
@pytest.mark.parametrize("topk", TOP_KS) @pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("k", [128, 511, 1024]) @pytest.mark.parametrize("k", [128, 511, 1024])
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16])
@pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm") @pytest.mark.skipif(current_platform.is_rocm(), reason="Skip for rocm")
def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype): def test_moe_sum(m: int, topk: int, k: int, dtype: torch.dtype):
input = torch.randn((m, topk, k), device="cuda", dtype=dtype) input = torch.randn((m, topk, k), device="cuda", dtype=dtype)
......
...@@ -26,9 +26,7 @@ MNK_FACTORS = [ ...@@ -26,9 +26,7 @@ MNK_FACTORS = [
(2, 1024, 1024), (2, 1024, 1024),
(2, 1024, 1536), (2, 1024, 1536),
(2, 3072, 1024), (2, 3072, 1024),
(2, 3072, 1536),
(64, 1024, 1024), (64, 1024, 1024),
(64, 1024, 1536),
(64, 3072, 1024), (64, 3072, 1024),
(64, 2048, 1536), (64, 2048, 1536),
(224, 1024, 1024), (224, 1024, 1024),
...@@ -39,7 +37,7 @@ MNK_FACTORS = [ ...@@ -39,7 +37,7 @@ MNK_FACTORS = [
@pytest.mark.parametrize("m,n,k", MNK_FACTORS) @pytest.mark.parametrize("m,n,k", MNK_FACTORS)
@pytest.mark.parametrize("e", [40, 64, 256]) @pytest.mark.parametrize("e", [40, 64, 256])
@pytest.mark.parametrize("topk", [1, 6, 8]) @pytest.mark.parametrize("topk", [1, 6, 8])
@pytest.mark.parametrize("dtype", [torch.half, torch.bfloat16]) @pytest.mark.parametrize("dtype", [torch.bfloat16])
@torch.inference_mode() @torch.inference_mode()
def test_cutlass_fp4_moe_no_graph( def test_cutlass_fp4_moe_no_graph(
m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype m: int, n: int, k: int, e: int, topk: int, dtype: torch.dtype
......
...@@ -19,20 +19,16 @@ CASES = [ ...@@ -19,20 +19,16 @@ CASES = [
(32, 64, 256, fp8_dtype), (32, 64, 256, fp8_dtype),
(17, 31, 768, fp8_dtype), (17, 31, 768, fp8_dtype),
(1, 1, 128 * 1, fp8_dtype), (1, 1, 128 * 1, fp8_dtype),
(1, 1, 128 * 2, fp8_dtype),
(1, 1, 128 * 3, fp8_dtype), (1, 1, 128 * 3, fp8_dtype),
(1, 1, 128 * 4, fp8_dtype), (1, 1, 128 * 4, fp8_dtype),
(8, 16, 128 * 1, fp8_dtype), (8, 16, 128 * 1, fp8_dtype),
(8, 16, 128 * 2, fp8_dtype), (8, 16, 128 * 2, fp8_dtype),
(8, 16, 128 * 3, fp8_dtype), (8, 16, 128 * 3, fp8_dtype),
(8, 16, 128 * 4, fp8_dtype),
(8, 64, 7168, fp8_dtype), (8, 64, 7168, fp8_dtype),
(8, 128, 7168, fp8_dtype), (8, 128, 7168, fp8_dtype),
(8, 256, 7168, fp8_dtype),
(8, 512, 7168, fp8_dtype), (8, 512, 7168, fp8_dtype),
(8, 1024, 7168, fp8_dtype), (8, 1024, 7168, fp8_dtype),
(256, 8, 7168, fp8_dtype), (256, 8, 7168, fp8_dtype),
(256, 16, 7168, fp8_dtype),
(256, 32, 7168, fp8_dtype), (256, 32, 7168, fp8_dtype),
(256, 64, 7168, fp8_dtype), (256, 64, 7168, fp8_dtype),
# Only add a few fnuz tests to help with long CI times. # Only add a few fnuz tests to help with long CI times.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment