test_full_cudagraph.py 3.55 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7
8
9
import contextlib
import os

import pytest

from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
10
from vllm.platforms import current_platform
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

MODEL = "Qwen/Qwen2-1.5B-Instruct"


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


@pytest.fixture(scope="module")
def full_cudagraph_llm():
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }):
        return LLM(model=MODEL,
41
                   gpu_memory_utilization=0.3,
42
43
44
45
46
47
48
49
50
51
                   compilation_config=CompilationConfig(full_cuda_graph=True))


@pytest.fixture(scope="module")
def piecewise_llm():
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }):
        return LLM(model=MODEL,
52
                   gpu_memory_utilization=0.6,
53
54
55
56
57
58
59
60
61
62
63
64
                   compilation_config=CompilationConfig())


def generate_text(llm: LLM, batch_size: int, max_tokens: int):
    prompts = ["Hi my name is"] * batch_size
    sampling_params = SamplingParams(temperature=0.0,
                                     max_tokens=max_tokens,
                                     top_p=0.95)

    return llm.generate(prompts, sampling_params)


65
66
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
                    reason="Only Hopper GPUs support FlashAttention 3")
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
@pytest.mark.parametrize(("batch_size", "max_tokens"), [(1, 10), (7, 10),
                                                        (16, 10), (25, 10),
                                                        (32, 10), (45, 10),
                                                        (64, 10), (8, 5),
                                                        (8, 20), (8, 200)])
def test_full_cudagraph(batch_size, max_tokens, full_cudagraph_llm,
                        piecewise_llm):
    """
    Load full cudagraph model and piecewise model once, and at the same time to
    reuse them across various test cases.

    Test various batch sizes and max_tokens to ensure that the full cudagraph
    compilation works for padded cases too.
    """
    piecewise_responses = generate_text(piecewise_llm,
                                        batch_size=batch_size,
                                        max_tokens=max_tokens)
    full_cudagraph_responses = generate_text(full_cudagraph_llm,
                                             batch_size=batch_size,
                                             max_tokens=max_tokens)

    # Check that all responses are the same
    for i in range(len(piecewise_responses)):
        assert piecewise_responses[i].outputs[
            0].text == full_cudagraph_responses[i].outputs[0].text


def test_full_cudagraph_with_invalid_backend():
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION":
            "2"  #FA2 not supported with full_cuda_graph
    }), pytest.raises(RuntimeError):
        LLM(model=MODEL,
            compilation_config=CompilationConfig(full_cuda_graph=True))