test_quark.py 12.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
import importlib.metadata
from dataclasses import dataclass
12
from importlib.util import find_spec
13
14
15

import huggingface_hub
import lm_eval
16
import pytest
17
import torch
18
from packaging import version
19
20

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
21
22
23
24
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
25
26
27
from vllm.model_executor.layers.quantization.quark.quark_moe import (  # noqa: E501
    QuarkW8A8Int8MoEMethod,
)
28
from vllm.platforms import current_platform
29

30
31
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

32
33
34
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
QUARK_MXFP4_MIN_VERSION = "0.8.99"

35
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
36
    importlib.metadata.version("amd-quark")
37
) >= version.parse(QUARK_MXFP4_MIN_VERSION)
38

39
40
DEVICE_TYPE = current_platform.device_type

41
if QUARK_MXFP4_AVAILABLE:
42
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
43
44
45
46
47
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
48
49
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
50
51
52
53
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

54

55
@pytest.fixture(scope="function", autouse=True)
56
57
58
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
59
60


61
62
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
63
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
64
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
65
    with vllm_runner(
66
67
68
69
        model_path,
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
        tensor_parallel_size=tp,
70
    ) as llm:
71

72
73
        def check_model(model):
            layer = model.model.layers[0]
74

75
            qkv_proj = layer.self_attn.qkv_proj
76

77
78
79
80
81
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
82
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
83
84
85
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
86

87
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
88
        assert output
89
90


91
@pytest.mark.parametrize("tp", [1])
92
93
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
94
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
95
96
97
98
99
100
101
102
103
104
105

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
106
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
107
108
109
110
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

111
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
112
        assert output
113
114


115
@pytest.mark.parametrize("tp", [1])
116
117
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
118
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
119
120
121
122
123
124
125
126
127
128
129

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

130
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
131
        assert output
132
133


134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
@pytest.mark.parametrize("tp", [1])
def test_quark_int8_w8a8_moe(vllm_runner, tp):
    """Test W8A8 INT8 MoE quantization with a tiny Qwen3 MoE model."""
    model_path = "nameistoken/tiny-qwen3-moe-w8a8-int8-quark"
    with vllm_runner(
        model_path,
        enforce_eager=True,
        tensor_parallel_size=tp,
        gpu_memory_utilization=0.1,
    ) as llm:

        def check_model(model):
            layer = model.model.layers[0]
            # MoE experts should use QuarkW8A8Int8MoEMethod
            moe = layer.mlp.experts
            assert isinstance(moe.quant_method, QuarkW8A8Int8MoEMethod), (
                f"Expected QuarkW8A8Int8MoEMethod, got {type(moe.quant_method)}"
            )
            # Non-MoE linear layers should use QuarkW8A8Int8
            qkv_proj = layer.self_attn.qkv_proj
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello", max_tokens=4)
        assert output


162
163
164
165
166
167
168
def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
169
        "gpu_memory_utilization": 0.1,
170
    }
171
172
173
174
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
175

176
177
178
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

179
180
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
181
182
183
184
185

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
186
187
188


@dataclass
189
class AccuracyTestConfig:
190
191
192
    model_name: str
    excepted_value: float

193
194
195
    def get_model_args(
        self,
        tp_size: int,
196
197
        model_max_len: int | None = None,
        kwargs: dict | None = None,
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
217
    # Private model.
218
    AccuracyTestConfig(
219
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
220
221
        excepted_value=0.96,
    ),
222
223
]

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


239
240
241
242
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
243
244
245
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
246
247
248
    device_count = torch.accelerator.device_count()
    if device_count < tp_size:
        pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
249
250
251
252

    task = "wikitext"
    rtol = 0.1

253
    # Smaller cudagraph_capture_sizes to speed up the test.
254
255
256
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
257
            tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
258
259
260
261
262
263
264
265
266
267
268
269
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

270

271
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
272
273
274
275
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
276
277
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
278
279
    reason="Read access to huggingface.co/amd is required for this test.",
)
280
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
281
282
283
    device_count = torch.accelerator.device_count()
    if device_count < 8:
        pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
284
285
286
287
288
289

    task = "gsm8k"
    rtol = 0.03

    results = lm_eval.simple_evaluate(
        model="vllm",
290
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
291
292
293
294
295
296
297
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
298
299
300
301
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
302
303


304
305
306
307
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
308
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
309
310
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
311
312
313
    torch.manual_seed(0)

    hidden_size = 64 * 32
314
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device=DEVICE_TYPE) - 0.5) * 2
315
    for i in range(hidden_size // 32):
316
317
318
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
319
320
321
322
323
324
325
326

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
327
328
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
329

330
331
332
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
333
334


335
336
337
338
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
339
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
340
341
342
343
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
344
345
346
347
348
349
350
351
352
353
354
355
356
357
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
358
        device=DEVICE_TYPE,
359
360
    )

361
    observer = qspec.observer_cls(qspec, device=DEVICE_TYPE)
362
363
364
365

    hidden_size = 512
    shape = (11008, hidden_size)

366
    w = (torch.rand(shape, device=DEVICE_TYPE, dtype=float_dtype) - 0.5) * 2
367
368
369

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
370
371
372
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
373
374
375
376
377

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

378
    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to(DEVICE_TYPE)
379
380
381
382
383
384
385
386
387
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)