test_full_cudagraph.py 7.46 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import contextlib
import os
5
import weakref
6
7
from dataclasses import dataclass
from typing import Optional
8
9
10

import pytest

11
from tests.utils import wait_for_gpu_memory_to_clear
12
13
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict
    specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
backend_configs = {
    # FA3 on Hopper
    "FA3":
    BackendConfig(name="FA3",
                  env_vars={"VLLM_FLASH_ATTN_VERSION": "3"},
                  comp_config={
                      "cudagraph_mode": "FULL",
                  },
                  specific_gpu_arch=(9, 0)),
    # FlashMLA on Hopper
    "FlashMLA":
    BackendConfig(name="FlashMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASHMLA",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  },
                  specific_gpu_arch=(9, 0)),
    # Cutlass MLA on Blackwell
    "CutlassMLA":
    BackendConfig(
        name="CutlassMLA",
        env_vars={
69
70
71
72
            "VLLM_USE_V1": "1",
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
            "FORCE_NUM_KV_SPLITS":
            "1",  # TODO: remove this when hang issue is fixed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
            "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
        },
        specific_gpu_arch=(10, 0)),
    # FA2
    "FA2":
    BackendConfig(name="FA2",
                  env_vars={"VLLM_FLASH_ATTN_VERSION": "2"},
                  comp_config={
                      "cudagraph_mode": "FULL",
                  }),
    # Triton Attention
    "TritonAttn":
    BackendConfig(name="TritonAttn",
                  env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN_VLLM_V1"},
                  comp_config={
                      "cudagraph_mode": "FULL",
                  }),
    # FlashInfer
    "FlashInfer":
    BackendConfig(name="FlashInfer",
                  env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
}

test_params_full_cudagraph = []

# deepseek-ai/DeepSeek-V2-Lite with MLA
MLA_backends = ["FlashMLA", "CutlassMLA"]
for mla_backend in MLA_backends:
    test_params_full_cudagraph.append(
        pytest.param(
            ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))

# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
    backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
    test_params_full_cudagraph.append(
        pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))


@pytest.fixture(scope="class")
def llm_pair(request):
    model, backend_config = request.param

    # Dynamically skip test if GPU capability is not met
    if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
        != current_platform.get_device_capability():
        if backend_config.specific_gpu_arch == (9, 0):
            pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
        elif backend_config.specific_gpu_arch == (10, 0):
            pytest.skip("Only Blackwell GPUs support Cutlass MLA")

    env_vars = {
        "VLLM_USE_V1": "1",
        # Force native sampler to avoid potential nondeterminism in FlashInfer
        # when per-request generators are not used in V1.
        "VLLM_USE_FLASHINFER_SAMPLER": "0",
        **backend_config.env_vars,
    }
    with temporary_environ(env_vars):
140
141
        full = LLM(
            model=model,
142
            gpu_memory_utilization=0.43,
143
144
            trust_remote_code=True,
            max_model_len=1024,
145
146
147
148
149
            max_num_seqs=128,
            compilation_config=\
                CompilationConfig(**backend_config.comp_config),
            generation_config="vllm",
            seed=42,
150
151
152
        )
        piecewise = LLM(
            model=model,
153
            gpu_memory_utilization=0.43,
154
155
            trust_remote_code=True,
            max_model_len=1024,
156
157
158
159
            max_num_seqs=128,
            compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
            generation_config="vllm",
            seed=42,
160
161
        )

162
    # PyTest caches the fixture values so we use weakref.proxy to enable GC
163
164
165
166
167
168
169
170
171
172
    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


173
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
174
class TestFullCUDAGraph:
175
    """
176
177
    Use a class such that an llm pair is constructed once for all
    batch_size/max_tokens combinations and released immediately after.
178

179
180
    Module-scope fixtures would stick around the whole time,
    meaning there would be multiple LLM instances hogging memory simultaneously.
181
182
    """

183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    @pytest.mark.parametrize(("batch_size", "max_tokens"), [
        (1, 10),
        (7, 10),
        (16, 10),
        (25, 10),
        (32, 10),
        (45, 10),
        (64, 10),
        (123, 10),
        (8, 5),
        (8, 30),
    ])
    def test_full_cudagraph(self, batch_size, max_tokens,
                            llm_pair: tuple[LLM, LLM]):
        """
        Test various batch sizes and max_tokens to ensure that the
        full cudagraph compilation works for padded cases too.
        """

202
        full_cudagraph_llm, piecewise_llm = llm_pair
203

204
205
206
        prompts = ["the quick brown fox"] * batch_size
        # Use purely greedy decoding to avoid top-p truncation sensitivity
        # that can amplify tiny numeric differences across runtimes.
207
208
        sampling_params = SamplingParams(temperature=0.0,
                                         max_tokens=max_tokens,
209
                                         top_p=1.0)
210
211
212
213
214
215
216

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        # Check that all responses are the same
        for piecewise_res, full_res in zip(piecewise_responses,
                                           full_responses):
217
218
            assert piecewise_res.outputs[0].text.lower() == \
                full_res.outputs[0].text.lower()
219
220


221
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
222
223
224
def test_full_cudagraph_with_invalid_backend():
    with temporary_environ({
            "VLLM_USE_V1": "1",
225
226
            "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
            # Flex_Attention is not supported with full cuda graph
227
    }), pytest.raises(RuntimeError):
228
        LLM(model="Qwen/Qwen2-1.5B-Instruct",
229
            compilation_config=CompilationConfig(cudagraph_mode="FULL"))