test_full_cudagraph.py 7.42 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
import contextlib
import os
5
6
import weakref
from contextlib import ExitStack
7
8
9

import pytest

10
from tests.utils import wait_for_gpu_memory_to_clear
11
12
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
13
from vllm.platforms import current_platform
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34


@contextlib.contextmanager
def temporary_environ(env_vars):
    """
    Temporarily set environment variables and restore them afterward.
    We have to do this vs monkeypatch because monkeypatch doesn't work
    with "module" scoped fixtures.
    """
    original_env = {k: os.environ.get(k) for k in env_vars}
    try:
        os.environ.update(env_vars)
        yield
    finally:
        for k, v in original_env.items():
            if v is None:
                os.environ.pop(k, None)
            else:
                os.environ[k] = v


35
36
37
@pytest.fixture(scope="class")
def llm_pair(request):
    model = request.param
38
39
40
41
42

    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }):
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
        full = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(full_cuda_graph=True),
        )
        piecewise = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(),
        )

    # PyTest caches the fixture values so we use weakref.proxy to enable GC
    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@pytest.fixture(scope="class")
def cutlass_mla_llm_pair(request):
    model = request.param

    # force V1 engine and Cutlass MLA backend
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_ATTENTION_BACKEND": "CUTLASS_MLA",
            "FORCE_NUM_KV_SPLITS":
            "1",  # TODO: remove this when hang issue is fixed
    }):
        full = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(
                full_cuda_graph=True,
                cudagraph_capture_sizes=[16, 32, 64, 128, 256, 512],
            ),
        )
        piecewise = LLM(
            model=model,
            gpu_memory_utilization=0.45,
            trust_remote_code=True,
            max_model_len=1024,
            compilation_config=CompilationConfig(),
        )

    yield weakref.proxy(full), weakref.proxy(piecewise)
    del full
    del piecewise

    wait_for_gpu_memory_to_clear(
        devices=[0],
        threshold_ratio=0.1,
    )


@pytest.mark.parametrize(
    "cutlass_mla_llm_pair",
    [
        # use an MLA model
        "deepseek-ai/DeepSeek-V2-Lite",
    ],
    indirect=True)
@pytest.mark.skipif(current_platform.get_device_capability() != (10, 0),
                    reason="Only Blackwell GPUs support Cutlass MLA")
class TestFullCUDAGraphCutlassMLA:
    """
    Validate full CUDA Graph with Cutlass MLA (decode-only capture).
    """

    @pytest.mark.parametrize(("batch_size", "max_tokens"), [
        (8, 8),
    ])
    def test_full_cudagraph_sm100_cutlass_mla(
            self, batch_size, max_tokens, cutlass_mla_llm_pair: tuple[LLM,
                                                                      LLM]):
        piecewise_llm, full_cudagraph_llm = cutlass_mla_llm_pair

        prompts = ["Hello, my name is"] * batch_size
        sampling_params = SamplingParams(temperature=0.0,
                                         max_tokens=max_tokens,
                                         top_p=0.95)

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        for piecewise_res, full_res in zip(piecewise_responses,
                                           full_responses):
            assert piecewise_res.outputs[0].text == full_res.outputs[0].text


143
144
145
146
147
148
149
150
@pytest.mark.parametrize(
    "llm_pair",
    [
        # Model names for the llm_pair fixture
        "deepseek-ai/DeepSeek-V2-Lite",
        "Qwen/Qwen2-1.5B-Instruct"
    ],
    indirect=True)
151
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
152
153
                    reason="Only Hopper GPUs support FA3 and FlashMLA")
class TestFullCUDAGraph:
154
    """
155
156
    Use a class such that an llm pair is constructed once for all
    batch_size/max_tokens combinations and released immediately after.
157

158
159
    Module-scope fixtures would stick around the whole time,
    meaning there would be multiple LLM instances hogging memory simultaneously.
160
161
    """

162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
    @pytest.mark.parametrize(("batch_size", "max_tokens"), [
        (1, 10),
        (7, 10),
        (16, 10),
        (25, 10),
        (32, 10),
        (45, 10),
        (64, 10),
        (123, 10),
        (8, 5),
        (8, 30),
    ])
    def test_full_cudagraph(self, batch_size, max_tokens,
                            llm_pair: tuple[LLM, LLM]):
        """
        Test various batch sizes and max_tokens to ensure that the
        full cudagraph compilation works for padded cases too.
        """

        piecewise_llm, full_cudagraph_llm = llm_pair

        prompts = ["Hello, my name is"] * batch_size
        sampling_params = SamplingParams(temperature=0.0,
                                         max_tokens=max_tokens,
                                         top_p=0.95)

        piecewise_responses = piecewise_llm.generate(prompts, sampling_params)
        full_responses = full_cudagraph_llm.generate(prompts, sampling_params)

        # Check that all responses are the same
        for piecewise_res, full_res in zip(piecewise_responses,
                                           full_responses):
            assert piecewise_res.outputs[0].text == full_res.outputs[0].text


@pytest.mark.parametrize(
    "model, supported",
    [
        ("Qwen/Qwen2-1.5B-Instruct", True),
        # MLA does not support capturing CUDA Graphs with size > max_num_seqs
        ("deepseek-ai/DeepSeek-V2-Lite", False),
    ])
@pytest.mark.skipif(current_platform.get_device_capability() != (9, 0),
                    reason="Only Hopper GPUs support FA3 and FlashMLA")
def test_lower_max_num_seqs(model, supported):
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION": "3"
    }), ExitStack() as stack:
        if not supported:
            stack.enter_context(pytest.raises(RuntimeError))

        llm = LLM(model=model,
                  max_num_seqs=256,
                  trust_remote_code=True,
                  max_model_len=1024,
                  compilation_config=CompilationConfig(
                      full_cuda_graph=True,
                      cudagraph_capture_sizes=[64, 256, 512]))
        llm.generate(["Hello, my name is"] * 10)
222
223


224
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
225
226
227
228
229
230
def test_full_cudagraph_with_invalid_backend():
    with temporary_environ({
            "VLLM_USE_V1": "1",
            "VLLM_FLASH_ATTN_VERSION":
            "2"  #FA2 not supported with full_cuda_graph
    }), pytest.raises(RuntimeError):
231
        LLM(model="Qwen/Qwen2-1.5B-Instruct",
232
            compilation_config=CompilationConfig(full_cuda_graph=True))