test_cutlass_2of4_sparse.py 9.49 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
10
"""Tests for sparse cutlass kernels

Run `pytest tests/kernels/test_semi_structured.py`.
"""

import pytest
import torch

11
from tests.kernels.utils import baseline_scaled_mm, to_fp8, to_int8
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    sparse_cutlass_supported)
from vllm.platforms import current_platform

CUDA_DEVICES = [
    f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
]

capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]


def to_bf16(tensor: torch.Tensor) -> torch.Tensor:
    return tensor.to(dtype=torch.bfloat16)


def to_fp16(tensor: torch.Tensor) -> torch.Tensor:
    return tensor.to(dtype=torch.float16)


def prune_to_2_4(tensor):
    # Reshape tensor to [N, 4] where N is number of groups of 4
    original_shape = tensor.shape
    reshaped = tensor.reshape(-1, 4)

    # Get indices of top 2 absolute values in each group of 4
    _, indices = torch.topk(torch.abs(reshaped), k=2, dim=1)

    # Create binary mask
    mask = torch.zeros_like(reshaped)
    mask.scatter_(dim=1,
                  index=indices,
                  src=torch.ones_like(indices, dtype=mask.dtype))

    # Apply mask and reshape back
    pruned = reshaped * mask

    # Turn all -0.0 to 0.0
    pruned[pruned == -0.0] = 0.0

    return pruned.reshape(original_shape)


56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# This function checks that applying an identity matrix multiplication
# to the compressed weights yields the original uncompressed weights.
def check_compress_decompress_invariance(dtype: torch.dtype, b: torch.Tensor,
                                         b_compressed: torch.Tensor,
                                         b_metadata: torch.Tensor):

    # For float16 and bfloat16, cutlass_scaled_sparse_mm's output must be the
    # same dtype as its inputs. This line addresses that constraint while
    # arbitrarily using bfloat16 for the int8/fp8 cases.
    out_dtype = torch.float16 if dtype is torch.float16 else torch.bfloat16

    eye = torch.eye(b.shape[0], device='cuda', dtype=dtype)
    eye_scale = torch.ones(1, device='cuda', dtype=torch.float32)
    b_decomp = ops.cutlass_scaled_sparse_mm(eye,
                                            b_compressed,
                                            b_metadata,
                                            eye_scale,
                                            eye_scale,
                                            out_dtype=out_dtype)

    torch.testing.assert_close(b.to(dtype=out_dtype), b_decomp)


79
80
def make_rand_sparse_tensors(
        dtype: torch.dtype, m: int, n: int, k: int
81
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
82
83
84
85
86
87
88
    a = torch.randn((m, k), device='cuda')
    b = torch.randn((n, k), device='cuda').t()

    if dtype == torch.int8:
        # ensure A and B aren't all zeros after rounding
        a = a * 5.0
        b = b * 5.0
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103

    b = prune_to_2_4(b.t()).t()

    if dtype == torch.int8:
        a, b = to_int8(a), to_int8(b)
    elif dtype == torch.float8_e4m3fn:
        a, b = to_fp8(a), to_fp8(b)
    elif dtype == torch.float16:
        a, b = to_fp16(a), to_fp16(b)
    elif dtype == torch.bfloat16:
        a, b = to_bf16(a), to_bf16(b)
    else:
        raise ValueError("unsupported dtype")

    b_compressed, e = ops.cutlass_sparse_compress(b.t())
104
    check_compress_decompress_invariance(dtype, b, b_compressed, e)
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

    # Compressed B, Metadata, Original A, B
    return b_compressed, e, a, b


@pytest.mark.skipif(not sparse_cutlass_supported(),
                    reason="Sparse CUTLASS is not supported on this GPU type.")
# Test working with a subset of A and B for sparse matmul
def test_cutlass_sparse_subset():

    big_m = 1024
    m, n, k = 512, 512, 512

    # Create tensors
    b_comp, e, whole_a, b = make_rand_sparse_tensors(torch.float8_e4m3fn,
                                                     big_m, n, k)
    a = whole_a[0:m, 0:k]
    scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
    scale_b = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10

    out = ops.cutlass_scaled_sparse_mm(a,
                                       b_comp,
                                       e,
                                       scale_a,
                                       scale_b,
                                       out_dtype=torch.bfloat16)
    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
                                  out_dtype=torch.bfloat16)

    torch.testing.assert_close(out, baseline, rtol=1e-1, atol=1e0)


MNK_FACTORS = [
    (1, 256, 128),
    (1, 16384, 1024),
    (1, 24576, 512),
    (16, 256, 512),
    (16, 16384, 128),
    (16, 24576, 4096),
    (32, 8192, 4096),
    (32, 16384, 4096),
    (33, 1024, 1024),
    (33, 8192, 128),
    (64, 2048, 512),
    (64, 16384, 1024),
    (100, 8192, 512),
    (128, 32768, 4096),
    (256, 4096, 4096),
    (512, 256, 1024),
    (512, 8192, 4096),
    (512, 16384, 128),
    (512, 24576, 128),
]


# Test working with a subset of A and B for sparse matmul
@pytest.mark.skipif(not sparse_cutlass_supported(),
                    reason="Sparse CUTLASS is not supported on this GPU type.")
166
@pytest.mark.parametrize("m, n, k", MNK_FACTORS)
167
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
168
@pytest.mark.parametrize("use_bias", [True, False])
169
def test_cutlass_sparse_gemm(m: int, k: int, n: int, dtype: type[torch.dtype],
170
                             use_bias: bool):
171
172
173
174
175
176

    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k)
    scale_a = torch.ones((1, 1), device="cuda", dtype=torch.float32)
    scale_b = torch.ones((1, 1), device="cuda", dtype=torch.float32)

177
178
    bias = torch.rand((n, ), device="cuda", dtype=dtype) if use_bias else None

179
180
181
182
183
    out = ops.cutlass_scaled_sparse_mm(a,
                                       b_comp,
                                       e,
                                       scale_a,
                                       scale_b,
184
185
                                       out_dtype=dtype,
                                       bias=bias)
186

187
188
189
190
191
192
193
194
    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
                                  out_dtype=dtype,
                                  bias=bias)

    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
195
196
197
198
199
200
201


@pytest.mark.skipif(not sparse_cutlass_supported(),
                    reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m, k, n", MNK_FACTORS)
@pytest.mark.skipif(not current_platform.has_device_capability(89),
                    reason="FP8 is not supported on this GPU type.")
202
203
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_fp8_gemm(m: int, n: int, k: int, use_bias: bool):
204
205
206
207
208

    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(torch.float8_e4m3fn, m, n, k)
    scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
    scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
209
210
211
212
    out_dtype = torch.bfloat16

    bias = torch.rand(
        (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
213
214
215
216
217
218

    out = ops.cutlass_scaled_sparse_mm(a,
                                       b_comp,
                                       e,
                                       scale_a,
                                       scale_b,
219
220
                                       out_dtype=out_dtype,
                                       bias=bias)
221
222
223
224
225

    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
226
227
                                  out_dtype=out_dtype,
                                  bias=bias)
228

229
    torch.testing.assert_close(out, baseline, rtol=1e-2, atol=3e-1)
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244


@pytest.mark.skipif(not sparse_cutlass_supported(),
                    reason="Sparse CUTLASS is not supported on this GPU type.")
@pytest.mark.parametrize("m,k,n", MNK_FACTORS)
@pytest.mark.parametrize("per_act_token", [True, False])
@pytest.mark.parametrize("per_out_ch", [True, False])
@pytest.mark.parametrize("use_bias", [True, False])
def test_cutlass_sparse_int8_gemm(m: int, n: int, k: int, per_act_token: bool,
                                  per_out_ch: bool, use_bias: bool):

    # Create tensors
    b_comp, e, a, b = make_rand_sparse_tensors(torch.int8, m, n, k)
    scale_a = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
    scale_b = (torch.randn((1, 1), device="cuda", dtype=torch.float32))
245
246
247
248
    out_dtype = torch.bfloat16

    bias = torch.rand(
        (n, ), device="cuda", dtype=out_dtype) * 10 if use_bias else None
249
250
251
252
253
254

    out = ops.cutlass_scaled_sparse_mm(a,
                                       b_comp,
                                       e,
                                       scale_a,
                                       scale_b,
255
256
                                       out_dtype=out_dtype,
                                       bias=bias)
257
258
259
260
261

    baseline = baseline_scaled_mm(a,
                                  b,
                                  scale_a,
                                  scale_b,
262
263
                                  out_dtype=out_dtype,
                                  bias=bias)
264
265

    torch.testing.assert_close(out, baseline, rtol=1e0, atol=2e0)