test_basic.py 4.94 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""A basic correctness check for TPUs

Run `pytest tests/v1/tpu/test_basic.py`.
"""
7
8
9
10
from __future__ import annotations

from typing import TYPE_CHECKING

11
import pytest
12
from torch_xla._internal import tpu
13
14
15

from vllm.platforms import current_platform

16
17
if TYPE_CHECKING:
    from tests.conftest import VllmRunner
18
19

MODELS = [
20
21
    "Qwen/Qwen2.5-1.5B-Instruct",
    # TODO: Enable this models with v6e
22
    # "Qwen/Qwen2-7B-Instruct",
23
    # "meta-llama/Llama-3.1-8B",
24
25
26
]

TENSOR_PARALLEL_SIZES = [1]
27
MAX_NUM_REQS = [16, 1024]
28
29
30
31
32
33
34
35
36
37

# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
38
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
39
def test_basic(
40
41
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
42
43
44
    model: str,
    max_tokens: int,
    tensor_parallel_size: int,
45
    max_num_seqs: int,
46
47
48
49
50
51
52
53
) -> None:
    prompt = "The next numbers of the sequence " + ", ".join(
        str(i) for i in range(1024)) + " are:"
    example_prompts = [prompt]

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

54
        with vllm_runner(
55
                model,
56
57
58
                # Note: max_num_batched_tokens == 1024 is needed here to
                # actually test chunked prompt
                max_num_batched_tokens=1024,
59
                max_model_len=8192,
60
                gpu_memory_utilization=0.7,
61
                max_num_seqs=max_num_seqs,
62
63
64
                tensor_parallel_size=tensor_parallel_size) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                      max_tokens)
65
        output = vllm_outputs[0][1]
66
67

        assert "1024" in output or "0, 1" in output
68
69


70
@pytest.mark.skip(reason="Temporarily disabled due to timeout")
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a basic test for TPU only")
@pytest.mark.parametrize("max_tokens", [8])
@pytest.mark.parametrize("max_num_seqs", [16])
def test_phi3(
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
    max_tokens: int,
    max_num_seqs: int,
) -> None:
    prompts = [
        "A robot may not injure a human being",
        "It is only with the heart that one can see rightly;",
        "The greatest glory in living lies not in never falling,",
    ]
    answers = [
        " or, by violating privacy",
        " what is essential is love.",
        " but in rising every time we fall.",
    ]
    # test head dim = 96
    model = "microsoft/Phi-3-mini-128k-instruct"

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        with vllm_runner(model,
                         max_num_batched_tokens=256,
                         max_num_seqs=max_num_seqs) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
        # vllm_outputs is a list of tuples whose first element is the token id
        # and the second element is the output (including the prompt).
        for output, answer in zip(vllm_outputs, answers):
            generated_text = output[1]
            assert answer in generated_text


108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
TP_SIZE_8 = 8


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a test for TPU only")
@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8,
                    reason=f"This test requires {TP_SIZE_8} TPU chips.")
def test_gemma3_27b_with_text_input_and_tp(
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    model = "google/gemma-3-27b-it"
    max_tokens = 16
    tensor_parallel_size = TP_SIZE_8
    max_num_seqs = 4
    prompts = [
        "A robot may not injure a human being",
        "It is only with the heart that one can see rightly;",
        "The greatest glory in living lies not in never falling,",
    ]
    answers = [
        " or, through inaction, allow a human being to come to harm.",
        " what is essential is invisible to the eye.",
        " but in rising every time we fall.",
    ]

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        with vllm_runner(
                model,
                max_num_batched_tokens=256,
                max_num_seqs=max_num_seqs,
                tensor_parallel_size=tensor_parallel_size) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
        # vllm_outputs is a list of tuples whose first element is the token id
        # and the second element is the output (including the prompt).
        for output, answer in zip(vllm_outputs, answers):
            generated_text = output[1]
            assert answer in generated_text