test_full_cudagraph.py 5.98 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import contextlib
import os
5
import weakref
6
7
8

import pytest

9
from tests.utils import wait_for_gpu_memory_to_clear
10
from tests.v1.attention.utils import full_cg_backend_configs as backend_configs
11
12
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
13
from vllm.platforms import current_platform
14
from vllm.utils.torch_utils import is_torch_equal_or_newer
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


36
model_backends_full_cudagraph = []
37
38

# deepseek-ai/DeepSeek-V2-Lite with MLA
39
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
40
for mla_backend in MLA_backends:
41
42
    model_backends_full_cudagraph.append(
        ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])
43
    )
44
45
46
47
48
49

# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
    backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
50
    model_backends_full_cudagraph.append(("Qwen/Qwen2-1.5B-Instruct", backend_config))
51
52
53
54


@pytest.fixture(scope="class")
def llm_pair(request):
55
56
57
58
59
60
61
    model, backend_config, use_inductor_graph_partition = request.param
    backend_config.comp_config["use_inductor_graph_partition"] = (
        use_inductor_graph_partition
    )

    if use_inductor_graph_partition and not is_torch_equal_or_newer("2.9.0.dev"):
        pytest.skip("Inductor graph partition only supported in torch>=2.9")
62
63

    # Dynamically skip test if GPU capability is not met
64
65
66
67
    if (
        backend_config.specific_gpu_arch
        and backend_config.specific_gpu_arch != current_platform.get_device_capability()
    ):
68
69
70
71
72
73
74
75
76
77
78
79
        if backend_config.specific_gpu_arch == (9, 0):
            pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
        elif backend_config.specific_gpu_arch == (10, 0):
            pytest.skip("Only Blackwell GPUs support Cutlass MLA")

    env_vars = {
        # Force native sampler to avoid potential nondeterminism in FlashInfer
        # when per-request generators are not used in V1.
        "VLLM_USE_FLASHINFER_SAMPLER": "0",
        **backend_config.env_vars,
    }
    with temporary_environ(env_vars):
80
81
        full = LLM(
            model=model,
82
            gpu_memory_utilization=0.43,
83
84
            trust_remote_code=True,
            max_model_len=1024,
85
            max_num_seqs=128,
86
            compilation_config=CompilationConfig(**backend_config.comp_config),
87
88
            generation_config="vllm",
            seed=42,
89
90
91
        )
        piecewise = LLM(
            model=model,
92
            gpu_memory_utilization=0.43,
93
94
            trust_remote_code=True,
            max_model_len=1024,
95
96
97
98
            max_num_seqs=128,
            compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
            generation_config="vllm",
            seed=42,
99
100
        )

101
    # PyTest caches the fixture values so we use weakref.proxy to enable GC
102
103
104
105
106
107
108
109
110
111
    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


112
113
114
115
116
117
118
119
120
@pytest.mark.parametrize(
    "llm_pair",
    [
        pytest.param((model, backend_config, use_inductor_graph_partition))
        for model, backend_config in model_backends_full_cudagraph
        for use_inductor_graph_partition in [True, False]
    ],
    indirect=True,
)
121
class TestFullCUDAGraph:
122
    """
123
124
    Use a class such that an llm pair is constructed once for all
    batch_size/max_tokens combinations and released immediately after.
125

126
127
    Module-scope fixtures would stick around the whole time,
    meaning there would be multiple LLM instances hogging memory simultaneously.
128
129
    """

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
    @pytest.mark.parametrize(
        ("batch_size", "max_tokens"),
        [
            (1, 10),
            (7, 10),
            (16, 10),
            (25, 10),
            (32, 10),
            (45, 10),
            (64, 10),
            (123, 10),
            (8, 5),
            (8, 30),
        ],
    )
    def test_full_cudagraph(self, batch_size, max_tokens, llm_pair: tuple[LLM, LLM]):
146
147
148
149
150
        """
        Test various batch sizes and max_tokens to ensure that the
        full cudagraph compilation works for padded cases too.
        """

151
        full_cudagraph_llm, piecewise_llm = llm_pair
152

153
154
155
        prompts = ["the quick brown fox"] * batch_size
        # Use purely greedy decoding to avoid top-p truncation sensitivity
        # that can amplify tiny numeric differences across runtimes.
156
157
158
        sampling_params = SamplingParams(
            temperature=0.0, max_tokens=max_tokens, top_p=1.0
        )
159
160
161
162
163

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        # Check that all responses are the same
164
165
166
167
168
        for piecewise_res, full_res in zip(piecewise_responses, full_responses):
            assert (
                piecewise_res.outputs[0].text.lower()
                == full_res.outputs[0].text.lower()
            )
169
170


171
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
172
def test_full_cudagraph_with_invalid_backend():
173
174
175
176
177
178
179
180
181
182
183
184
185
    with (
        temporary_environ(
            {
                "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION",
                # Flex_Attention is not supported with full cuda graph
            }
        ),
        pytest.raises(RuntimeError),
    ):
        LLM(
            model="Qwen/Qwen2-1.5B-Instruct",
            compilation_config=CompilationConfig(cudagraph_mode="FULL"),
        )