test_mistral.py 2.33 KB
Newer Older
1
2
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.

3
Run `pytest tests/models/test_mistral.py`.
4
5
6
"""
import pytest

7
from ...utils import check_logprobs_close
8

9
10
MODELS = [
    "mistralai/Mistral-7B-Instruct-v0.1",
11
    "mistralai/Mistral-7B-Instruct-v0.3",
12
13
14
15
16
]


@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["bfloat16"])
17
18
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
19
20
21
def test_models(
    hf_runner,
    vllm_runner,
22
    example_prompts,
23
24
25
    model: str,
    dtype: str,
    max_tokens: int,
26
    num_logprobs: int,
27
) -> None:
28
    # TODO(sang): Sliding window should be tested separately.
29
30
31
    with hf_runner(model, dtype=dtype) as hf_model:
        hf_outputs = hf_model.generate_greedy_logprobs_limit(
            example_prompts, max_tokens, num_logprobs)
32

33
34
    with vllm_runner(model, dtype=dtype,
                     tokenizer_mode="mistral") as vllm_model:
35
36
        vllm_outputs = vllm_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)
37

38
39
40
41
42
43
    check_logprobs_close(
        outputs_0_lst=hf_outputs,
        outputs_1_lst=vllm_outputs,
        name_0="hf",
        name_1="vllm",
    )
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83


@pytest.mark.parametrize("model", MODELS[1:])
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_mistral_format(
    vllm_runner,
    example_prompts,
    model: str,
    dtype: str,
    max_tokens: int,
    num_logprobs: int,
) -> None:
    with vllm_runner(
            model,
            dtype=dtype,
            tokenizer_mode="auto",
            load_format="safetensors",
            config_format="hf",
    ) as hf_format_model:
        hf_format_outputs = hf_format_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    with vllm_runner(
            model,
            dtype=dtype,
            tokenizer_mode="mistral",
            load_format="mistral",
            config_format="mistral",
    ) as mistral_format_model:
        mistral_format_outputs = mistral_format_model.generate_greedy_logprobs(
            example_prompts, max_tokens, num_logprobs)

    check_logprobs_close(
        outputs_0_lst=hf_format_outputs,
        outputs_1_lst=mistral_format_outputs,
        name_0="hf",
        name_1="mistral",
    )