test_cutlass_grouped_gemm.py 2.91 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
11

# DeepGEMM Style Cutlass Grouped GEMM Test
# See https://github.com/deepseek-ai/DeepGEMM/blob/main/tests/test_core.py

import random

import pytest
import torch

12
from tests.kernels.moe.utils import per_token_cast_to_fp8
13
14
15
from tests.kernels.utils import baseline_scaled_mm
from vllm import _custom_ops as ops
from vllm.platforms import current_platform
16
from vllm.utils.deep_gemm import per_block_cast_to_fp8
17
from vllm.utils.math_utils import cdiv
18
19


20
21
22
23
24
25
26
27
28
29
30
@pytest.mark.parametrize(
    "num_groups, expected_m_per_group, k, n",
    [
        (4, 8192, 7168, 4096),
        (4, 8192, 2048, 7168),
        (8, 4096, 7168, 4096),
        (8, 4096, 2048, 7168),
        (32, 1024, 7168, 4096),
        (32, 1024, 2048, 7168),
    ],
)
31
32
33
@pytest.mark.parametrize("out_dtype", [torch.float16])
@pytest.mark.skipif(
    (lambda x: x is None or x.to_int() != 100)(
34
35
36
37
        current_platform.get_device_capability()
    ),
    reason="Block Scaled Grouped GEMM is only supported on SM100.",
)
38
39
40
41
42
43
44
45
46
47
def test_cutlass_grouped_gemm(
    num_groups: int,
    expected_m_per_group: int,
    k: int,
    n: int,
    out_dtype: torch.dtype,
):
    device = "cuda"
    alignment = 128
    group_ms = [
48
        int(expected_m_per_group * random.uniform(0.7, 1.3)) for _ in range(num_groups)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
    ]
    m = sum([cdiv(m, alignment) * alignment for m in group_ms])

    x = torch.randn((m, k), device=device, dtype=out_dtype)
    y = torch.randn((num_groups, n, k), device=device, dtype=out_dtype)
    out = torch.empty((m, n), device=device, dtype=out_dtype)
    ref_out = torch.randn((m, n), device=device, dtype=out_dtype)

    ep_offset = [0] + [sum(group_ms[:i]) for i in range(1, num_groups)] + [m]
    pb_size = []
    for i in range(num_groups):
        pb_size.append([ep_offset[i + 1] - ep_offset[i], n, k])
    problem_sizes = torch.tensor(pb_size, device=device, dtype=torch.int32)
    expert_offsets = torch.tensor(ep_offset, device=device, dtype=torch.int32)

    x_fp8 = per_token_cast_to_fp8(x)
65
66
67
68
69
70
    y_fp8 = (
        torch.empty_like(y, dtype=torch.float8_e4m3fn),
        torch.empty(
            (num_groups, cdiv(n, 128), k // 128), device=device, dtype=torch.float
        ),
    )
71
    for i in range(num_groups):
72
        y_fp8[0][i], y_fp8[1][i] = per_block_cast_to_fp8(y[i], [128, 128])
73
74

    for i in range(num_groups):
75
76
        a = x_fp8[0][ep_offset[i] : ep_offset[i + 1]]
        a_scale = x_fp8[1][ep_offset[i] : ep_offset[i + 1]]
77
78
79
        b = y_fp8[0][i].t()
        b_scale = y_fp8[1][i].t()
        baseline = baseline_scaled_mm(a, b, a_scale, b_scale, out_dtype)
80
        ref_out[ep_offset[i] : ep_offset[i + 1]] = baseline
81
82
83
84
85
86
87
88
89
90
91
92

    ops.cutlass_blockwise_scaled_grouped_mm(
        out,
        x_fp8[0],
        y_fp8[0],
        x_fp8[1],
        y_fp8[1],
        problem_sizes,
        expert_offsets[:-1],
    )

    torch.testing.assert_close(ref_out, out, atol=5e-1, rtol=1e-3)