test_quark.py 12.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
import importlib.metadata
from dataclasses import dataclass
12
from importlib.util import find_spec
13
14
15

import huggingface_hub
import lm_eval
16
import pytest
17
import torch
18
from packaging import version
19
20

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
21
22
23
24
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
25
26
27
from vllm.model_executor.layers.quantization.quark.quark_moe import (  # noqa: E501
    QuarkW8A8Int8MoEMethod,
)
28
from vllm.platforms import current_platform
29

30
31
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

32
33
34
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
QUARK_MXFP4_MIN_VERSION = "0.8.99"

35
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
36
    importlib.metadata.version("amd-quark")
37
) >= version.parse(QUARK_MXFP4_MIN_VERSION)
38
39

if QUARK_MXFP4_AVAILABLE:
40
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
41
42
43
44
45
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
46
47
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
48
49
50
51
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

52

53
@pytest.fixture(scope="function", autouse=True)
54
55
56
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
57
58


59
60
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
61
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
62
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
63
    with vllm_runner(
64
65
66
67
        model_path,
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
        tensor_parallel_size=tp,
68
    ) as llm:
69

70
71
        def check_model(model):
            layer = model.model.layers[0]
72

73
            qkv_proj = layer.self_attn.qkv_proj
74

75
76
77
78
79
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
80
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
81
82
83
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
84

85
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
86
        assert output
87
88


89
@pytest.mark.parametrize("tp", [1])
90
91
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
92
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
93
94
95
96
97
98
99
100
101
102
103

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
104
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
105
106
107
108
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

109
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
110
        assert output
111
112


113
@pytest.mark.parametrize("tp", [1])
114
115
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
116
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
117
118
119
120
121
122
123
124
125
126
127

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

128
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
129
        assert output
130
131


132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
@pytest.mark.parametrize("tp", [1])
def test_quark_int8_w8a8_moe(vllm_runner, tp):
    """Test W8A8 INT8 MoE quantization with a tiny Qwen3 MoE model."""
    model_path = "nameistoken/tiny-qwen3-moe-w8a8-int8-quark"
    with vllm_runner(
        model_path,
        enforce_eager=True,
        tensor_parallel_size=tp,
        gpu_memory_utilization=0.1,
    ) as llm:

        def check_model(model):
            layer = model.model.layers[0]
            # MoE experts should use QuarkW8A8Int8MoEMethod
            moe = layer.mlp.experts
            assert isinstance(moe.quant_method, QuarkW8A8Int8MoEMethod), (
                f"Expected QuarkW8A8Int8MoEMethod, got {type(moe.quant_method)}"
            )
            # Non-MoE linear layers should use QuarkW8A8Int8
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello", max_tokens=4)
        assert output


160
161
162
163
164
165
166
def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
167
        "gpu_memory_utilization": 0.1,
168
    }
169
170
171
172
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
173

174
175
176
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

177
178
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
179
180
181
182
183

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
184
185
186


@dataclass
187
class AccuracyTestConfig:
188
189
190
    model_name: str
    excepted_value: float

191
192
193
    def get_model_args(
        self,
        tp_size: int,
194
195
        model_max_len: int | None = None,
        kwargs: dict | None = None,
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
215
    # Private model.
216
    AccuracyTestConfig(
217
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
218
219
        excepted_value=0.96,
    ),
220
221
]

222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


237
238
239
240
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
241
242
243
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
244
245
246
    device_count = torch.accelerator.device_count()
    if device_count < tp_size:
        pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
247
248
249
250

    task = "wikitext"
    rtol = 0.1

251
    # Smaller cudagraph_capture_sizes to speed up the test.
252
253
254
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
255
            tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
256
257
258
259
260
261
262
263
264
265
266
267
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

268

269
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
270
271
272
273
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
274
275
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
276
277
    reason="Read access to huggingface.co/amd is required for this test.",
)
278
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
279
280
281
    device_count = torch.accelerator.device_count()
    if device_count < 8:
        pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
282
283
284
285
286
287

    task = "gsm8k"
    rtol = 0.03

    results = lm_eval.simple_evaluate(
        model="vllm",
288
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
289
290
291
292
293
294
295
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
296
297
298
299
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
300
301


302
303
304
305
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
306
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
307
308
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
309
310
311
    torch.manual_seed(0)

    hidden_size = 64 * 32
312
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
313
    for i in range(hidden_size // 32):
314
315
316
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
317
318
319
320
321
322
323
324

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
325
326
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
327

328
329
330
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
331
332


333
334
335
336
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
337
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
338
339
340
341
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
368
369
370
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)