test_compressed_tensors.py 21.5 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
"""Test model set-up and weight loading for llmcompressor-quantized models.
3
4
5

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
6

7
from typing import Optional
8

9
import pytest
10
import torch
11
from compressed_tensors.quantization import QuantizationType
12

13
from tests.models.utils import check_logprobs_close
14
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
15
16
17
18
    CompressedTensors24, CompressedTensorsLinearMethod,
    CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16)
19
20
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
    sparse_cutlass_supported)
21
from vllm.platforms import current_platform
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL = [
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2"
]

# TritonScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]

40

41
42
43
44
45
46
47
48
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This module relies on V0 internals, so set VLLM_USE_V1=0.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


49
50
@pytest.mark.parametrize(
    "model_args",
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
    [
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            "tensor",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
            "channel",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
            "tensor",
            QuantizationType.INT,
            2560,
            False,
        ),
    ],
)
75
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
76
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
77
78
79
80
81

    if current_platform.is_rocm(
    ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

82
    with vllm_runner(model_path, enforce_eager=True) as llm:
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
            def zp_valid(zp: Optional[torch.Tensor]):
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
132

133
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
134
135
        assert output

136

137
138
139
140
141
142
143
144
145
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dynamic-Asym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
    ],
)
146
147
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [10])
148
149
@pytest.mark.parametrize(
    "use_aiter", [True, False] if current_platform.is_rocm() else [False])
150
151
152
153
154
155
156
def test_compressed_tensors_w8a8_logprobs(
    hf_runner,
    vllm_runner,
    example_prompts,
    model_path,
    max_tokens,
    num_logprobs,
157
158
    use_aiter,
    monkeypatch,
159
):
160
161
162
163
164
165
166
167
168
169
170
171

    if current_platform.is_rocm(
    ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
            pytest.skip(
                f"Skip model {model_path} as it is not support by aiter.")
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

172
173
    dtype = "bfloat16"

174
    # skip language translation prompt for the static per tensor asym model
175
176
177
    if (model_path ==
            "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym"
        ):  # noqa: E501
178
179
        example_prompts = example_prompts[0:-1]

180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

195
196
197
    if current_platform.is_rocm():
        torch.cuda.synchronize()

198

199
def test_compressed_tensors_no_enforce_eager(vllm_runner):
200
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
201
    with vllm_runner(model_path) as llm:
202
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
203
204
205
        assert output


206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
@pytest.mark.parametrize(
    "model_args",
    [
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2-asym", "tensor"),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
            "channel",
        ),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2-asym",
            "channel",
        ),
    ],
)
221
222
223
224
225
226
227
228
@pytest.mark.parametrize(
    "use_aiter", [True, False] if current_platform.is_rocm() else [False])
def test_compressed_tensors_w8a8_dynamic_per_token(
    vllm_runner,
    model_args,
    use_aiter,
    monkeypatch,
):
229
    model_path, strategy = model_args
230
231
232
233
234
235
236
237
238
239
240
241

    if current_platform.is_rocm(
    ) and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL:
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
            pytest.skip(
                f"Skip model {model_path} as it is not support by aiter.")
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

242
    with vllm_runner(model_path, dtype=torch.float16) as llm:
243

244
245
246
247
248
249
250
251
252
253
254
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
255

256
        llm.apply_model(check_model)
257

258
        output = llm.generate_greedy(["Hello my name is"], max_tokens=20)
259
260
        assert output

261

262
263
@pytest.mark.parametrize(
    "wNa16_args",
264
265
266
267
268
269
    [
        ("nm-testing/tinyllama-oneshot-w4a16-channel-v2", "channel", None, 8),
        ("nm-testing/tinyllama-oneshot-w4a16-group128-v2", "group", 128, 8),
        ("nm-testing/tinyllama-oneshot-w8a16-per-channel", "channel", None, 4),
    ],
)
270
271
@pytest.mark.skipif(not current_platform.is_cuda(),
                    reason="The tests are skipped on non-CUDA platform.")
272
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
273
    model, strategy, group, pack_factor = wNa16_args
274
275
    with vllm_runner(model) as llm:

276
277
        def check_model(model):
            layer = model.model.layers[0]
278

279
280
281
282
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
283

284
285
286
287
288
289
290
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.group_size == (-1
                                                  if group is None else group)

            assert qkv_proj.scheme.pack_factor == pack_factor

        llm.apply_model(check_model)
291

292
293
294
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output

295

296
297
@pytest.mark.skipif(not current_platform.is_cuda(),
                    reason="This test is skipped on non-CUDA platform.")
298
299
300
301
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
    with vllm_runner(model_path) as llm:

302
303
304
305
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
306

307
308
309
310
311
312
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32

        llm.apply_model(check_model)
313

314
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
315
        assert output
316
317
318
319
320
321


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
    with vllm_runner(model_path) as llm:

322
323
        def check_model(model):
            layer = model.model.layers[0]
324

325
            qkv_proj = layer.self_attn.qkv_proj
326

327
328
329
330
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(
                qkv_proj.scheme,
331
332
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
            )
333

334
335
336
337
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
338
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
339
340
341
342
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
343

344
        output = llm.generate_greedy("Hello my name is", max_tokens=20)
345
        assert output
346
347


348
349
@pytest.mark.skipif(not current_platform.is_cuda(),
                    reason="This test is skipped on non-CUDA platform.")
350
351
352
353
def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
    with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=20)
354
        assert output
355
356


357
358
359
360
361
362
363
364
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
def _test_2of4_quant_models(qkv_proj,
                            weight_strategy,
                            input_strategy,
                            format="dense"):
365
366
367
368
369
370
371
372
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
373
    assert sparsity_map.get("Linear").format == format
374
375
376
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


377
@pytest.mark.skipif(
378
379
    not current_platform.is_cuda()
    or not current_platform.has_device_capability(90),
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
407
408
409
410
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

411
412
413
414
415
416
417
418
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
419
420
421
422
423
424

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


425
@pytest.mark.skipif(
426
427
    not current_platform.is_cuda()
    or not current_platform.has_device_capability(90),
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
554
555
556
557
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
    with vllm_runner(model) as llm:

558
559
560
561
562
563
564
565
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
566
567
568
569
570
571

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output


572
573
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
574
575
    reason="2of4 Sparse is not yet supported on this GPU type.",
)
576
577
@pytest.mark.parametrize(
    "args_2of4",
578
579
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
580
581
582
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:
583
584
585
586
587
588
589
590
591
592
593
594
595

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
596
597
598
            sparsity_map = (
                qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            )  # noqa: E501
599
600
601
602
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
603
604
605
606

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")])
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
    model = args_2of4
    with vllm_runner(model) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.quant_method,
                              CompressedTensorsLinearMethod)
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            sparsity_map = (
                qkv_proj.quant_method.quantization_config.sparsity_scheme_map
            )  # noqa: E501
            assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        print(output)
        assert output