test_basic.py 3.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
"""A basic correctness check for TPUs

Run `pytest tests/v1/tpu/test_basic.py`.
"""
7
8
9
10
from __future__ import annotations

from typing import TYPE_CHECKING

11
import pytest
12
from torch_xla._internal import tpu
13
14
15

from vllm.platforms import current_platform

16
17
if TYPE_CHECKING:
    from tests.conftest import VllmRunner
18
19

MODELS = [
20
21
    "Qwen/Qwen2.5-1.5B-Instruct",
    # TODO: Enable this models with v6e
22
    # "Qwen/Qwen2-7B-Instruct",
23
    # "meta-llama/Llama-3.1-8B",
24
25
26
]

TENSOR_PARALLEL_SIZES = [1]
27
MAX_NUM_REQS = [16, 1024]
28
29
30
31
32
33
34
35
36
37

# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
38
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
39
def test_basic(
40
41
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
42
43
44
    model: str,
    max_tokens: int,
    tensor_parallel_size: int,
45
    max_num_seqs: int,
46
47
48
49
50
51
52
53
) -> None:
    prompt = "The next numbers of the sequence " + ", ".join(
        str(i) for i in range(1024)) + " are:"
    example_prompts = [prompt]

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

54
        with vllm_runner(
55
                model,
56
57
58
                # Note: max_num_batched_tokens == 1024 is needed here to
                # actually test chunked prompt
                max_num_batched_tokens=1024,
59
                max_model_len=8192,
60
                gpu_memory_utilization=0.7,
61
                max_num_seqs=max_num_seqs,
62
63
64
                tensor_parallel_size=tensor_parallel_size) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                      max_tokens)
65
        output = vllm_outputs[0][1]
66
67

        assert "1024" in output or "0, 1" in output
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


TP_SIZE_8 = 8


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a test for TPU only")
@pytest.mark.skipif(tpu.num_available_chips() < TP_SIZE_8,
                    reason=f"This test requires {TP_SIZE_8} TPU chips.")
def test_gemma3_27b_with_text_input_and_tp(
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
) -> None:
    model = "google/gemma-3-27b-it"
    max_tokens = 16
    tensor_parallel_size = TP_SIZE_8
    max_num_seqs = 4
    prompts = [
        "A robot may not injure a human being",
        "It is only with the heart that one can see rightly;",
        "The greatest glory in living lies not in never falling,",
    ]
    answers = [
        " or, through inaction, allow a human being to come to harm.",
        " what is essential is invisible to the eye.",
        " but in rising every time we fall.",
    ]

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

        with vllm_runner(
                model,
                max_num_batched_tokens=256,
                max_num_seqs=max_num_seqs,
                tensor_parallel_size=tensor_parallel_size) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(prompts, max_tokens)
        # vllm_outputs is a list of tuples whose first element is the token id
        # and the second element is the output (including the prompt).
        for output, answer in zip(vllm_outputs, answers):
            generated_text = output[1]
            assert answer in generated_text