test_quark.py 11.1 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
import importlib.metadata
from dataclasses import dataclass
12
from importlib.util import find_spec
13
14
15

import huggingface_hub
import lm_eval
16
import pytest
17
import torch
18
from packaging import version
19
20

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
21
22
23
24
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
25
from vllm.platforms import current_platform
26

27
28
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

29
30
31
# Minimum amd-quark version for MXFP4/OCP_MX tests (single source of truth).
QUARK_MXFP4_MIN_VERSION = "0.8.99"

32
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
33
    importlib.metadata.version("amd-quark")
34
) >= version.parse(QUARK_MXFP4_MIN_VERSION)
35
36

if QUARK_MXFP4_AVAILABLE:
37
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
38
39
40
41
42
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
43
44
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
45
46
47
48
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

49

50
@pytest.fixture(scope="function", autouse=True)
51
52
53
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
54
55


56
57
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
58
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
59
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
60
    with vllm_runner(
61
62
63
64
        model_path,
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
        tensor_parallel_size=tp,
65
    ) as llm:
66

67
68
        def check_model(model):
            layer = model.model.layers[0]
69

70
            qkv_proj = layer.self_attn.qkv_proj
71

72
73
74
75
76
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
77
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
78
79
80
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
81

82
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
83
        assert output
84
85


86
@pytest.mark.parametrize("tp", [1])
87
88
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
89
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
90
91
92
93
94
95
96
97
98
99
100

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
101
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
102
103
104
105
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

106
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
107
        assert output
108
109


110
@pytest.mark.parametrize("tp", [1])
111
112
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
113
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
114
115
116
117
118
119
120
121
122
123
124

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

125
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
126
        assert output
127
128
129
130
131
132
133
134
135


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
136
        "gpu_memory_utilization": 0.1,
137
    }
138
139
140
141
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
142

143
144
145
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

146
147
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
148
149
150
151
152

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
153
154
155


@dataclass
156
class AccuracyTestConfig:
157
158
159
    model_name: str
    excepted_value: float

160
161
162
    def get_model_args(
        self,
        tp_size: int,
163
164
        model_max_len: int | None = None,
        kwargs: dict | None = None,
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
184
    # Private model.
185
    AccuracyTestConfig(
186
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
187
188
        excepted_value=0.96,
    ),
189
190
]

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


206
207
208
209
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
210
211
212
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
213
214
215
    device_count = torch.accelerator.device_count()
    if device_count < tp_size:
        pytest.skip(f"This test requires >={tp_size} gpus, got only {device_count}")
216
217
218
219

    task = "wikitext"
    rtol = 0.1

220
    # Smaller cudagraph_capture_sizes to speed up the test.
221
222
223
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
224
            tp_size=tp_size, kwargs={"cudagraph_capture_sizes": [16]}
225
226
227
228
229
230
231
232
233
234
235
236
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

237

238
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
239
240
241
242
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
243
244
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
245
246
    reason="Read access to huggingface.co/amd is required for this test.",
)
247
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
248
249
250
    device_count = torch.accelerator.device_count()
    if device_count < 8:
        pytest.skip(f"This test requires >=8 gpus, got only {device_count}")
251
252
253
254
255
256

    task = "gsm8k"
    rtol = 0.03

    results = lm_eval.simple_evaluate(
        model="vllm",
257
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
258
259
260
261
262
263
264
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
265
266
267
268
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
269
270


271
272
273
274
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
275
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
276
277
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
278
279
280
    torch.manual_seed(0)

    hidden_size = 64 * 32
281
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
282
    for i in range(hidden_size // 32):
283
284
285
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
286
287
288
289
290
291
292
293

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
294
295
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
296

297
298
299
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
300
301


302
303
304
305
@pytest.mark.skipif(
    not QUARK_MXFP4_AVAILABLE,
    reason=f"amd-quark>={QUARK_MXFP4_MIN_VERSION} is not available",
)
306
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
307
308
309
310
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
337
338
339
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)