test_compressed_tensors.py 28 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Test model set-up and weight loading for llmcompressor-quantized models.
4
5
6

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
7

8
import pytest
9
import torch
10
from compressed_tensors.quantization import QuantizationType
11

12
from tests.models.utils import check_logprobs_close
13
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
14
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
15
16
17
18
19
20
21
22
23
24
25
    CompressedTensors24,
    CompressedTensorsLinearMethod,
    CompressedTensorsW4A4Fp4,
    CompressedTensorsW4A8Fp8,
    CompressedTensorsW4A16Fp4,
    CompressedTensorsW4A16Sparse24,
    CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8,
    CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16,
)
26
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
27
from vllm.model_executor.layers.quantization.utils.fp8_utils import W8A8BlockFp8LinearOp
28
from vllm.model_executor.layers.quantization.utils.quant_utils import (
29
30
    cutlass_fp4_supported,
)
31
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
32
33
    sparse_cutlass_supported,
)
34
from vllm.platforms import current_platform
35

36
37
38
39
40
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL = [
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
41
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
42
43
]

44
# TritonInt8ScaledMMLinearKernel only supports symmetric quantization.
45
46
47
48
49
50
51
52
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]

53

54
@pytest.fixture(scope="function", autouse=True)
55
56
57
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
58
59


60
61
@pytest.mark.parametrize(
    "model_args",
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    [
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            "tensor",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
            "tensor",
            QuantizationType.INT,
            2560,
            False,
        ),
    ],
)
79
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
80
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
81

82
83
84
85
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
86
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
87

88
    with vllm_runner(model_path, enforce_eager=True) as llm:
89
90
91
92
93
94
95
96
97
98

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
99
            def zp_valid(zp: torch.Tensor | None):
100
101
102
103
104
105
106
107
108
109
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

110
111
112
113
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
134

135
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
136
137
        assert output

138

139
140
141
142
143
144
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    ],
)
145
@pytest.mark.parametrize("max_tokens", [4])
146
@pytest.mark.parametrize("num_logprobs", [10])
147
@pytest.mark.parametrize(
148
149
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
150
151
152
153
154
155
156
def test_compressed_tensors_w8a8_logprobs(
    hf_runner,
    vllm_runner,
    example_prompts,
    model_path,
    max_tokens,
    num_logprobs,
157
158
    use_aiter,
    monkeypatch,
159
):
160
161
162
163
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
164
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
165
166
167

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
168
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
169
170
171
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

172
173
    dtype = "bfloat16"

174
175
    # skip language translation prompt for the static per tensor models
    if model_path in (
176
177
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
178
    ):
179
180
        example_prompts = example_prompts[0:-1]

181
182
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
183
184
            example_prompts, max_tokens, num_logprobs
        )
185

186
    with vllm_runner(model_path, dtype=dtype, enforce_eager=True) as vllm_model:
187
        vllm_outputs = vllm_model.generate_greedy_logprobs(
188
189
            example_prompts, max_tokens, num_logprobs
        )
190
191
192
193
194
195
196
197

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

198
199
200
    if current_platform.is_rocm():
        torch.cuda.synchronize()

201

202
def test_compressed_tensors_no_enforce_eager(vllm_runner):
203
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
204
    with vllm_runner(model_path) as llm:
205
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
206
207
208
        assert output


209
210
211
212
213
214
215
216
217
218
@pytest.mark.parametrize(
    "model_args",
    [
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
            "channel",
        ),
    ],
)
219
@pytest.mark.parametrize(
220
221
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
222
223
224
225
226
227
def test_compressed_tensors_w8a8_dynamic_per_token(
    vllm_runner,
    model_args,
    use_aiter,
    monkeypatch,
):
228
    model_path, strategy = model_args
229

230
231
232
233
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
234
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
235
236
237

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
238
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
239
240
241
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

242
    with vllm_runner(model_path, enforce_eager=True, dtype=torch.float16) as llm:
243

244
245
246
247
248
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

249
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
250
251
252
253
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
254

255
        llm.apply_model(check_model)
256

257
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
258
259
        assert output

260

261
262
@pytest.mark.parametrize(
    "wNa16_args",
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
    [
        (
            "nm-testing/tinyllama-oneshot-w4a16-channel-v2",
            "channel",
            None,
            8,
            True,
            False,
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
            "group",
            128,
            8,
            False,
            True,
        ),
    ],
)
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform."
284
)
285
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
286
    model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
287
    with vllm_runner(model, enforce_eager=True) as llm:
288

289
290
        def check_model(model):
            layer = model.model.layers[0]
291

292
            qkv_proj = layer.self_attn.qkv_proj
293
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
294
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
295

296
            assert qkv_proj.scheme.strategy == strategy
297
            assert qkv_proj.scheme.group_size == (-1 if group is None else group)
298
299

            assert qkv_proj.scheme.pack_factor == pack_factor
300
301
            assert qkv_proj.scheme.symmetric == symmetric
            assert qkv_proj.scheme.has_g_idx == has_g_idx
302
303

        llm.apply_model(check_model)
304

305
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
306
307
        assert output

308

309
310
311
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
312
313
def test_compressed_tensors_w4a16_marlin24(vllm_runner):
    model_path = "nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t"
314
    with vllm_runner(model_path, enforce_eager=True) as llm:
315

316
317
318
319
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
320

321
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
322
323
324
325
            assert isinstance(qkv_proj.scheme, CompressedTensorsW4A16Sparse24)
            assert qkv_proj.weight_packed.dtype is torch.int32

        llm.apply_model(check_model)
326

327
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
328
        assert output
329
330
331
332


def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
333
    with vllm_runner(model_path, enforce_eager=True) as llm:
334

335
336
        def check_model(model):
            layer = model.model.layers[0]
337

338
            qkv_proj = layer.self_attn.qkv_proj
339

340
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
341
342
            assert isinstance(
                qkv_proj.scheme,
343
344
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
            )
345

346
347
348
349
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
350
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
351
352
353
354
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
355

356
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
357
        assert output
358
359


360
361
362
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
363
364
def test_compressed_tensors_kv_cache(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme"
365
366
    with vllm_runner(model_path, enforce_eager=True, kv_cache_dtype="fp8") as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=4)
367
        assert output
368
369


370
371
372
373
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
374
def _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy, format="dense"):
375
376
377
378
379
380
381
382
    assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
    assert isinstance(qkv_proj.scheme, CompressedTensors24)

    assert qkv_proj.scheme.weight_quant.strategy == weight_strategy
    assert qkv_proj.scheme.input_quant.strategy == input_strategy
    assert qkv_proj.scheme.quantized
    assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
    sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
383
    assert sparsity_map.get("Linear").format == format
384
385
386
    assert sparsity_map.get("Linear").sparsity_structure == "2:4"


387
@pytest.mark.skipif(
388
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-2of4-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-Per-Tensor-testing",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/Meta-Llama-3-8B-Instruct-FP8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
416
417
def test_compressed_tensors_2of4_quant_fp8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
418
    with vllm_runner(model, enforce_eager=True) as llm:
419

420
421
422
423
424
425
426
427
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
428

429
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
430
431
432
433
        print(output)
        assert output


434
@pytest.mark.skipif(
435
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_fp8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_fp8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_fp8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_fp8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_fp8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
465
    with vllm_runner(model, enforce_eager=True) as llm:
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.float8_e4m3fn
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

481
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_per_tok_dyn_act_int8-BitM",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-chnl_wts_tensor_act_int8-BitM",
            "channel",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_per_tok_dyn_act_int8-BitM",
            "tensor",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-gsm8k-pruned.2of4-tensor_wts_tensor_act_int8-BitM",
            "tensor",
            "tensor",
        ),
    ],
)
def test_compressed_tensors_2of4_quant_int8_compressed(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
517
    with vllm_runner(model, enforce_eager=True) as llm:
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(
                qkv_proj,
                weight_strategy,
                input_strategy,
                format="sparse-24-bitmask",
            )

        llm.apply_model(check_model)

533
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
        print(output)
        assert output


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Sparse FP8 is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
    "args_2of4",
    [
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Channel-Weight-testing",
            "channel",
            "token",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Static-testing",
            "tensor",
            "tensor",
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-INT8-Dynamic-IA-Per-Tensor-Weight-testing",
            "tensor",
            "token",
        ),
    ],
)
562
563
def test_compressed_tensors_2of4_quant_int8(vllm_runner, args_2of4):
    model, weight_strategy, input_strategy = args_2of4
564
    with vllm_runner(model, enforce_eager=True) as llm:
565

566
567
568
569
570
571
572
573
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            assert qkv_proj.scheme.weights_dtype == torch.int8
            _test_2of4_quant_models(qkv_proj, weight_strategy, input_strategy)

        llm.apply_model(check_model)
574

575
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
576
577
578
579
        print(output)
        assert output


580
581
@pytest.mark.skipif(
    not sparse_cutlass_supported(),
582
583
    reason="2of4 Sparse is not yet supported on this GPU type.",
)
584
585
@pytest.mark.parametrize(
    "args_2of4",
586
587
    [("nm-testing/TinyLlama-1.1B-Chat-v1.0-2of4-Sparse-Dense-Compressor")],
)
588
589
def test_compressed_tensors_2of4_sparse(vllm_runner, args_2of4):
    model = args_2of4
590
    with vllm_runner(model, enforce_eager=True) as llm:
591
592
593
594
595

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
596
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
597
598
599
600
601
602
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
603
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
604
605
606
607
            assert sparsity_map.get("Linear").format == "dense"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)
608

609
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
610
611
        print(output)
        assert output
612
613
614
615
616
617
618


@pytest.mark.skipif(
    not sparse_cutlass_supported(),
    reason="Cutlass is not yet supported on this GPU type.",
)
@pytest.mark.parametrize(
619
620
    "args_2of4", [("nm-testing/llama2.c-stories42M-pruned2.4-compressed")]
)
621
622
def test_compressed_tensors_2of4_sparse_compressed(vllm_runner, args_2of4):
    model = args_2of4
623
    with vllm_runner(model, enforce_eager=True) as llm:
624
625
626
627
628

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
629
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
630
631
632
633
634
635
            assert isinstance(qkv_proj.scheme, CompressedTensors24)

            assert qkv_proj.scheme.weight_quant is None
            assert qkv_proj.scheme.input_quant is None
            assert not qkv_proj.scheme.quantized
            assert qkv_proj.quant_method.quantization_config.sparsity_scheme_map
636
            sparsity_map = qkv_proj.quant_method.quantization_config.sparsity_scheme_map  # noqa: E501
637
638
639
640
641
            assert sparsity_map.get("Linear").format == "sparse-24-bitmask"
            assert sparsity_map.get("Linear").sparsity_structure == "2:4"

        llm.apply_model(check_model)

642
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
643
644
        print(output)
        assert output
645
646


647
648
649
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
650
651
@pytest.mark.parametrize(
    "args",
652
    [
653
654
        # TODO: Enable once model is available again
        # ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
655
656
657
        ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
    ],
)
658
659
def test_compressed_tensors_nvfp4(vllm_runner, args):
    model, scheme = args
660
661
662
663
664
665
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
666
667
668
669
670
671
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            if (
                isinstance(qkv_proj.scheme, scheme)
                or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
                and not cutlass_fp4_supported()
            ):
672
673
674
675
                assert True
            else:
                raise AssertionError("FP4 Scheme Mismatch")

676
677
678
            assert qkv_proj.scheme.group_size == 16

        llm.apply_model(check_model)
679
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
680
681
        print(output)
        assert output
682
683
684


@pytest.mark.skipif(
685
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
686
687
    reason="W4A8 FP8 is not yet supported on this GPU type.",
)
688
689
690
691
@pytest.mark.parametrize(
    "args",
    [("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)],
)
692
693
694
695
696
697
698
699
700
701
702
703
704
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
    model, scheme = args
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
705
                assert isinstance(proj.quant_method, CompressedTensorsLinearMethod)
706
707
708
709
710
711
712
713
                assert isinstance(proj.scheme, scheme)

                assert proj.weight_packed.dtype is torch.int32
                assert proj.weight_scale.dtype is torch.float8_e4m3fn
                assert proj.weight_chan_scale.dtype is torch.float32
                assert proj.scheme.group_size == 128

        llm.apply_model(check_model)
714
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
715
716
        print(output)
        assert output
717
718


719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize(
    "model,prompt,exp_perplexity",
    [
        (
            "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
        (
            "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
    ],
)
def test_compressed_tensors_transforms_perplexity(
    vllm_runner, model, prompt, exp_perplexity
):
740
741
742
    with vllm_runner(model, enforce_eager=True) as llm:
        perplexity = llm.generate_prompt_perplexity([prompt])[0]
        print(perplexity)
743
        assert perplexity <= exp_perplexity
744
745
746
747


def test_compressed_tensors_fp8_block_enabled(vllm_runner):
    model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
748
    with vllm_runner(model_path, enforce_eager=True) as llm:
749
750
751
752
753
754
        fp8_dtype = current_platform.fp8_dtype()

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
755
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
756
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
757
758
759
            assert isinstance(
                qkv_proj.scheme.w8a8_block_fp8_linear, W8A8BlockFp8LinearOp
            )
760
761
762
763
764
765

            assert qkv_proj.weight.dtype is fp8_dtype
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight.shape) == 2
            assert len(qkv_proj.weight_scale.shape) == 2

766
            input_quant_op = qkv_proj.scheme.w8a8_block_fp8_linear.input_quant_op
767
            assert isinstance(input_quant_op, QuantFP8)
768
769
770
771
            assert input_quant_op._forward_method in (
                input_quant_op.forward_cuda,
                input_quant_op.forward_hip,
            )
772
773
774

        llm.apply_model(check_model)

775
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
776
        assert output
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823


@pytest.mark.skipif(
    not current_platform.is_cuda(),
    reason="This test is not for non-CUDA platforms",
)
def test_compressed_tensors_moe_ignore_with_model(vllm_runner):
    """
    Integration test for MoE layer ignore functionality with a real model.

    This test would verify that when loading a compressed-tensors quantized
    MoE model where some MoE layers are in the ignore list, those layers
    use UnquantizedFusedMoEMethod while non-ignored layers use the
    quantized method.

    Expected model structure:
    - Compressed-tensors quantized MoE model (e.g., Mixtral-based)
    - Config with ignore list containing specific MoE layers
    - Multiple MoE layers where some are quantized and some are not
    """

    # model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3
    model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable"  # CT 12.2

    with vllm_runner(model_path, enforce_eager=True) as llm:

        def check_model(model):
            from vllm.model_executor.layers.fused_moe import FusedMoE
            from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (  # noqa: E501
                CompressedTensorsMoEMethod,
            )

            # Check layer 0 MoE (should be quantized)
            layer_quantized = model.model.layers[0].mlp.experts
            assert isinstance(layer_quantized, FusedMoE)
            assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod)

            # Check layer 10 MoE (should be unquantized + ignored)
            layer_unquantized = model.model.layers[3].mlp.experts
            assert isinstance(layer_unquantized, FusedMoE)
            assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod)

        llm.apply_model(check_model)

        # Verify the model can generate output
        output = llm.generate_greedy("Hello, my name is", max_tokens=4)
        assert output