test_marlin.py 2.65 KB
Newer Older
1
2
"""Compare the outputs of a GPTQ model to a Marlin model.

3
4
Note: GPTQ and Marlin do not have bitwise correctness.
As a result, in this test, we just confirm that the top selected tokens of the
5
6
7
8
9
10
Marlin/GPTQ models are in the top 3 selections of each other.

Note: Marlin internally uses locks to synchronize the threads. This can
result in very slight nondeterminism for Marlin. As a result, we re-run the test
up to 3 times to see if we pass.

11
Run `pytest tests/models/test_marlin.py`.
12
"""
13
14
from dataclasses import dataclass

15
16
import pytest
import torch
17

18
from tests.models.utils import check_logprobs_close
19
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
20
21
22

capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
23
24
marlin_not_supported = (capability <
                        QUANTIZATION_METHODS["marlin"].get_min_capability())
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48


@dataclass
class ModelPair:
    model_marlin: str
    model_gptq: str


model_pairs = [
    ModelPair(model_marlin="nm-testing/zephyr-beta-7b-marlin-g128",
              model_gptq="nm-testing/zephyr-beta-7b-gptq-g128"),
    ModelPair(model_marlin="robertgshaw2/zephyr-7b-beta-channelwise-marlin",
              model_gptq="robertgshaw2/zephyr-7b-beta-channelwise-gptq"),
    ModelPair(model_marlin="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-marlin",
              model_gptq="robertgshaw2/TinyLlama-1.1B-Chat-v1.0-g128-gptq")
]


@pytest.mark.flaky(reruns=2)
@pytest.mark.skipif(marlin_not_supported,
                    reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("model_pair", model_pairs)
@pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [32])
49
@pytest.mark.parametrize("num_logprobs", [5])
50
51
52
53
54
55
56
57
def test_models(
    vllm_runner,
    example_prompts,
    model_pair: ModelPair,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
58
59
60
    marlin_model = vllm_runner(model_pair.model_marlin,
                               dtype=dtype,
                               quantization="marlin")
61
62
63
64
    marlin_outputs = marlin_model.generate_greedy_logprobs(
        example_prompts, max_tokens, num_logprobs)
    del marlin_model

65
66
67
    gptq_model = vllm_runner(model_pair.model_gptq,
                             dtype=dtype,
                             quantization="gptq")
68
69
70
71
72
    gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts,
                                                       max_tokens,
                                                       num_logprobs)
    del gptq_model

73
74
75
76
77
78
    check_logprobs_close(
        outputs_0_lst=gptq_outputs,
        outputs_1_lst=marlin_outputs,
        name_0="gptq",
        name_1="marlin",
    )