"vscode:/vscode.git/clone" did not exist on "bd6028d6b0bbc0c569ece0535067081c5e8bdc14"
test_cutlass.py 11.4 KB
Newer Older
1
2
3
4
"""Tests for cutlass kernels

Run `pytest tests/kernels/test_cutlass.py`.
"""
5
from typing import Optional, Type
6
7
8
9
10
11
12
13
14
15
16
17
18
19

import pytest
import torch

from vllm import _custom_ops as ops

CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]


20
def to_fp8(tensor: torch.Tensor):
21
22
23
24
25
    finfo = torch.finfo(torch.float8_e4m3fn)
    return torch.round(tensor.clamp(
        min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)


26
def to_int8(tensor: torch.Tensor):
27
28
29
    return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)


30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def baseline_scaled_mm(a: torch.Tensor,
                       b: torch.Tensor,
                       scale_a: torch.Tensor,
                       scale_b: torch.Tensor,
                       out_dtype: Type[torch.dtype],
                       bias: Optional[torch.Tensor] = None) -> torch.Tensor:

    output = (scale_a * (scale_b * (torch.mm(
        a.to(dtype=torch.float32), b.to(dtype=torch.float32))))).to(out_dtype)
    if bias is not None:
        output = output + bias

    return output


45
46
47
48
49
def cutlass_fp8_gemm_helper(m: int,
                            n: int,
                            k: int,
                            per_token_act_quant: bool,
                            per_out_channel_weight_quant: bool,
50
                            use_bias: bool,
51
52
53
54
55
56
57
58
59
60
                            out_dtype: Type[torch.dtype] = torch.bfloat16,
                            device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_fp8(torch.randn((m, k), device=device))
    b = to_fp8(torch.randn((n, k), device=device).t())

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

61
62
63
64
65
66
    scale_a = (torch.randn((m_a_scales, 1), device=device,
                           dtype=torch.float32))
    scale_b = (torch.randn((1, n_b_scales), device=device,
                           dtype=torch.float32))
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
67
    else:
68
        bias = None
69

70
71
    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
72

73
    assert torch.allclose(out, baseline, rtol=1e-2, atol=5e-2)
74
75
76
77
78
79
80


def cutlass_int8_gemm_helper(m: int,
                             n: int,
                             k: int,
                             per_token_act_quant: bool,
                             per_out_channel_weight_quant: bool,
81
                             use_bias: bool,
82
83
84
85
86
87
88
89
90
91
                             out_dtype: Type[torch.dtype] = torch.bfloat16,
                             device: str = "cuda"):
    # Test for a cutlass kernel with per-token activation quantization
    # and per-output channel weight quantization.
    a = to_int8(torch.randn((m, k), device=device) * 5)
    b = to_int8(torch.randn((n, k), device=device).t() * 5)

    m_a_scales = m if per_token_act_quant else 1
    n_b_scales = n if per_out_channel_weight_quant else 1

92
93
94
95
    scale_a = (torch.randn((m_a_scales, 1), device=device,
                           dtype=torch.float32))
    scale_b = (torch.randn((1, n_b_scales), device=device,
                           dtype=torch.float32))
96

97
98
    if use_bias:
        bias = torch.rand((n, ), device=device, dtype=out_dtype) * 10
99
    else:
100
101
102
103
        bias = None

    out = ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
    baseline = baseline_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
104

105
106
107
    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)


108
@pytest.mark.parametrize("m", [512, 222, 100, 33, 1])
109
110
111
112
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
113
@pytest.mark.parametrize("use_bias", [True, False])
114
115
116
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
117
118
                          per_out_ch: bool, use_bias: bool):
    cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
119
120
121
122
123
124
125


@pytest.mark.parametrize("m", [512, 222, 33, 1])
@pytest.mark.parametrize("n", [2048, 256, 1024])
@pytest.mark.parametrize("k", [128, 496, 1024])
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
126
@pytest.mark.parametrize("use_bias", [True, False])
127
def test_cutlass_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
128
129
                           per_out_ch: bool, use_bias: bool):
    cutlass_int8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
130
131
132
133
134


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
135
@pytest.mark.parametrize("use_bias", [True, False])
136
def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
137
                                        out_dtype: Type[torch.dtype],
138
                                        use_bias: bool):
139
140
141
142
143
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
144
                             use_bias,
145
                             out_dtype=out_dtype)
146
147
148
149
150


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
151
@pytest.mark.parametrize("use_bias", [True, False])
152
153
154
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
155
                                       out_dtype: Type[torch.dtype],
156
                                       use_bias: bool):
157
158
159
160
161
    cutlass_fp8_gemm_helper(512,
                            512,
                            512,
                            per_act_token,
                            per_out_ch,
162
                            use_bias,
163
                            out_dtype=out_dtype)
164
165
166
167


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
168
@pytest.mark.parametrize("use_bias", [True, False])
169
170
171
172
@pytest.mark.parametrize("device", CUDA_DEVICES)
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
173
174
                                  use_bias: bool, device: str):
    cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
175
176
177
178
179
                            torch.bfloat16, device)


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
180
@pytest.mark.parametrize("use_bias", [True, False])
181
182
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
183
                                   use_bias: bool, device: str):
184
185
186
187
188
    cutlass_int8_gemm_helper(512,
                             512,
                             512,
                             per_act_token,
                             per_out_ch,
189
                             use_bias,
190
191
                             out_dtype=torch.bfloat16,
                             device=device)
192
193
194
195
196
197
198
199
200


# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
201
@pytest.mark.parametrize("use_bias", [True, False])
202
203
@pytest.mark.skipif(capability < 89,
                    reason="FP8 is not supported on this GPU type.")
204
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
205
                                  use_bias: bool):
206
207
    for nk in range(32, 128, 32):
        for m in range(1, 128):
208
209
            cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
                                    use_bias)
210
211
212
213


@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
214
@pytest.mark.parametrize("use_bias", [True, False])
215
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
216
                                   use_bias: bool):
217
218
    for nk in range(32, 128, 32):
        for m in range(1, 128):
219
            cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
220
                                     use_bias)
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235


# Test working with a subset of A and B
def test_cutlass_subset():
    big_m, big_n, big_k = 1024, 1024, 1024
    m, n, k = 512, 512, 512

    whole_a = to_int8(torch.randn((big_m, big_k), device="cuda") * 5)
    whole_b = to_int8(torch.randn((big_n, big_k), device="cuda").t() * 5)
    a = whole_a[0:m, 0:k]
    b = whole_b[0:k, 0:n]

    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

236
237
238
239
240
    out = ops.cutlass_scaled_mm(a,
                                b,
                                scale_a,
                                scale_b,
                                out_dtype=torch.bfloat16)
241
242
243
244
245
    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
                                  out_dtype=torch.bfloat16)
246
247

    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)
248
249
250
251
252
253
254
255
256
257
258
259
260


# Test to make sure cuda graphs work
class CutlassLayer(torch.nn.Module):

    def __init__(self, b, scale_a, scale_b, out_dtype):
        super().__init__()
        self.b = b
        self.scale_a = scale_a
        self.scale_b = scale_b
        self.out_dtype = out_dtype

    def forward(self, a):
261
262
        return ops.cutlass_scaled_mm(a, self.b, self.scale_a, self.scale_b,
                                     self.out_dtype)
263
264


265
266
267
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
268
269
270
271
272
    m, n, k = 512, 512, 512

    a = to_int8(torch.randn((m, k), device="cuda"))
    b = to_int8(torch.randn((n, k), device="cuda").t())

273
274
275
276
277
278
279
    m_a_scales = m if per_act_token else 1
    n_b_scales = n if per_out_ch else 1

    scale_a = (torch.randn(
        (m_a_scales, 1), device="cuda", dtype=torch.float32) / 10)
    scale_b = (torch.randn(
        (1, n_b_scales), device="cuda", dtype=torch.float32) / 10)
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295

    # Construct a trivial model with a single layer that calls a CUTLASS kernel
    model = CutlassLayer(b, scale_a, scale_b, torch.bfloat16)

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            out = model(a)
    out.zero_()
    g.replay()

    baseline = torch.mm(scale_a * a.to(dtype=torch.float32),
                        scale_b * b.to(dtype=torch.float32)).to(torch.bfloat16)
    assert torch.allclose(out, baseline, rtol=1e-1, atol=1e0)