test_cutlass.py 11.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
"""Tests for cutlass kernels

Run `pytest tests/kernels/test_cutlass.py`.
"""
from typing import Type

import pytest
import torch

from vllm import _custom_ops as ops

CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


20
def to_fp8(tensor: torch.Tensor):
21
22
23
24
25
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


26
def to_int8(tensor: torch.Tensor):
27
28
29
30
31
32
33
34
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


def cutlass_fp8_gemm_helper(m: int,
                            n: int,
                            k: int,
                            per_token_act_quant: bool,
                            per_out_channel_weight_quant: bool,
35
                            bias: bool,
36
37
38
39
40
41
42
43
44
45
46
47
48
49
                            out_dtype: Type[torch.dtype] = torch.bfloat16,
                            device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_fp8(torch.randn((m, k), device=device))
    b = to_fp8(torch.randn((n, k), device=device).t())

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device=device, dtype=torch.float32) / 10)
    scale_b = (torch.randn(
        (1, n_b_scales), device=device, dtype=torch.float32) / 10)
50
51
52
53
54
55
56
57
58
59
60
    if bias:
        # bias term should be > 1 so that the absolute tolerance can catch it
        bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
        out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
    else:
        out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
        bias_t = 0

    baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
                         scale_b * b.to(dtype=torch.float32)) +
                bias_t).to(out_dtype)
61
62
63
64
65
66
67
68
69

    assert torch.allclose(out, baseline, rtol=1e-2, atol=1e-1)


def cutlass_int8_gemm_helper(m: int,
                             n: int,
                             k: int,
                             per_token_act_quant: bool,
                             per_out_channel_weight_quant: bool,
70
                             bias: bool,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
                             out_dtype: Type[torch.dtype] = torch.bfloat16,
                             device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device=device, dtype=torch.float32) / 10)
    scale_b = (torch.randn(
        (1, n_b_scales), device=device, dtype=torch.float32) / 10)

86
87
88
89
90
91
92
93
94
95
96
    if bias:
        # bias term should be > 1 so that the absolute tolerance can catch it
        bias_t = torch.rand((n, ), device=device, dtype=out_dtype) + 1.0
        out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias_t)
    else:
        out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype)
        bias_t = 0

    baseline = (torch.mm(scale_a * a.to(dtype=torch.float32),
                         scale_b * b.to(dtype=torch.float32)) +
                bias_t).to(dtype=out_dtype)
97
98
99
    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)


100
@pytest.mark.parametrize("m", [512, 222, 100, 33, 1])
101
102
103
104
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
105
@pytest.mark.parametrize("bias", [True, False])
106
107
108
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
109
110
                          per_out_ch: bool, bias: bool):
    cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
111
112
113
114
115
116
117


@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
118
@pytest.mark.parametrize("bias", [True, False])
119
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
120
121
                           per_out_ch: bool, bias: bool):
    cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, bias)
122
123
124
125
126


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
127
@pytest.mark.parametrize("bias", [True, False])
128
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
129
130
131
132
133
134
135
136
137
                                        out_dtype: Type[torch.dtype],
                                        bias: bool):
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
                             bias,
                             out_dtype=out_dtype)
138
139
140
141
142


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
143
@pytest.mark.parametrize("bias", [True, False])
144
145
146
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
147
148
149
150
151
152
153
154
155
                                       out_dtype: Type[torch.dtype],
                                       bias: bool):
    cutlass_fp8_gemm_helper(512,
                            512,
                            512,
                            per_act_token,
                            per_out_ch,
                            bias,
                            out_dtype=out_dtype)
156
157
158
159


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
160
@pytest.mark.parametrize("bias", [True, False])
161
162
163
164
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
165
166
                                  bias: bool, device: str):
    cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, bias,
167
168
169
170
171
                            torch.bfloat16, device)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
172
@pytest.mark.parametrize("bias", [True, False])
173
174
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
175
176
177
178
179
180
181
182
183
                                   bias: bool, device: str):
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
                             bias,
                             out_dtype=torch.bfloat16,
                             device=device)
184
185
186
187
188
189
190
191
192


# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
193
@pytest.mark.parametrize("bias", [True, False])
194
195
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
196
197
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
                                  bias: bool):
198
199
    for nk in range(32, 128, 32):
        for m in range(1, 128):
200
            cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch, bias)
201
202
203
204


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
205
206
207
@pytest.mark.parametrize("bias", [True, False])
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
                                   bias: bool):
208
209
    for nk in range(32, 128, 32):
        for m in range(1, 128):
210
211
            cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
                                     bias)
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226


# Test working with a subset of A and B
def test_cutlass_subset():
    big_m, big_n, big_k = 1024, 1024, 1024
    m, n, k = 512, 512, 512

    whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
    whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
    a = whole_a[0:m, 0:k]
    b = whole_b[0:k, 0:n]

    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

227
228
229
230
231
    out = ops.cutlass_scaled_mm(a,
                                b,
                                scale_a,
                                scale_b,
                                out_dtype=torch.bfloat16)
232
233
234
235
236
    baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
                        scale_b *
                        b.to(dtype=torch.float32)).to(dtype=torch.bfloat16)

    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
237
238
239
240
241
242
243
244
245
246
247
248
249


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

    def __init__(self, b, scale_a, scale_b, out_dtype):
        super().__init__()
        self.b = b
        self.scale_a = scale_a
        self.scale_b = scale_b
        self.out_dtype = out_dtype

    def forward(self, a):
250
251
        return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
                                     self.out_dtype)
252
253


254
255
256
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
257
258
259
260
261
    m, n, k = 512, 512, 512

    a = to_int8(torch.randn((m, k), device="cuda"))
    b = to_int8(torch.randn((n, k), device="cuda").t())

262
263
264
265
266
267
268
    m_a_scales = m if per_act_token else 1
    n_b_scales = n if per_out_ch else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
    scale_b = (torch.randn(
        (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    # Construct a trivial model with a single layer that calls a CUTLASS kernel
    model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            out = model(a)
    out.zero_()
    g.replay()

    baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
                        scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)