test_quark.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6
7

See also `tests/kernels/moe/test_mxfp4_moe.py`.
8
9
"""

10
11
12
import importlib.metadata
import os
from dataclasses import dataclass
13
from importlib.util import find_spec
14
from typing import Optional
15
16
17

import huggingface_hub
import lm_eval
18
import pytest
19
import torch
20
from packaging import version
21
22

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
23
24
25
26
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
27
from vllm.platforms import current_platform
28

29
30
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

31
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
32
33
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
34
35

if QUARK_MXFP4_AVAILABLE:
36
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
37
38
39
40
41
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
42
43
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
44
45
46
47
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

48

49
@pytest.fixture(scope="function", autouse=True)
50
51
52
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
53
54


55
56
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
57
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
58
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
59
60
61
    with vllm_runner(
        model_path, kv_cache_dtype=kv_cache_dtype, tensor_parallel_size=tp
    ) as llm:
62

63
64
        def check_model(model):
            layer = model.model.layers[0]
65

66
            qkv_proj = layer.self_attn.qkv_proj
67

68
69
70
71
72
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
73
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
74
75
76
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
77
78
79

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
80
81


82
@pytest.mark.parametrize("tp", [1])
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
97
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
98
99
100
101
102
103
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
104
105


106
@pytest.mark.parametrize("tp", [1])
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
123
124
125
126
127
128
129
130
131


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
132
        "gpu_memory_utilization": 0.1,
133
    }
134
135
136
137
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
138

139
140
141
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

142
143
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
144
145
146
147
148

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
149
150
151


@dataclass
152
class AccuracyTestConfig:
153
154
155
    model_name: str
    excepted_value: float

156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
    def get_model_args(
        self,
        tp_size: int,
        model_max_len: Optional[int] = None,
        kwargs: Optional[dict] = None,
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
180
    # Private model.
181
    AccuracyTestConfig(
182
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
183
184
        excepted_value=0.96,
    ),
185
186
]

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
    if torch.cuda.device_count() < tp_size:
        pytest.skip(
            f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
        )

    task = "wikitext"
    rtol = 0.1

    # Smaller cuda_graph_sizes to speed up the test.
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
            tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

231

232
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
233
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
234
235
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
236
237
    reason="Read access to huggingface.co/amd is required for this test.",
)
238
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
239
240
241
242
243
244
245
246
247
248
249
250
    if torch.cuda.device_count() < 8:
        pytest.skip(
            f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
        )

    task = "gsm8k"
    rtol = 0.03

    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

    results = lm_eval.simple_evaluate(
        model="vllm",
251
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
252
253
254
255
256
257
258
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
259
260
261
262
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
263
264
265
266

    del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]


267
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
268
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
269
270
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
271
272
273
    torch.manual_seed(0)

    hidden_size = 64 * 32
274
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
275
    for i in range(hidden_size // 32):
276
277
278
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
279
280
281
282
283
284
285
286

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
287
288
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
289

290
291
292
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
293
294


295
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
296
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
297
298
299
300
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
327
328
329
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)