test_compressed_tensors.py 21.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""Test model set-up and weight loading for llmcompressor-quantized models.
4
5
6

Run `pytest tests/quantization/test_compressed_tensors.py`.
"""
7

8
9
from unittest.mock import Mock

10
import pytest
11
import torch
12
13
14
15
16
from compressed_tensors.quantization import (
    QuantizationArgs,
    QuantizationStrategy,
    QuantizationType,
)
17

18
from tests.models.utils import check_logprobs_close
19
20
21
from vllm.model_executor.kernels.linear import (
    Fp8BlockScaledMMLinearKernel,
)
22
from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod
23
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (  # noqa: E501
24
    CompressedTensorsConfig,
25
26
27
28
29
30
31
32
33
    CompressedTensorsLinearMethod,
    CompressedTensorsW4A4Fp4,
    CompressedTensorsW4A8Fp8,
    CompressedTensorsW4A16Fp4,
    CompressedTensorsW8A8Fp8,
    CompressedTensorsW8A8Int8,
    CompressedTensorsW8A16Fp8,
    CompressedTensorsWNA16,
)
34
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
35
from vllm.model_executor.layers.quantization.utils.nvfp4_utils import (
36
37
    cutlass_fp4_supported,
)
38
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
39
from vllm.platforms import current_platform
40
from vllm.v1.attention.backends.fa_utils import get_flash_attn_version
41

42
43
44
45
46
# AITER only supports per-channel-per-channel INT8 gemm
# and per-tensor-per-tensor INT8 GEMM.
# It does not support mix precision MM and mix quantization scheme.
ROCM_AITER_SUPPORTED_INT8_MODEL = [
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
47
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
48
49
]

50
# TritonInt8ScaledMMLinearKernel only supports symmetric quantization.
51
52
53
54
55
56
57
58
ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL = [
    "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
    "nm-testing/tinyllama-oneshot-w8-channel-a8-tensor",
    "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    "nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2",
    "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
]

59

60
@pytest.fixture(scope="function", autouse=True)
61
62
63
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
64
65


66
67
@pytest.mark.parametrize(
    "model_args",
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    [
        (
            "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change",
            "tensor",
            QuantizationType.INT,
            2560,
            True,
        ),
        (
            "nm-testing/asym-w8w8-int8-static-per-tensor-tiny-llama",
            "tensor",
            QuantizationType.INT,
            2560,
            False,
        ),
    ],
)
85
def test_compressed_tensors_w8a8_static_setup(vllm_runner, model_args):
86
    model_path, strategy, quant_type, shape_0, is_symmetric = model_args
87

88
89
90
91
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
92
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
93

94
    with vllm_runner(model_path, enforce_eager=True) as llm:
95
96
97
98
99
100
101
102
103
104

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            # assert zp for symmetric and asymmetric cases
105
            def zp_valid(zp: torch.Tensor | None):
106
107
108
109
110
111
112
113
114
115
                if is_symmetric:
                    return zp is None

                return zp is not None and zp.dtype is torch.int32

            assert zp_valid(qkv_proj.input_zero_point)
            assert zp_valid(o_proj.input_zero_point)
            assert zp_valid(gate_up_proj.input_zero_point)
            assert zp_valid(down_proj.input_zero_point)

116
117
118
119
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(o_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(gate_up_proj.quant_method, CompressedTensorsLinearMethod)
            assert isinstance(down_proj.quant_method, CompressedTensorsLinearMethod)
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)

            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.scheme.is_static_input_scheme
            expected_type = torch.int8

            assert qkv_proj.weight.dtype is expected_type
            assert o_proj.weight.dtype is expected_type
            assert gate_up_proj.weight.dtype is expected_type

            if qkv_proj.scheme.strategy == "tensor":
                # Make sure it is a channelwise buffer
                # After running process_weights_after_loading
                assert len(qkv_proj.weight_scale.shape) == 2
                assert qkv_proj.weight_scale.shape[0] == shape_0
                assert qkv_proj.weight_scale.shape[1] == 1
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert qkv_proj.input_scale.dtype is torch.float32

        llm.apply_model(check_model)
140

141
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
142
143
        assert output

144

145
146
147
148
149
150
@pytest.mark.parametrize(
    "model_path",
    [
        "neuralmagic/Llama-3.2-1B-quantized.w8a8",
    ],
)
151
@pytest.mark.parametrize("max_tokens", [4])
152
@pytest.mark.parametrize("num_logprobs", [10])
153
@pytest.mark.parametrize(
154
155
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
156
157
158
159
160
161
162
def test_compressed_tensors_w8a8_logprobs(
    hf_runner,
    vllm_runner,
    example_prompts,
    model_path,
    max_tokens,
    num_logprobs,
163
164
    use_aiter,
    monkeypatch,
165
):
166
167
168
169
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
170
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
171
172
173

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
174
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
175
176
177
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

178
179
    dtype = "bfloat16"

180
181
    # skip language translation prompt for the static per tensor models
    if model_path in (
182
183
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Sym",
        "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Static-Per-Tensor-Asym",
184
    ):
185
186
        example_prompts = example_prompts[0:-1]

187
188
    with hf_runner(model_path, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
189
190
            example_prompts, max_tokens, num_logprobs
        )
191

192
    with vllm_runner(model_path, dtype=dtype, enforce_eager=True) as vllm_model:
193
        vllm_outputs = vllm_model.generate_greedy_logprobs(
194
195
            example_prompts, max_tokens, num_logprobs
        )
196
197
198
199
200
201
202
203

    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )

204
    if current_platform.is_rocm():
205
        torch.accelerator.synchronize()
206

207

208
def test_compressed_tensors_no_enforce_eager(vllm_runner):
209
    model_path = "nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
210
    with vllm_runner(model_path) as llm:
211
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
212
213
214
        assert output


215
216
217
218
219
220
221
222
223
224
@pytest.mark.parametrize(
    "model_args",
    [
        ("nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2", "tensor"),
        (
            "nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2",
            "channel",
        ),
    ],
)
225
@pytest.mark.parametrize(
226
227
    "use_aiter", [True, False] if current_platform.is_rocm() else [False]
)
228
229
230
231
232
233
def test_compressed_tensors_w8a8_dynamic_per_token(
    vllm_runner,
    model_args,
    use_aiter,
    monkeypatch,
):
234
    model_path, strategy = model_args
235

236
237
238
239
    if (
        current_platform.is_rocm()
        and model_path not in ROCM_TRITON_SCALED_MM_SUPPORTED_INT8_MODEL
    ):
240
        pytest.skip(f"Skip model {model_path} as it is not supported on ROCm.")
241
242
243

    if use_aiter:
        if model_path not in ROCM_AITER_SUPPORTED_INT8_MODEL:
244
            pytest.skip(f"Skip model {model_path} as it is not support by aiter.")
245
246
247
        # this will enable VLLM_ROCM_USE_AITER_LINEAR
        monkeypatch.setenv("VLLM_ROCM_USE_AITER", "1")

248
    with vllm_runner(model_path, enforce_eager=True, dtype=torch.float16) as llm:
249

250
251
252
253
254
        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

255
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
256
257
258
259
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Int8)
            assert not qkv_proj.scheme.is_static_input_scheme
            assert qkv_proj.scheme.strategy == strategy
            assert qkv_proj.weight.dtype is torch.int8
260

261
        llm.apply_model(check_model)
262

263
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
264
265
        assert output

266

267
268
@pytest.mark.parametrize(
    "wNa16_args",
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
    [
        (
            "nm-testing/tinyllama-oneshot-w4a16-channel-v2",
            "channel",
            None,
            8,
            True,
            False,
        ),
        (
            "nm-testing/TinyLlama-1.1B-Chat-v1.0-W4A16-G128-Asym-Updated-ActOrder",
            "group",
            128,
            8,
            False,
            True,
        ),
    ],
)
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="The tests are skipped on non-CUDA platform."
290
)
291
def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
292
    model, strategy, group, pack_factor, symmetric, has_g_idx = wNa16_args
293
    with vllm_runner(model, enforce_eager=True) as llm:
294

295
296
        def check_model(model):
            layer = model.model.layers[0]
297

298
            qkv_proj = layer.self_attn.qkv_proj
299
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
300
            assert isinstance(qkv_proj.scheme, CompressedTensorsWNA16)
301

302
            assert qkv_proj.scheme.strategy == strategy
303
            assert qkv_proj.scheme.group_size == (-1 if group is None else group)
304
305

            assert qkv_proj.scheme.pack_factor == pack_factor
306
307
            assert qkv_proj.scheme.symmetric == symmetric
            assert qkv_proj.scheme.has_g_idx == has_g_idx
308
309

        llm.apply_model(check_model)
310

311
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
312
313
        assert output

314

315
316
def test_compressed_tensors_fp8(vllm_runner):
    model_path = "nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test"
317
    with vllm_runner(model_path, enforce_eager=True) as llm:
318

319
320
        def check_model(model):
            layer = model.model.layers[0]
321

322
            qkv_proj = layer.self_attn.qkv_proj
323

324
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
325
326
            assert isinstance(
                qkv_proj.scheme,
327
328
                (CompressedTensorsW8A8Fp8, CompressedTensorsW8A16Fp8),
            )
329

330
331
332
333
            assert qkv_proj.input_scale.dtype is torch.float32

            if isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
334
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
335
336
337
338
                assert qkv_proj.weight_scale.dtype is torch.float32
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
339

340
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
341
        assert output
342
343


344
345
346
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
def test_compressed_tensors_kv_cache_fp8_per_tensor(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-tensor"
    with vllm_runner(model_path) as llm:
        output = llm.generate_greedy("Hello world!", max_tokens=4)
        assert output


@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
def test_compressed_tensors_kv_cache_fp8_per_attn_head(vllm_runner):
    model_path = "nm-testing/TinyLlama-1.1B-Chat-v1.0-kvcache-fp8-attn_head"
    try:
        fa_version = get_flash_attn_version()
    except Exception:
        pytest.skip("This test requires FlashAttention backend.")
    if fa_version is None or fa_version < 3:
        pytest.skip("This test requires FlashAttention version >= 3.")

    with vllm_runner(model_path, attention_config={"backend": "FLASH_ATTN"}) as llm:
367
        output = llm.generate_greedy("Hello world!", max_tokens=4)
368
        assert output
369
370


371
372
@pytest.mark.parametrize(
    "args",
373
    [
374
375
        # TODO: Enable once model is available again
        # ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4A16", CompressedTensorsW4A16Fp4),
376
377
378
        ("nm-testing/TinyLlama-1.1B-Chat-v1.0-NVFP4", CompressedTensorsW4A4Fp4),
    ],
)
379
380
def test_compressed_tensors_nvfp4(vllm_runner, args):
    model, scheme = args
381
382
383
384
385
386
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
387
388
389
390
391
392
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
            if (
                isinstance(qkv_proj.scheme, scheme)
                or isinstance(qkv_proj.scheme, CompressedTensorsW4A16Fp4)
                and not cutlass_fp4_supported()
            ):
393
394
395
396
                assert True
            else:
                raise AssertionError("FP4 Scheme Mismatch")

397
398
399
            assert qkv_proj.scheme.group_size == 16

        llm.apply_model(check_model)
400
        output = llm.generate_greedy(["Hello my name is"], max_tokens=4)
401
402
        print(output)
        assert output
403
404
405


@pytest.mark.skipif(
406
    not current_platform.is_cuda() or not current_platform.has_device_capability(90),
407
408
    reason="W4A8 FP8 is not yet supported on this GPU type.",
)
409
410
411
412
@pytest.mark.parametrize(
    "args",
    [("czhu-cohere/TinyLlama-1.1B-Chat-v1.0-W4A8-e2e", CompressedTensorsW4A8Fp8)],
)
413
414
415
416
417
418
419
420
421
422
423
424
425
def test_compressed_tensors_w4a8_fp8(vllm_runner, args):
    model, scheme = args
    with vllm_runner(model, enforce_eager=True) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
            o_proj = layer.self_attn.o_proj
            gate_up_proj = layer.mlp.gate_up_proj
            down_proj = layer.mlp.down_proj

            for proj in (qkv_proj, o_proj, gate_up_proj, down_proj):
426
                assert isinstance(proj.quant_method, CompressedTensorsLinearMethod)
427
428
429
430
431
432
433
434
                assert isinstance(proj.scheme, scheme)

                assert proj.weight_packed.dtype is torch.int32
                assert proj.weight_scale.dtype is torch.float8_e4m3fn
                assert proj.weight_chan_scale.dtype is torch.float32
                assert proj.scheme.group_size == 128

        llm.apply_model(check_model)
435
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
436
437
        print(output)
        assert output
438
439


440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
@pytest.mark.skipif(
    not current_platform.is_cuda(), reason="This test is skipped on non-CUDA platform."
)
@pytest.mark.parametrize(
    "model,prompt,exp_perplexity",
    [
        (
            "nm-testing/Llama-3.2-1B-Instruct-spinquantR1R2R4-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
        (
            "nm-testing/Llama-3.2-1B-Instruct-quip-w4a16",
            "Flat is better than nested.\nSparse is better than dense.",
            150.0,
        ),
    ],
)
def test_compressed_tensors_transforms_perplexity(
    vllm_runner, model, prompt, exp_perplexity
):
461
462
463
    with vllm_runner(model, enforce_eager=True) as llm:
        perplexity = llm.generate_prompt_perplexity([prompt])[0]
        print(perplexity)
464
        assert perplexity <= exp_perplexity
465
466
467
468


def test_compressed_tensors_fp8_block_enabled(vllm_runner):
    model_path = "RedHatAI/Qwen3-0.6B-FP8-BLOCK"
469
    with vllm_runner(model_path, enforce_eager=True) as llm:
470
471
472
473
474
475
        fp8_dtype = current_platform.fp8_dtype()

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj
476
            assert isinstance(qkv_proj.quant_method, CompressedTensorsLinearMethod)
477
            assert isinstance(qkv_proj.scheme, CompressedTensorsW8A8Fp8)
478
            assert isinstance(qkv_proj.scheme.fp8_linear, Fp8BlockScaledMMLinearKernel)
479
480
481
482
483
484

            assert qkv_proj.weight.dtype is fp8_dtype
            assert qkv_proj.weight_scale.dtype is torch.float32
            assert len(qkv_proj.weight.shape) == 2
            assert len(qkv_proj.weight_scale.shape) == 2

485
            input_quant_op = qkv_proj.scheme.fp8_linear.quant_fp8
486
            assert isinstance(input_quant_op, QuantFP8)
487
488
489
490
            assert input_quant_op._forward_method in (
                input_quant_op.forward_cuda,
                input_quant_op.forward_hip,
            )
491
492
493

        llm.apply_model(check_model)

494
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
495
        assert output
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542


@pytest.mark.skipif(
    not current_platform.is_cuda(),
    reason="This test is not for non-CUDA platforms",
)
def test_compressed_tensors_moe_ignore_with_model(vllm_runner):
    """
    Integration test for MoE layer ignore functionality with a real model.

    This test would verify that when loading a compressed-tensors quantized
    MoE model where some MoE layers are in the ignore list, those layers
    use UnquantizedFusedMoEMethod while non-ignored layers use the
    quantized method.

    Expected model structure:
    - Compressed-tensors quantized MoE model (e.g., Mixtral-based)
    - Config with ignore list containing specific MoE layers
    - Multiple MoE layers where some are quantized and some are not
    """

    # model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only" # CT 12.3
    model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable"  # CT 12.2

    with vllm_runner(model_path, enforce_eager=True) as llm:

        def check_model(model):
            from vllm.model_executor.layers.fused_moe import FusedMoE
            from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (  # noqa: E501
                CompressedTensorsMoEMethod,
            )

            # Check layer 0 MoE (should be quantized)
            layer_quantized = model.model.layers[0].mlp.experts
            assert isinstance(layer_quantized, FusedMoE)
            assert isinstance(layer_quantized.quant_method, CompressedTensorsMoEMethod)

            # Check layer 10 MoE (should be unquantized + ignored)
            layer_unquantized = model.model.layers[3].mlp.experts
            assert isinstance(layer_unquantized, FusedMoE)
            assert isinstance(layer_unquantized.quant_method, UnquantizedFusedMoEMethod)

        llm.apply_model(check_model)

        # Verify the model can generate output
        output = llm.generate_greedy("Hello, my name is", max_tokens=4)
        assert output
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565


def test_w4a16_moe_torch_compile(vllm_runner):
    """Regression test: MoE quant_config must be initialized inside the
    moe_forward custom op, not just in forward_native which is compiled by
    Dynamo (attribute mutations are not replayed at runtime).

    Without the fix in _moe_forward/_moe_forward_shared, this hits:
        AssertionError: Hidden size mismatch 2048 != 1024
    because use_int4_w4a16 is False (moe_quant_config stays None).
    """
    model_path = "nm-testing/tinysmokeqwen3moe-W4A16-first-only-CTstable"

    with vllm_runner(
        model_path,
        enforce_eager=False,
        max_model_len=256,
        compilation_config={
            "cudagraph_mode": "NONE",
        },
    ) as llm:
        output = llm.generate_greedy("Hi", max_tokens=1)
        assert output
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634


def _make_ct_config(*, target: str = "Linear") -> CompressedTensorsConfig:
    """Build a minimal CompressedTensorsConfig with INT8 channel quant."""
    weight_quant = QuantizationArgs(
        num_bits=8,
        type=QuantizationType.INT,
        strategy=QuantizationStrategy.CHANNEL,
        symmetric=True,
        dynamic=False,
    )
    return CompressedTensorsConfig(
        target_scheme_map={
            target: {
                "weights": weight_quant,
                "input_activations": None,
                "format": "pack-quantized",
            }
        },
        ignore=[],
        quant_format="pack-quantized",
        sparsity_scheme_map={},
        sparsity_ignore_list=[],
    )


def test_get_quant_method_returns_linear_method_for_parallel_lm_head():
    """ParallelLMHead whose name matches a target must get a quantised method."""
    config = _make_ct_config(target="re:.*lm_head")
    mock_lm_head = Mock(spec=ParallelLMHead)
    mock_lm_head.__class__ = ParallelLMHead

    method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")

    assert isinstance(method, CompressedTensorsLinearMethod), (
        f"Expected CompressedTensorsLinearMethod, got {type(method).__name__}"
    )


def test_get_quant_method_returns_none_for_ignored_parallel_lm_head():
    """ParallelLMHead on the ignore list should be left unquantized (None)."""
    config = _make_ct_config(target="re:.*lm_head")
    config.ignore = ["re:.*lm_head"]
    mock_lm_head = Mock(spec=ParallelLMHead)
    mock_lm_head.__class__ = ParallelLMHead

    method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")

    assert method is None, (
        f"Expected None for ignored ParallelLMHead, got {type(method).__name__}"
    )


def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head():
    """ParallelLMHead with target='Linear' (typical real model) must not crash.

    Most compressed-tensors models only target 'Linear'. ParallelLMHead does
    not match that target, so get_quant_method should return None (unquantized)
    instead of raising ValueError.
    """
    config = _make_ct_config(target="Linear")
    mock_lm_head = Mock(spec=ParallelLMHead)
    mock_lm_head.__class__ = ParallelLMHead

    method = config.get_quant_method(mock_lm_head, prefix="model.lm_head")

    assert method is None, (
        f"Expected None for unmatched ParallelLMHead, got {type(method).__name__}"
    )