test_machete_mm.py 13.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
"""Tests for the machete kernel.

5
Run `pytest tests/kernels/quantization/test_machete_mm.py`.
6
7
8
9
10
11
12
13
14
15
"""

import math
from dataclasses import dataclass, fields

import pytest
import torch

from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
16
from vllm.model_executor.layers.quantization.utils.machete_utils import (
17
18
    query_machete_supported_group_sizes,
)
19
from vllm.model_executor.layers.quantization.utils.quant_utils import (
20
21
22
    pack_rows,
    quantize_weights,
)
23
24
25
from vllm.platforms import current_platform
from vllm.scalar_type import ScalarType, scalar_types

26
27
28
29
30
31
if current_platform.is_rocm():
    pytest.skip(
        "These tests require machete_prepack_B, not supported on ROCm.",
        allow_module_level=True,
    )

32
CUDA_DEVICES = [f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)]
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57

# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
#  unit tests to a common utility function. Currently the use of
#  `is_quant_method_supported` conflates kernels with quantization methods
#  an assumption which is breaking down as quantizations methods can have
#  have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.get_device_capability()[0] >= 9

MNK_SHAPES = [
    (1, 128, 128),
    (1, 8192, 28672),
    (13, 8192, 4096),
    (26, 4096, 8192),
    (64, 4096, 4096),
    (64, 8192, 28672),
    (257, 128, 4096),
    (257, 4224, 4160),
    (1024, 8192, 4096),
]


@dataclass
class TypeConfig:
    act_type: torch.dtype
    weight_type: ScalarType
58
59
60
61
62
    output_type: torch.dtype | None
    group_scale_type: torch.dtype | None
    group_zero_type: torch.dtype | None
    channel_scale_type: torch.dtype | None
    token_scale_type: torch.dtype | None
63
64
65
66
67
68
69
70


@dataclass
class Tensors:
    w_ref: torch.Tensor
    a_ref: torch.Tensor
    a: torch.Tensor
    w_q: torch.Tensor
71
72
73
74
    w_g_s: torch.Tensor | None
    w_g_zp: torch.Tensor | None
    w_ch_s: torch.Tensor | None
    w_tok_s: torch.Tensor | None
75
76
77
78
79
80


# (Act Type, Weight Type, Output Type, Scale Type, ZeroPoints,
#  Ch Scales Type, Tok Scales Type)
# NOTE: None "Scale Type" means the act type is floating point
#       None "Output Type" means the output type is the same as the act type
81
TestTypeTuple = tuple[
82
    list[torch.dtype], ScalarType, torch.dtype | None, torch.dtype | None, bool
83
]
84
85
TEST_TYPES = [
    # GPTQ style
86
87
88
89
90
91
92
93
94
95
96
97
98
    *(
        TypeConfig(
            act_type=a_type,
            weight_type=w_type,
            output_type=None,
            group_scale_type=a_type,
            group_zero_type=None,
            channel_scale_type=None,
            token_scale_type=None,
        )
        for w_type in [scalar_types.uint4b8, scalar_types.uint8b128]
        for a_type in [torch.float16, torch.bfloat16]
    ),
99
    # AWQ style
100
101
102
103
104
105
106
107
108
109
110
111
112
    *(
        TypeConfig(
            act_type=a_type,
            weight_type=w_type,
            output_type=None,
            group_scale_type=a_type,
            group_zero_type=a_type,
            channel_scale_type=None,
            token_scale_type=None,
        )
        for w_type in [scalar_types.uint4, scalar_types.uint8]
        for a_type in [torch.float16, torch.bfloat16]
    ),
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
    # # QQQ style
    # *(TypeConfig(act_type=torch.int8,
    #              weight_type=scalar_types.uint4b8,
    #              output_type=torch.float16,
    #              group_scale_type=group_scale_type,
    #              group_zero_type=None,
    #              channel_scale_type=torch.float,
    #              token_scale_type=torch.float)
    #   for group_scale_type in [None, torch.float16]),
    # *(TypeConfig(act_type=torch.float8_e4m3fn,
    #              weight_type=scalar_types.uint4b8,
    #              output_type=torch.float16,
    #              group_scale_type=group_scale_type,
    #              group_zero_type=None,
    #              channel_scale_type=torch.float,
    #              token_scale_type=torch.float)
    #   for group_scale_type in [None, torch.float16]),
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
]

# TODO: in future PR refactor this and `is_quant_method_supported` in the kernel
#  unit tests to a common utility function. Currently the use of
#  `is_quant_method_supported` conflates kernels with quantization methods
#  an assumption which is breaking down as quantizations methods can have
#  have kernels and some kernels support multiple quantization methods.
IS_SUPPORTED_BY_GPU = current_platform.has_device_capability(90)


def rand_data(shape, dtype=torch.float16, scale=1, offset=0):
    if dtype.is_floating_point:
        return (scale * torch.rand(shape, device="cuda") - offset).to(dtype)
    else:
        return torch.randint(-8, 7, shape, dtype=dtype, device="cuda")


147
def maybe_convert_zeropoints(zps: torch.Tensor | None, s: torch.Tensor):
148
149
150
    return zps if zps is None else -1 * s * (zps.to(s.dtype))


151
def group_size_valid(shape: tuple[int, int, int], group_size: int | None) -> bool:
152
    return group_size is None or group_size == -1 or shape[2] % group_size == 0
153
154


155
156
157
158
def machete_quantize_and_pack(
    atype: torch.dtype,
    w: torch.Tensor,
    wtype: ScalarType,
159
160
    stype: torch.dtype | None,
    group_size: int | None,
161
162
    zero_points: bool = False,
):
163
164
165
166
167
168
169
170
    assert wtype.is_integer(), "TODO: support floating point weights"

    w_ref, w_q, w_s, w_zp = quantize_weights(
        w,
        wtype,
        group_size=group_size,
        zero_points=zero_points,
        # to match how the kernel applies zps
171
172
        ref_zero_points_after_scales=True,
    )
173
174
175
176
177
178
179
180
181
182

    w_q = pack_rows(w_q, wtype.size_bits, *w_q.shape)
    w_q = w_q.t().contiguous().t()  # convert to col major

    w_q_machete = ops.machete_prepack_B(w_q, atype, wtype, stype)
    opcheck(torch.ops._C.machete_prepack_B, (w_q, atype, wtype.id, stype))

    return w_ref, w_q_machete, w_s, w_zp


183
184
185
def create_test_tensors(
    shape: tuple[int, int, int],
    types: TypeConfig,
186
187
    group_size: int | None,
    subset_stride_factor: int | None = None,
188
) -> Tensors:
189
190
191
    m, n, k = shape
    factor = subset_stride_factor or 1

192
193
194
    print(
        "create_test_tensors, shape:", shape, "types:", types, "group_size:", group_size
    )
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    a = rand_data((m * factor, k * factor), types.act_type, scale=3, offset=2)
    w = rand_data((k * factor, n * factor), types.act_type, scale=3, offset=1)

    if factor > 1:
        a = a[0:m, 0:k]
        w = w[0:k, 0:n]

    if types.group_scale_type is not None:
        w = w.to(types.group_scale_type)
    if w.dtype.itemsize == 1:
        w = w.to(torch.float16)

    w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
209
210
211
212
213
214
215
        a.dtype,
        w,
        types.weight_type,
        types.group_scale_type,
        group_size,
        types.group_zero_type is not None,
    )
216
217
218
219
220
221
222
223

    if not a.dtype.is_floating_point:
        aiinfo = torch.iinfo(a.dtype)
        w_ref = w_ref.round().clamp(aiinfo.min, aiinfo.max)

    a_ref = a.to(torch.float32)
    w_ref = w_ref.to(torch.float32)

224
225
226
227
228
229
230
231
232
233
    w_ch_s = (
        None
        if types.channel_scale_type is None
        else rand_data((n,), types.channel_scale_type)
    )
    w_tok_s = (
        None
        if types.token_scale_type is None
        else rand_data((m,), types.token_scale_type)
    )
234

235
236
237
238
239
240
241
242
243
244
    return Tensors(
        w_ref=w_ref,
        a_ref=a_ref,
        a=a,
        w_q=w_q_packed,
        w_g_s=w_s,
        w_g_zp=maybe_convert_zeropoints(w_zp, w_s),
        w_ch_s=w_ch_s,
        w_tok_s=w_tok_s,
    )
245
246
247


# None stype means scales use the same dtype as a
248
249
250
def machete_mm_test_helper(
    types: TypeConfig,
    tensors: Tensors,
251
252
    group_size: int | None = None,
    schedule: str | None = None,
253
):
254
255
256
257
    output_ref = torch.matmul(tensors.a_ref, tensors.w_ref)
    output_ref_type = output_ref.dtype

    if tensors.w_ch_s is not None:
258
259
260
        output_ref = (
            output_ref.to(tensors.w_ch_s.dtype) * tensors.w_ch_s.unsqueeze(0)
        ).to(output_ref_type)
261
    if tensors.w_tok_s is not None:
262
263
264
        output_ref = (
            output_ref.to(tensors.w_tok_s.dtype) * tensors.w_tok_s.unsqueeze(1)
        ).to(output_ref_type)
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284

    output = ops.machete_mm(
        a=tensors.a,
        b_q=tensors.w_q,
        b_type=types.weight_type,
        b_group_scales=tensors.w_g_s,
        b_group_zeros=tensors.w_g_zp,
        b_group_size=group_size,
        b_channel_scales=tensors.w_ch_s,
        a_token_scales=tensors.w_tok_s,
        out_type=types.output_type,
        schedule=schedule,
    )

    print(output)
    print(output_ref)

    # Relax atol as our reduction dim becomes larger (more rounding error)
    # Relax atol when we have zeropoints since the way machete applies
    #  zeropoints (after scales) causes noise around 0
285
286
287
    atol = (
        1
        if tensors.w_g_zp is not None
288
        else min(5e-2 * math.sqrt(tensors.a.shape[1]), 1)
289
    )
290
    rtol = 1e-1 if tensors.a.element_size() >= 2 else 2e-1
291
292
293
    torch.testing.assert_close(
        output, output_ref.to(output.dtype), rtol=rtol, atol=atol
    )
294
295


296
297
298
299
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
300
301
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_all_schedules(shape, types: TypeConfig):
302
    group_sizes: list[int | None] = []
303
304
305
    if types.group_scale_type is None:
        group_sizes = [None]
    else:
306
        group_sizes = query_machete_supported_group_sizes(types.act_type)
307
308
309
310
311
312
313
314

    for group_size in group_sizes:
        if not group_size_valid(shape, group_size):
            continue

        tensors = create_test_tensors(shape, types, group_size)
        print(f"MNK = {shape}")
        for schedule in ops.machete_supported_schedules(
315
316
317
318
319
320
            types.act_type,
            types.weight_type,
            group_scales_type=types.group_scale_type,
            group_zeros_type=types.group_scale_type,
            out_type=types.output_type,
        ):
321
322
323
324
            print(f"Testing schedule {schedule}")
            machete_mm_test_helper(types, tensors, group_size, schedule)


325
326
327
328
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
@pytest.mark.parametrize("shape", MNK_SHAPES, ids=lambda x: "x".join(str(v) for v in x))
329
330
@pytest.mark.parametrize("types", TEST_TYPES)
def test_machete_heuristic(shape, types: TypeConfig):
331
    group_sizes: list[int | None] = []
332
333
334
    if types.group_scale_type is None:
        group_sizes = [None]
    else:
335
        group_sizes = query_machete_supported_group_sizes(types.act_type)
336
337
338
339
340
341
342
343
344
345

    for group_size in group_sizes:
        if not group_size_valid(shape, group_size):
            continue

        tensors = create_test_tensors(shape, types, group_size)
        machete_mm_test_helper(types, tensors, group_size)


# Test working on other devices
346
347
348
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
349
350
351
352
@pytest.mark.parametrize("device", CUDA_DEVICES)
def test_machete_devices(device: str):
    group_size = 128

353
354
355
356
357
358
359
360
361
    type_config = TypeConfig(
        act_type=torch.float16,
        weight_type=scalar_types.uint4b8,
        output_type=None,
        group_scale_type=torch.float16,
        group_zero_type=None,
        channel_scale_type=None,
        token_scale_type=None,
    )
362
363
364
365
366
367
368
369
370
371
372
373

    tensors = create_test_tensors((512, 4096, 4096), type_config, group_size)

    for field in fields(Tensors):
        tensor = getattr(tensors, field.name)
        if isinstance(tensor, torch.Tensor):
            setattr(tensors, field.name, tensor.to(device))

    machete_mm_test_helper(type_config, tensors, group_size)


# Test working with a subset of A and B
374
375
376
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
377
378
379
def test_machete_subset():
    group_size = 128

380
381
382
383
384
385
386
387
388
389
390
391
392
    type_config = TypeConfig(
        act_type=torch.float16,
        weight_type=scalar_types.uint4b8,
        output_type=None,
        group_scale_type=torch.float16,
        group_zero_type=None,
        channel_scale_type=None,
        token_scale_type=None,
    )

    tensors = create_test_tensors(
        (512, 4096, 4096), type_config, group_size, subset_stride_factor=2
    )
393
394
395
396
397
398
399
400
401
402
403
404
405
    machete_mm_test_helper(type_config, tensors, group_size)


# Test to make sure cuda graphs work
class MacheteLayer(torch.nn.Module):
    def __init__(self, **kwargs):
        super().__init__()
        self.kwargs = kwargs

    def forward(self, a):
        return ops.machete_mm(a=a, **self.kwargs)


406
407
408
@pytest.mark.skipif(
    not IS_SUPPORTED_BY_GPU, reason="Machete is not supported on this GPU type."
)
409
410
411
412
413
414
415
416
417
418
419
def test_machete_cuda_graph():
    m, n, k = 512, 4096, 4096

    a = rand_data((m, k), torch.float16)
    b = rand_data((k, n), torch.float16)
    wtype = scalar_types.uint4b8
    stype = torch.float16
    group_size = 128
    zero_points = False

    w_ref, w_q_packed, w_s, w_zp = machete_quantize_and_pack(
420
421
        a.dtype, b, wtype, stype, group_size, zero_points
    )
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447

    # Construct a trivial model with a single layer that calls a machete kernel
    model = MacheteLayer(
        b_q=w_q_packed,
        b_type=wtype,
        b_group_scales=w_s,
        b_group_zeros=maybe_convert_zeropoints(w_zp, w_s),
        b_group_size=group_size,
    )

    output_ref = torch.matmul(a, w_ref)

    # Run the model with a cuda graph
    stream = torch.cuda.Stream()
    with torch.cuda.stream(stream):
        g = torch.cuda.CUDAGraph()
        with torch.cuda.graph(g):
            output = model(a)
    output.zero_()
    g.replay()

    # Relax atol as our reduction dim becomes larger (more rounding error)
    # Relax atol when we have zeropoints since the way machete applies
    #  zeropoints (after scales) causes noise around 0
    atol = 1 if zero_points else min(5e-2 * math.sqrt(k), 1)
    torch.testing.assert_close(output, output_ref, rtol=1e-1, atol=atol)