test_basic.py 1.94 KB
Newer Older
1
2
3
4
5
# SPDX-License-Identifier: Apache-2.0
"""A basic correctness check for TPUs

Run `pytest tests/v1/tpu/test_basic.py`.
"""
6
7
8
9
from __future__ import annotations

from typing import TYPE_CHECKING

10
11
12
13
import pytest

from vllm.platforms import current_platform

14
15
if TYPE_CHECKING:
    from tests.conftest import VllmRunner
16
17

MODELS = [
18
19
    "Qwen/Qwen2.5-1.5B-Instruct",
    # TODO: Enable this models with v6e
20
    # "Qwen/Qwen2-7B-Instruct",
21
    # "meta-llama/Llama-3.1-8B",
22
23
24
]

TENSOR_PARALLEL_SIZES = [1]
25
MAX_NUM_REQS = [16, 1024]
26
27
28
29
30
31
32
33
34
35

# TODO: Enable when CI/CD will have a multi-tpu instance
# TENSOR_PARALLEL_SIZES = [1, 4]


@pytest.mark.skipif(not current_platform.is_tpu(),
                    reason="This is a basic test for TPU only")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("tensor_parallel_size", TENSOR_PARALLEL_SIZES)
36
@pytest.mark.parametrize("max_num_seqs", MAX_NUM_REQS)
37
def test_basic(
38
39
    vllm_runner: type[VllmRunner],
    monkeypatch: pytest.MonkeyPatch,
40
41
42
    model: str,
    max_tokens: int,
    tensor_parallel_size: int,
43
    max_num_seqs: int,
44
45
46
47
48
49
50
51
) -> None:
    prompt = "The next numbers of the sequence " + ", ".join(
        str(i) for i in range(1024)) + " are:"
    example_prompts = [prompt]

    with monkeypatch.context() as m:
        m.setenv("VLLM_USE_V1", "1")

52
        with vllm_runner(
53
                model,
54
55
56
                # Note: max_num_batched_tokens == 1024 is needed here to
                # actually test chunked prompt
                max_num_batched_tokens=1024,
57
                max_model_len=8192,
58
                gpu_memory_utilization=0.7,
59
                max_num_seqs=max_num_seqs,
60
61
62
                tensor_parallel_size=tensor_parallel_size) as vllm_model:
            vllm_outputs = vllm_model.generate_greedy(example_prompts,
                                                      max_tokens)
63
        output = vllm_outputs[0][1]
64
65

        assert "1024" in output or "0, 1" in output