test_cutlass_w4a8.py 8.38 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests for the CUTLASS W4A8 kernel.

5
Run `pytest tests/kernels/quantization/test_cutlass_w4a8.py`.
6
7
8
9
10
11
12
13
14
"""

from dataclasses import dataclass

import pytest
import torch

from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.quant_utils import (
15
16
17
    pack_rows,
    quantize_weights,
)
18
19
20
21
22
23
24
25
26
27
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
#  unit tests to a common utility function. Currently the use of
#  `is_quant_method_supported` conflates kernels with quantization methods
#  an assumption which is breaking down as quantizations methods can have
#  have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9

28
29
30
31
32
33
34
35
36
37
38
39
40
41
MNK_SHAPES = [
    (1, 128, 128),
    (1, 512, 1024),
    (1, 4096, 4096),
    (1, 8192, 28672),
    (13, 8192, 4096),
    (26, 4096, 8192),
    (64, 4096, 4096),
    (64, 8192, 28672),
    (257, 128, 4096),
    (257, 4096, 4096),
    (1024, 4096, 8192),
    (1024, 8192, 4096),
]
42
43
44

# TODO(czhu): get supported schedules from fn
SCHEDULES = [
45
46
47
48
49
50
51
52
53
54
    "128x16_1x1x1",
    "256x16_1x1x1",
    "128x32_1x1x1",
    "256x32_1x1x1",
    "128x64_1x1x1",
    "256x64_1x1x1",
    "128x128_1x1x1",
    "256x128_1x1x1",
    "128x256_1x1x1",
    "128x256_2x1x1",
55
56
57
58
59
60
61
]


@dataclass
class TypeConfig:
    act_type: torch.dtype
    weight_type: ScalarType
62
63
64
65
    output_type: torch.dtype | None
    group_scale_type: torch.dtype | None
    channel_scale_type: torch.dtype | None
    token_scale_type: torch.dtype | None
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80


@dataclass
class Tensors:
    w_ref: torch.Tensor
    a_ref: torch.Tensor
    a: torch.Tensor
    w_q: torch.Tensor
    w_g_s: torch.Tensor
    w_ch_s: torch.Tensor
    w_tok_s: torch.Tensor


# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
#  Ch Scales Type, Tok Scales Type)
81
TestTypeTuple = tuple[
82
    list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
83
]
84
85
TEST_TYPES = [
    *(
86
87
88
89
90
91
92
93
        TypeConfig(
            act_type=torch.float8_e4m3fn,
            weight_type=w_type,
            output_type=o_type,
            group_scale_type=torch.float8_e4m3fn,
            channel_scale_type=torch.float32,
            token_scale_type=torch.float32,
        )
94
95
        for w_type in [scalar_types.int4]
        # TODO(czhu): fp16 out type
96
97
        for o_type in [torch.bfloat16]
    ),
98
99
100
101
102
103
104
105
106
107
108
109
110
]

# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
#  unit tests to a common utility function. Currently the use of
#  `is_quant_method_supported` conflates kernels with quantization methods
#  an assumption which is breaking down as quantizations methods can have
#  have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)


# For testing quantized linear kernels
def to_fp8(tensor: torch.Tensor):
    finfo = torch.finfo(torch.float8_e4m3fn)
111
    return tensor.clamp(min=finfo.min, max=finfo.max).to(dtype=torch.float8_e4m3fn)
112
113


114
115
116
117
def cutlass_quantize_and_pack(
    atype: torch.dtype,
    w: torch.Tensor,
    wtype: ScalarType,
118
119
    stype: torch.dtype | None,
    group_size: int | None,
120
121
    zero_points: bool = False,
):
122
123
    assert wtype.is_integer(), "TODO: support floating point weights"

124
125
126
    w_ref, w_q, w_s, w_zp = quantize_weights(
        w, wtype, group_size=group_size, zero_points=zero_points
    )
127
128

    # since scales are cast to fp8, we need to compute w_ref this way
129
130
131
132
    w_ref = (
        (w_q).to(torch.float32)
        * w_s.to(atype).to(torch.float32).repeat_interleave(group_size, dim=0)
    ).to(atype)
133
134
135
136
137
138
139
140
141
142
143

    # bit mask prevents sign extending int4 when packing
    w_q = pack_rows(w_q & 0x0F, wtype.size_bits, *w_q.shape)
    w_q = w_q.t().contiguous().t()  # convert to col major

    w_q_packed = ops.cutlass_encode_and_reorder_int4b(w_q)
    w_s_packed = ops.cutlass_pack_scale_fp8(w_s.to(atype))

    return w_ref, w_q_packed, w_s_packed, w_zp


144
def create_test_tensors(
145
    shape: tuple[int, int, int], types: TypeConfig, group_size: int | None
146
) -> Tensors:
147
148
    m, n, k = shape

149
150
151
    print(
        "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
    )
152
153
154
155
156
157
158
159
160
161

    a = to_fp8(torch.randn((m, k), device="cuda"))
    w = to_fp8(torch.randn((k, n), device="cuda"))

    if types.group_scale_type is not None:
        w = w.to(types.group_scale_type)
    if w.dtype.itemsize == 1:
        w = w.to(torch.float16)

    w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
162
163
        a.dtype, w, types.weight_type, types.group_scale_type, group_size, False
    )
164
165
166
167
168

    a_ref = a.to(torch.float32)
    w_ref = w_ref.to(torch.float32)

    # for the practical use case we need per-tok scales for fp8 activations
169
    w_tok_s = torch.randn((m,), device="cuda", dtype=types.token_scale_type)
170
    # weights are already per-group quantized, use placeholder here
171
172
173
174
175
176
177
178
179
180
181
    w_ch_s = torch.ones((n,), device="cuda", dtype=types.channel_scale_type)

    return Tensors(
        w_ref=w_ref,
        a_ref=a_ref,
        a=a,
        w_q=w_q_packed,
        w_g_s=w_s,
        w_ch_s=w_ch_s,
        w_tok_s=w_tok_s,
    )
182
183


184
185
186
def mm_test_helper(
    types: TypeConfig,
    tensors: Tensors,
187
188
    group_size: int | None = None,
    schedule: str | None = None,
189
):
190
191
192
193
194
195
196
197
    # CUTLASS upstream uses fp8 with fastaccum as reference
    # https://github.com/NVIDIA/cutlass/blob/main/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L406
    output_ref = torch._scaled_mm(
        tensors.a_ref.to(types.act_type),
        tensors.w_ref.to(types.act_type).t().contiguous().t(),  # col major
        tensors.w_tok_s.unsqueeze(1),
        tensors.w_ch_s.unsqueeze(0),
        out_dtype=types.output_type,
198
199
        use_fast_accum=True,
    )
200
201
202
203
204
205
206
207
208
209
210
211
212

    output = ops.cutlass_w4a8_mm(
        a=tensors.a,
        b_q=tensors.w_q,
        b_group_scales=tensors.w_g_s,
        b_group_size=group_size,
        b_channel_scales=tensors.w_ch_s,
        a_token_scales=tensors.w_tok_s,
    )

    print(output)
    print(output_ref)

213
214
215
    torch.testing.assert_close(
        output, output_ref.to(output.dtype), rtol=1e-3, atol=1e-3
    )
216
217


218
219
220
221
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
@pytest.mark.parametrize("types", TEST_TYPES)
@pytest.mark.parametrize("schedule", SCHEDULES)
def test_cutlass_w4a8(shape, types: TypeConfig, schedule):
    group_sizes = [128]
    for group_size in group_sizes:
        tensors = create_test_tensors(shape, types, group_size)
        mm_test_helper(types, tensors, group_size, schedule)


# Test to make sure cuda graphs work
class W4A8Layer(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.kwargs = kwargs

    def forward(self, a):
        return ops.cutlass_w4a8_mm(a=a, **self.kwargs)


241
242
243
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="CUTLASS W4A8 is not supported on this GPU type."
)
244
245
246
247
248
249
250
251
252
253
254
255
def test_w4a8_cuda_graph():
    m, n, k = 512, 4096, 4096

    a = to_fp8(torch.randn((m, k), device="cuda"))
    b = to_fp8(torch.randn((k, n), device="cuda"))

    wtype = scalar_types.int4
    stype = torch.float8_e4m3fn
    group_size = 128
    zero_points = False

    w_ref, w_q_packed, w_s, _ = cutlass_quantize_and_pack(
256
257
        a.dtype, b.to(torch.float16), wtype, stype, group_size, zero_points
    )
258

259
260
    w_tok_s = torch.randn((m,), device="cuda", dtype=torch.float32)
    w_ch_s = torch.ones((n,), device="cuda", dtype=torch.float32)
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276

    # Construct a trivial model with a single layer that calls the kernel
    model = W4A8Layer(
        b_q=w_q_packed,
        b_group_scales=w_s,
        b_group_size=group_size,
        b_channel_scales=w_ch_s,
        a_token_scales=w_tok_s,
    )

    output_ref = torch._scaled_mm(
        a,
        w_ref.to(a.dtype).t().contiguous().t(),  # col major
        w_tok_s.unsqueeze(1),
        w_ch_s.unsqueeze(0),
        out_dtype=torch.bfloat16,
277
278
        use_fast_accum=True,
    )
279
280
281
282
283
284
285
286
287
288
289
290

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            output = model(a)

    output.zero_()
    g.replay()

    torch.testing.assert_close(output, output_ref, rtol=1e-3, atol=1e-3)