test_quark.py 3.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
"""Test model set-up and weight loading for quark-quantized models.

Run `pytest tests/quantization/test_quark.py`.
"""

8
import pytest
9
import torch
10
11

from vllm.model_executor.layers.quantization.quark.quark import (  # noqa: E501
12
13
    QuarkLinearMethod, QuarkW8A8Fp8, QuarkW8A8Int8)
from vllm.platforms import current_platform
14
15


16
17
18
19
20
21
22
23
24
25
26
@pytest.fixture(scope="function", autouse=True)
def use_v0_only(monkeypatch):
    """
    This module relies on V0 internals, so set VLLM_USE_V1=0.
    """
    monkeypatch.setenv('VLLM_USE_V1', '0')


@pytest.mark.parametrize('kv_cache_dtype', ['auto', 'fp8'])
@pytest.mark.parametrize('tp', [1])
def test_quark_fp8_w_per_tensor_a_per_tensor(vllm_runner, kv_cache_dtype, tp):
27
    model_path = "amd/Llama-3.1-8B-Instruct-FP8-KV-Quark-test"
28
29
30
    with vllm_runner(model_path,
                     kv_cache_dtype=kv_cache_dtype,
                     tensor_parallel_size=tp) as llm:
31

32
33
        def check_model(model):
            layer = model.model.layers[0]
34

35
            qkv_proj = layer.self_attn.qkv_proj
36

37
38
39
40
41
            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Fp8)

            if isinstance(qkv_proj.scheme, QuarkW8A8Fp8):
                assert len(qkv_proj.input_scale.shape) == 0
42
                assert qkv_proj.weight.dtype is current_platform.fp8_dtype()
43
44
45
                assert len(qkv_proj.weight_scale.shape) == 0

        llm.apply_model(check_model)
46
47
48

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67


@pytest.mark.parametrize('tp', [1])
def test_quark_int8_w_per_tensor_a_per_tensor(vllm_runner, tp):
    model_path = "amd/Llama-3.1-8B-Instruct-w-int8-a-int8-sym-test"
    with vllm_runner(model_path, tensor_parallel_size=tp) as llm:

        def check_model(model):
            layer = model.model.layers[0]

            qkv_proj = layer.self_attn.qkv_proj

            assert isinstance(qkv_proj.quant_method, QuarkLinearMethod)
            assert isinstance(qkv_proj.scheme, QuarkW8A8Int8)

        llm.apply_model(check_model)

        output = llm.generate_greedy("Hello my name is", max_tokens=20)
        assert output
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92


def test_quark_fp8_parity(vllm_runner):
    quark_model_id = "amd-quark/llama-tiny-fp8-quark-quant-method"
    fp8_model_id = "amd-quark/llama-tiny-fp8-quant-method"

    llm_kwargs = {
        "tensor_parallel_size": 1,
        "enforce_eager": True,
        "gpu_memory_utilization": 0.1
    }
    with (vllm_runner(quark_model_id, **llm_kwargs) as
          quark_handle, vllm_runner(fp8_model_id, **llm_kwargs) as fp8_handle):
        quark_model = (quark_handle.model.llm_engine.model_executor.
                       driver_worker.model_runner.model)
        quark_state_dict = quark_model.state_dict()

        fp8_model = (fp8_handle.model.llm_engine.model_executor.driver_worker.
                     model_runner.model)
        fp8_state_dict = fp8_model.state_dict()

    assert fp8_state_dict.keys() == quark_state_dict.keys()

    for key in fp8_state_dict:
        assert torch.equal(fp8_state_dict[key], quark_state_dict[key])