test_quark.py 10.9 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
6

7
See also `tests/kernels/moe/test_ocp_mx_moe.py`.
8
9
"""

10
11
12
import importlib.metadata
import os
from dataclasses import dataclass
13
from importlib.util import find_spec
14
15
16

import huggingface_hub
import lm_eval
17
import pytest
18
import torch
19
from packaging import version
20
21

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
22
23
24
25
    QuarkLinearMethod,
    QuarkW8A8Fp8,
    QuarkW8A8Int8,
)
26
from vllm.platforms import current_platform
27

28
29
from .reference_mxfp4 import dq_mxfp4_torch, qdq_mxfp4_torch

30
QUARK_MXFP4_AVAILABLE = find_spec("quark") is not None and version.parse(
31
32
    importlib.metadata.version("amd-quark")
) >= version.parse("0.8.99")
33
34

if QUARK_MXFP4_AVAILABLE:
35
    from quark.torch.export.nn.modules.realquantizer import StaticScaledRealQuantizer
36
37
38
39
40
    from quark.torch.kernel import mx as mx_kernel
    from quark.torch.quantization.config.config import FP4PerGroupSpec

try:
    huggingface_hub.list_repo_refs(
41
42
        "amd/Llama-3.3-70B-Instruct-WMXFP4-AMXFP4-KVFP8-Scale-UINT8-SQ"
    )
43
44
45
46
    HF_HUB_AMD_ORG_ACCESS = True
except huggingface_hub.errors.RepositoryNotFoundError:
    HF_HUB_AMD_ORG_ACCESS = False

47

48
@pytest.fixture(scope="function", autouse=True)
49
50
51
def enable_pickle(monkeypatch):
    """`LLM.apply_model` requires pickling a function."""
    monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
52
53


54
55
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
@pytest.mark.parametrize("tp", [1])
56
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
57
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
58
    with vllm_runner(
59
60
61
62
        model_path,
        enforce_eager=True,
        kv_cache_dtype=kv_cache_dtype,
        tensor_parallel_size=tp,
63
    ) as llm:
64

65
66
        def check_model(model):
            layer = model.model.layers[0]
67

68
            qkv_proj = layer.self_attn.qkv_proj
69

70
71
72
73
74
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
75
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
76
77
78
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
79

80
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
81
        assert output
82
83


84
@pytest.mark.parametrize("tp", [1])
85
86
def test_quark_fp8_w_per_channel_a_per_token(vllm_runner, tp):
    model_path = "amd/Qwen2.5-1.5B-Instruct-ptpc-Quark-ts"
87
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
88
89
90
91
92
93
94
95
96
97
98

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
99
                assert qkv_proj.weight_scale.shape[0] == qkv_proj.weight.shape[1]
100
101
102
103
                assert qkv_proj.weight_scale.shape[1] == 1

        llm.apply_model(check_model)

104
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
105
        assert output
106
107


108
@pytest.mark.parametrize("tp", [1])
109
110
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
111
    with vllm_runner(model_path, enforce_eager=True, tensor_parallel_size=tp) as llm:
112
113
114
115
116
117
118
119
120
121
122

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

123
        output = llm.generate_greedy("Hello my name is", max_tokens=4)
124
        assert output
125
126
127
128
129
130
131
132
133


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
134
        "gpu_memory_utilization": 0.1,
135
    }
136
137
138
139
    with (
        vllm_runner(quark_model_id, **llm_kwargs) as quark_handle,
        vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle,
    ):
140

141
142
143
        def get_state_dict(model):
            return {k: v.cpu() for k, v in model.state_dict().items()}

144
145
        (quark_state_dict,) = quark_handle.apply_model(get_state_dict)
        (fp8_state_dict,) = fp8_handle.apply_model(get_state_dict)
146
147
148
149
150

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])
151
152
153


@dataclass
154
class AccuracyTestConfig:
155
156
157
    model_name: str
    excepted_value: float

158
159
160
    def get_model_args(
        self,
        tp_size: int,
161
162
        model_max_len: int | None = None,
        kwargs: dict | None = None,
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
    ) -> dict:
        if kwargs is None:
            kwargs = {}

        model_args = {
            "pretrained": self.model_name,
            "dtype": "auto",
            "add_bos_token": True,
            "tensor_parallel_size": tp_size,
            "gpu_memory_utilization": 0.7,
            **kwargs,
        }
        if model_max_len is not None:
            model_args["max_model_len"] = model_max_len

        return model_args


GSM8K_ACCURACY_CONFIGS = [
182
    # Private model.
183
    AccuracyTestConfig(
184
        model_name="amd/DeepSeek-R1-WMXFP4-AMXFP4-Scale-UINT8-MoE-Quant",
185
186
        excepted_value=0.96,
    ),
187
188
]

189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
WIKITEXT_ACCURACY_CONFIGS = [
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp4_a_fp6_e2m3",
        excepted_value=11.3,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen1.5_moe_a2.7b_chat_w_fp6_e3m2_a_fp6_e3m2",
        excepted_value=10.6,
    ),
    AccuracyTestConfig(
        model_name="fxmarty/qwen_1.5-moe-a2.7b-mxfp4", excepted_value=12.4
    ),
]


@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
@pytest.mark.parametrize("config", WIKITEXT_ACCURACY_CONFIGS)
@pytest.mark.parametrize("tp_size", [1, 2])
def test_ocp_mx_wikitext_correctness(config: AccuracyTestConfig, tp_size: int):
    if torch.cuda.device_count() < tp_size:
        pytest.skip(
            f"This test requires >={tp_size} gpus, got only {torch.cuda.device_count()}"
        )

    task = "wikitext"
    rtol = 0.1

    # Smaller cuda_graph_sizes to speed up the test.
    results = lm_eval.simple_evaluate(
        model="vllm",
        model_args=config.get_model_args(
            tp_size=tp_size, kwargs={"cuda_graph_sizes": [16]}
        ),
        tasks=task,
        batch_size=64,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["word_perplexity,none"]
    assert (
        measured_value < EXPECTED_VALUE + rtol
        and measured_value > EXPECTED_VALUE - rtol
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"

233

234
@pytest.mark.parametrize("config", GSM8K_ACCURACY_CONFIGS)
235
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
236
237
@pytest.mark.skipif(
    not HF_HUB_AMD_ORG_ACCESS,
238
239
    reason="Read access to huggingface.co/amd is required for this test.",
)
240
def test_mxfp4_gsm8k_correctness(config: AccuracyTestConfig):
241
242
243
244
245
246
247
248
249
250
251
252
    if torch.cuda.device_count() < 8:
        pytest.skip(
            f"This test requires >=8 gpus, got only {torch.cuda.device_count()}"
        )

    task = "gsm8k"
    rtol = 0.03

    os.environ["VLLM_USE_TRITON_FLASH_ATTN"] = "0"

    results = lm_eval.simple_evaluate(
        model="vllm",
253
        model_args=config.get_model_args(tp_size=8, model_max_len=38768),
254
255
256
257
258
259
260
        tasks=task,
        batch_size=64,
        num_fewshot=8,
    )

    EXPECTED_VALUE = config.excepted_value
    measured_value = results["results"][task]["exact_match,strict-match"]
261
262
263
264
    assert (
        measured_value - rtol < EXPECTED_VALUE
        and measured_value + rtol > EXPECTED_VALUE
    ), f"Expected: {EXPECTED_VALUE} |  Measured: {measured_value}"
265
266
267
268

    del os.environ["VLLM_USE_TRITON_FLASH_ATTN"]


269
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
270
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
271
272
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_fused_qdq_match_quark(float_dtype: torch.dtype, scalings: list[int]):
273
274
275
    torch.manual_seed(0)

    hidden_size = 64 * 32
276
    inp = (torch.rand(1, hidden_size, dtype=float_dtype, device="cuda") - 0.5) * 2
277
    for i in range(hidden_size // 32):
278
279
280
        inp[:, i * 32 : (i + 1) * 32] = (
            inp[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
281
282
283
284
285
286
287
288

    inp_kernel = inp.clone()
    inp_kernel_clone = inp_kernel.clone()

    res_hip = mx_kernel.qdq_mxfp4_hip(inp_kernel_clone, "even")
    res_torch = qdq_mxfp4_torch(inp_kernel, "even")

    for i in range(hidden_size // 32):
289
290
        assert torch.all(torch.isfinite(res_hip[:, i * 32 : (i + 1) * 32]))
        assert torch.all(torch.isfinite(res_torch[:, i * 32 : (i + 1) * 32]))
291

292
293
294
        torch.testing.assert_close(
            res_hip[:, i * 32 : (i + 1) * 32], res_torch[:, i * 32 : (i + 1) * 32]
        )
295
296


297
@pytest.mark.skipif(not QUARK_MXFP4_AVAILABLE, reason="amd-quark>=0.9 is not available")
298
@pytest.mark.parametrize("float_dtype", [torch.bfloat16, torch.float16])
299
300
301
302
@pytest.mark.parametrize("scalings", [[2.3, 0.03, 7.3, 0.1, 0.004, 17.3, 1e4, 1e-4]])
def test_mxfp4_dequant_kernel_match_quark(
    float_dtype: torch.dtype, scalings: list[int]
):
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
    qspec = FP4PerGroupSpec(
        ch_axis=-1,
        group_size=32,
        scale_format="e8m0",
        scale_calculation_mode="even",
        is_dynamic=False,
    ).to_quantization_spec()

    weight_quantizer = StaticScaledRealQuantizer(
        qspec=qspec,
        quantizer=None,
        reorder=False,
        real_quantized=True,
        float_dtype=float_dtype,
        device="cuda",
    )

    observer = qspec.observer_cls(qspec, device="cuda")

    hidden_size = 512
    shape = (11008, hidden_size)

    w = (torch.rand(shape, device="cuda", dtype=float_dtype) - 0.5) * 2

    # Make it so that different groups have different scales.
    for i in range(hidden_size // 32):
329
330
331
        w[:, i * 32 : (i + 1) * 32] = (
            w[:, i * 32 : (i + 1) * 32] * scalings[i % len(scalings)]
        )
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346

    observer(w)
    scale, _ = observer._calculate_qparams()
    weight_quantizer.scale = scale

    w_mxfp4 = weight_quantizer.to_real_quantize_params(w).to("cuda")
    weight_quantizer.maybe_convert_and_transpose_scale()

    scale = weight_quantizer.scale

    out_hip = mx_kernel.dq_mxfp4_hip(w_mxfp4, scale, float_dtype)

    out_torch = dq_mxfp4_torch(w_mxfp4, scale, float_dtype)

    assert torch.equal(out_hip, out_torch)