test_compressed_tensors.py 25.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Test model set-up and weight loading for llmcompressor-quantized models.
4
5
6

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
7

8
import pytest
9
import torch
10
from compressed_tensors.quantization import QuantizationType
11

12
from tests.models.utils import check_logprobs_close
13
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
14
15
16
17
18
19
20
21
22
23
24
    CompressedTensors24,
    CompressedTensorsLinearMethod,
    CompressedTensorsW4A4Fp4,
    CompressedTensorsW4A8Fp8,
    CompressedTensorsW4A16Fp4,
    CompressedTensorsW4A16Sparse24,
    CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8,
    CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16,
)
25
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
26
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
27
from vllm.model_executor.layers.quantization.utils.quant_utils import (
28
29
    cutlass_fp4_supported,
)
30
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
31
32
    sparse_cutlass_supported,
)
33
from vllm.platforms import current_platform
34

35
36
37
38
39
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL = [
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
40
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
41
42
43
44
45
46
47
48
49
50
51
]

# TritonScaledMMLinearKernel only supports symmetric quantization.
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]

52

53
@pytest.fixture(scope="function", autouse=True)
54
55
56
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
57
58


59
60
@pytest.mark.parametrize(
    "model_args",
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
    [
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            "tensor",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
            "tensor",
            QuantizationType.INT,
            2560,
            False,
        ),
    ],
)
78
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
79
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
80

81
82
83
84
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
85
86
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

87
    with vllm_runner(model_path, enforce_eager=True) as llm:
88
89
90
91
92
93
94
95
96
97

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
98
            def zp_valid(zp: torch.Tensor | None):
99
100
101
102
103
104
105
106
107
108
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

109
110
111
112
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
133

134
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
135
136
        assert output

137

138
139
140
141
142
143
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    ],
)
144
@pytest.mark.parametrize("max_tokens", [8])
145
@pytest.mark.parametrize("num_logprobs", [10])
146
@pytest.mark.parametrize(
147
148
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
149
150
151
152
153
154
155
def test_compressed_tensors_w8a8_logprobs(
    hf_runner,
    vllm_runner,
    example_prompts,
    model_path,
    max_tokens,
    num_logprobs,
156
157
    use_aiter,
    monkeypatch,
158
):
159
160
161
162
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
163
164
165
166
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
167
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
168
169
170
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

171
172
    dtype = "bfloat16"

173
174
    # skip language translation prompt for the static per tensor models
    if model_path in (
175
176
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
177
    ):
178
179
        example_prompts = example_prompts[0:-1]

180
181
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
182
183
            example_prompts, max_tokens, num_logprobs
        )
184
185
186

    with vllm_runner(model_path, dtype=dtype) as vllm_model:
        vllm_outputs = vllm_model.generate_greedy_logprobs(
187
188
            example_prompts, max_tokens, num_logprobs
        )
189
190
191
192
193
194
195
196

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

197
198
199
    if current_platform.is_rocm():
        torch.cuda.synchronize()

200

201
def test_compressed_tensors_no_enforce_eager(vllm_runner):
202
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
203
    with vllm_runner(model_path) as llm:
204
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
205
206
207
        assert output


208
209
210
211
212
213
214
215
216
217
@pytest.mark.parametrize(
    "model_args",
    [
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
            "channel",
        ),
    ],
)
218
@pytest.mark.parametrize(
219
220
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
221
222
223
224
225
226
def test_compressed_tensors_w8a8_dynamic_per_token(
    vllm_runner,
    model_args,
    use_aiter,
    monkeypatch,
):
227
    model_path, strategy = model_args
228

229
230
231
232
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
233
234
235
236
        pytest.skip(f"Skip model {model_path} as it is not support on ROCm.")

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
237
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
238
239
240
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

241
    with vllm_runner(model_path, enforce_eager=True, dtype=torch.float16) as llm:
242

243
244
245
246
247
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

248
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
249
250
251
252
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
253

254
        llm.apply_model(check_model)
255

256
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
257
258
        assert output

259

260
261
@pytest.mark.parametrize(
    "wNa16_args",
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
    [
        (
            "nm-testing/tinyllama-oneshot-w4a16-channel-v2",
            "channel",
            None,
            8,
            True,
            False,
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
            "group",
            128,
            8,
            False,
            True,
        ),
    ],
)
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform."
283
)
284
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
285
    model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
286
    with vllm_runner(model, enforce_eager=True) as llm:
287

288
289
        def check_model(model):
            layer = model.model.layers[0]
290

291
            qkv_proj = layer.self_attn.qkv_proj
292
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
293
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
294

295
            assert qkv_proj.scheme.strategy == strategy
296
            assert qkv_proj.scheme.group_size == (-1 if group is None else group)
297
298

            assert qkv_proj.scheme.pack_factor == pack_factor
299
300
            assert qkv_proj.scheme.symmetric == symmetric
            assert qkv_proj.scheme.has_g_idx == has_g_idx
301
302

        llm.apply_model(check_model)
303

304
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
305
306
        assert output

307

308
309
310
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
311
312
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
313
    with vllm_runner(model_path, enforce_eager=True) as llm:
314

315
316
317
318
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
319

320
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
321
322
323
324
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32

        llm.apply_model(check_model)
325

326
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
327
        assert output
328
329
330
331


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
332
    with vllm_runner(model_path, enforce_eager=True) as llm:
333

334
335
        def check_model(model):
            layer = model.model.layers[0]
336

337
            qkv_proj = layer.self_attn.qkv_proj
338

339
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
340
341
            assert isinstance(
                qkv_proj.scheme,
342
343
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
            )
344

345
346
347
348
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
349
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
350
351
352
353
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
354

355
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
356
        assert output
357
358


359
360
@pytest.mark.skipif(
    not current_platform.is_kv_cache_dtype_supported("fp8", None),
361
362
363
364
365
    reason="FP8 KV cache is not supported on this device.",
)
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
366
367
def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
368
369
    with vllm_runner(model_path, enforce_eager=True, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=4)
370
        assert output
371
372


373
374
375
376
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
377
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"):
378
379
380
381
382
383
384
385
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
386
    assert sparsity_map.get("Linear").format == format
387
388
389
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


390
@pytest.mark.skipif(
391
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
419
420
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
421
    with vllm_runner(model, enforce_eager=True) as llm:
422

423
424
425
426
427
428
429
430
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
431

432
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
433
434
435
436
        print(output)
        assert output


437
@pytest.mark.skipif(
438
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
468
    with vllm_runner(model, enforce_eager=True) as llm:
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

484
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
520
    with vllm_runner(model, enforce_eager=True) as llm:
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

536
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
565
566
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
567
    with vllm_runner(model, enforce_eager=True) as llm:
568

569
570
571
572
573
574
575
576
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
577

578
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
579
580
581
582
        print(output)
        assert output


583
584
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
585
586
    reason="2of4 Sparse is not yet supported on this GPU type.",
)
587
588
@pytest.mark.parametrize(
    "args_2of4",
589
590
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
591
592
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
593
    with vllm_runner(model, enforce_eager=True) as llm:
594
595
596
597
598

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
599
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
600
601
602
603
604
605
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
606
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
607
608
609
610
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
611

612
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
613
614
        print(output)
        assert output
615
616
617
618
619
620
621


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
622
623
    "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]
)
624
625
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
    model = args_2of4
626
    with vllm_runner(model, enforce_eager=True) as llm:
627
628
629
630
631

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
632
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
633
634
635
636
637
638
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
639
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
640
641
642
643
644
            assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)

645
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
646
647
        print(output)
        assert output
648
649


650
651
@pytest.mark.parametrize(
    "args",
652
    [
653
654
        # TODO: Enable once model is available again
        # ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
655
656
657
        ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
    ],
)
658
659
def test_compressed_tensors_nvfp4(vllm_runner, args):
    model, scheme = args
660
661
662
663
664
665
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
666
667
668
669
670
671
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            if (
                isinstance(qkv_proj.scheme, scheme)
                or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
                and not cutlass_fp4_supported()
            ):
672
673
674
675
                assert True
            else:
                raise AssertionError("FP4 Scheme Mismatch")

676
677
678
            assert qkv_proj.scheme.group_size == 16

        llm.apply_model(check_model)
679
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
680
681
        print(output)
        assert output
682
683
684


@pytest.mark.skipif(
685
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
686
687
    reason="W4A8 FP8 is not yet supported on this GPU type.",
)
688
689
690
691
@pytest.mark.parametrize(
    "args",
    [("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)],
)
692
693
694
695
696
697
698
699
700
701
702
703
704
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
    model, scheme = args
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
705
                assert isinstance(proj.quant_method, CompressedTensorsLinearMethod)
706
707
708
709
710
711
712
713
                assert isinstance(proj.scheme, scheme)

                assert proj.weight_packed.dtype is torch.int32
                assert proj.weight_scale.dtype is torch.float8_e4m3fn
                assert proj.weight_chan_scale.dtype is torch.float32
                assert proj.scheme.group_size == 128

        llm.apply_model(check_model)
714
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
715
716
        print(output)
        assert output
717
718


719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize(
    "model,prompt,exp_perplexity",
    [
        (
            "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
        (
            "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
    ],
)
def test_compressed_tensors_transforms_perplexity(
    vllm_runner, model, prompt, exp_perplexity
):
740
741
742
    with vllm_runner(model, enforce_eager=True) as llm:
        perplexity = llm.generate_prompt_perplexity([prompt])[0]
        print(perplexity)
743
        assert perplexity <= exp_perplexity
744
745
746
747


def test_compressed_tensors_fp8_block_enabled(vllm_runner):
    model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
748
    with vllm_runner(model_path, enforce_eager=True) as llm:
749
750
751
752
753
754
        fp8_dtype = current_platform.fp8_dtype()

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
755
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
756
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
757
758
759
            assert isinstance(
                qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
            )
760
761
762
763
764
765

            assert qkv_proj.weight.dtype is fp8_dtype
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight.shape) == 2
            assert len(qkv_proj.weight_scale.shape) == 2

766
            input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
767
768
769
770
771
            assert isinstance(input_quant_op, QuantFP8)
            assert input_quant_op._forward_method == input_quant_op.forward_cuda

        llm.apply_model(check_model)

772
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
773
        assert output