test_full_cudagraph.py 8.16 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import contextlib
import os
5
import weakref
6
7
from dataclasses import dataclass
from typing import Optional
8
9
10

import pytest

11
from tests.utils import wait_for_gpu_memory_to_clear
12
13
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
14
from vllm.platforms import current_platform
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


36
37
38
39
40
41
42
43
44
45
46
47
48
@dataclass
class BackendConfig:
    name: str
    env_vars: dict
    comp_config: dict
    specific_gpu_arch: Optional[tuple] = None


# Define all backend configurations of full cudagraph to be tested
backend_configs = {
    # FA3 on Hopper
    "FA3":
    BackendConfig(name="FA3",
49
50
51
52
                  env_vars={
                      "VLLM_FLASH_ATTN_VERSION": "3",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
53
54
55
56
57
58
59
60
61
62
63
64
65
66
                  comp_config={
                      "cudagraph_mode": "FULL",
                  },
                  specific_gpu_arch=(9, 0)),
    # FlashMLA on Hopper
    "FlashMLA":
    BackendConfig(name="FlashMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASHMLA",
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  },
                  specific_gpu_arch=(9, 0)),
67
68
69
70
71
    # FlashAttention MLA on Hopper
    "FlashAttentionMLA":
    BackendConfig(name="FlashAttentionMLA",
                  env_vars={
                      "VLLM_ATTENTION_BACKEND": "FLASH_ATTN_MLA",
72
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
73
74
75
76
77
                  },
                  comp_config={
                      "cudagraph_mode": "FULL_DECODE_ONLY",
                  },
                  specific_gpu_arch=(9, 0)),
78
79
80
81
82
    # Cutlass MLA on Blackwell
    "CutlassMLA":
    BackendConfig(
        name="CutlassMLA",
        env_vars={
83
84
85
86
            "VLLM_USE_V1": "1",
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
            "FORCE_NUM_KV_SPLITS":
            "1",  # TODO: remove this when hang issue is fixed
87
88
89
90
91
92
93
94
95
        },
        comp_config={
            "cudagraph_mode": "FULL_AND_PIECEWISE",
            "cudagraph_capture_sizes": [16, 32, 64, 128, 256, 512],
        },
        specific_gpu_arch=(10, 0)),
    # FA2
    "FA2":
    BackendConfig(name="FA2",
96
97
98
99
                  env_vars={
                      "VLLM_FLASH_ATTN_VERSION": "2",
                      "VLLM_FLASH_ATTN_MAX_NUM_SPLITS_FOR_CUDA_GRAPH": "16",
                  },
100
101
102
103
104
105
                  comp_config={
                      "cudagraph_mode": "FULL",
                  }),
    # Triton Attention
    "TritonAttn":
    BackendConfig(name="TritonAttn",
106
                  env_vars={"VLLM_ATTENTION_BACKEND": "TRITON_ATTN"},
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
                  comp_config={
                      "cudagraph_mode": "FULL",
                  }),
    # FlashInfer
    "FlashInfer":
    BackendConfig(name="FlashInfer",
                  env_vars={"VLLM_ATTENTION_BACKEND": "FLASHINFER"},
                  comp_config={
                      "cudagraph_mode": "FULL_AND_PIECEWISE",
                  }),
}

test_params_full_cudagraph = []

# deepseek-ai/DeepSeek-V2-Lite with MLA
122
MLA_backends = ["FlashMLA", "FlashAttentionMLA", "CutlassMLA"]
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
for mla_backend in MLA_backends:
    test_params_full_cudagraph.append(
        pytest.param(
            ("deepseek-ai/DeepSeek-V2-Lite", backend_configs[mla_backend])))

# Qwen/Qwen2-1.5B-Instruct with other backends
other_backend_configs = [
    backend_configs[c] for c in backend_configs if c not in MLA_backends
]
for backend_config in other_backend_configs:
    test_params_full_cudagraph.append(
        pytest.param(("Qwen/Qwen2-1.5B-Instruct", backend_config)))


@pytest.fixture(scope="class")
def llm_pair(request):
    model, backend_config = request.param

    # Dynamically skip test if GPU capability is not met
    if backend_config.specific_gpu_arch and backend_config.specific_gpu_arch\
        != current_platform.get_device_capability():
        if backend_config.specific_gpu_arch == (9, 0):
            pytest.skip("Only Hopper GPUs support FA3 and FlashMLA")
        elif backend_config.specific_gpu_arch == (10, 0):
            pytest.skip("Only Blackwell GPUs support Cutlass MLA")

    env_vars = {
        "VLLM_USE_V1": "1",
        # Force native sampler to avoid potential nondeterminism in FlashInfer
        # when per-request generators are not used in V1.
        "VLLM_USE_FLASHINFER_SAMPLER": "0",
        **backend_config.env_vars,
    }
    with temporary_environ(env_vars):
157
158
        full = LLM(
            model=model,
159
            gpu_memory_utilization=0.43,
160
161
            trust_remote_code=True,
            max_model_len=1024,
162
163
164
165
166
            max_num_seqs=128,
            compilation_config=\
                CompilationConfig(**backend_config.comp_config),
            generation_config="vllm",
            seed=42,
167
168
169
        )
        piecewise = LLM(
            model=model,
170
            gpu_memory_utilization=0.43,
171
172
            trust_remote_code=True,
            max_model_len=1024,
173
174
175
176
            max_num_seqs=128,
            compilation_config=CompilationConfig(cudagraph_mode="PIECEWISE"),
            generation_config="vllm",
            seed=42,
177
178
        )

179
    # PyTest caches the fixture values so we use weakref.proxy to enable GC
180
181
182
183
184
185
186
187
188
189
    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


190
@pytest.mark.parametrize("llm_pair", test_params_full_cudagraph, indirect=True)
191
class TestFullCUDAGraph:
192
    """
193
194
    Use a class such that an llm pair is constructed once for all
    batch_size/max_tokens combinations and released immediately after.
195

196
197
    Module-scope fixtures would stick around the whole time,
    meaning there would be multiple LLM instances hogging memory simultaneously.
198
199
    """

200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
    @pytest.mark.parametrize(("batch_size", "max_tokens"), [
        (1, 10),
        (7, 10),
        (16, 10),
        (25, 10),
        (32, 10),
        (45, 10),
        (64, 10),
        (123, 10),
        (8, 5),
        (8, 30),
    ])
    def test_full_cudagraph(self, batch_size, max_tokens,
                            llm_pair: tuple[LLM, LLM]):
        """
        Test various batch sizes and max_tokens to ensure that the
        full cudagraph compilation works for padded cases too.
        """

219
        full_cudagraph_llm, piecewise_llm = llm_pair
220

221
222
223
        prompts = ["the quick brown fox"] * batch_size
        # Use purely greedy decoding to avoid top-p truncation sensitivity
        # that can amplify tiny numeric differences across runtimes.
224
225
        sampling_params = SamplingParams(temperature=0.0,
                                         max_tokens=max_tokens,
226
                                         top_p=1.0)
227
228
229
230
231
232
233

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        # Check that all responses are the same
        for piecewise_res, full_res in zip(piecewise_responses,
                                           full_responses):
234
235
            assert piecewise_res.outputs[0].text.lower() == \
                full_res.outputs[0].text.lower()
236
237


238
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
239
240
241
def test_full_cudagraph_with_invalid_backend():
    with temporary_environ({
            "VLLM_USE_V1": "1",
242
243
            "VLLM_ATTENTION_BACKEND": "FLEX_ATTENTION"
            # Flex_Attention is not supported with full cuda graph
244
    }), pytest.raises(RuntimeError):
245
        LLM(model="Qwen/Qwen2-1.5B-Instruct",
246
            compilation_config=CompilationConfig(cudagraph_mode="FULL"))