test_cudagraph_dispatch.py 17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
from vllm.config import (
    CompilationConfig,
14
    CompilationMode,
15
16
17
18
19
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


36
def _create_vllm_config(
37
38
39
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
40
) -> MagicMock:
41
42
43
44
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
    mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
    mock_config.parallel_config = ParallelConfig()
45
46
    if not lora_config:
        mock_config.lora_config = None
47
    # Mimic the behavior of VllmConfig.__post_init__()
48
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
49
50
51
52
53
54
55
        compilation_config.set_splitting_ops_for_v1()

    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
56
        "cudagraph_mode_str,compilation_mode,lora_config",
57
58
        [
            # Test case 0: Full CG for mixed batches, no separate routine
59
            ("FULL", CompilationMode.NONE, False),
60
            # Test case 1: Full CG for uniform batches, piecewise for mixed
61
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
62
            # Test case 2: Full CG for uniform batches, no CG for mixed
63
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
64
            # Test case 3: PIECEWISE for all
65
66
67
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
68
69
        ],
    )
70
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
71
        # Setup dispatcher
72
73
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
74
            mode=compilation_mode,
75
76
            cudagraph_capture_sizes=[1, 8],
        )
77

78
79
80
81
82
83
84
85
86
87
88
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

89
90
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
91
92
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
93
94

        # Verify the key is initialized correctly
95
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
96
97
98
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
                4 if lora_config else 2
            )
99
100
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
101
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
102
103
104
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
                4 if lora_config else 2
            )
105
106
107
108
109
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
110
111
112
113
        desc_full_exact = BatchDescriptor(
            num_tokens=8,
            uniform_decode=False,
        )
114
        rt_mode, key = dispatcher.dispatch(desc_full_exact)
115
        if cudagraph_mode_str == "FULL":
116
117
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
118
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
119
120
121
122
123
124
125
126
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
        desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
        rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
127
        if cudagraph_mode_str == "FULL":
128
129
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact.non_uniform
130
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
131
132
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
133
        elif cudagraph_mode_str == "PIECEWISE":
134
135
136
137
138
139
140
141
142
143
144
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_uniform_exact.non_uniform
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
        desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
        rt_mode, key = dispatcher.dispatch(desc_no_match)
        assert rt_mode == CUDAGraphMode.NONE
        assert key is None

145
146
        # 4. Cascade attention should have a fall back mode
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
147
        rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
148
149
150
151
152
153
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact.non_uniform
        else:
            assert rt_mode == CUDAGraphMode.NONE

154
155
156
157
158
159
160
161
162
163

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
164
165
166
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
167
168
169
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
170
171
172
173
174
175
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
176
177
178
            wrapper(self.input_tensor)

        # 1. Capture
179
180
        with (
            set_forward_context(
181
182
183
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
184
185
186
187
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
188
189
190
191
192
193
194
195
196
197
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
198
199
        with (
            set_forward_context(
200
201
202
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
203
204
205
206
207
208
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
209
210
211
212
213
214
215
216
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
217
218
219
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
220
221
        batch_descriptor = BatchDescriptor(num_tokens=10)

222
223
        with (
            set_forward_context(
224
225
226
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
227
228
229
230
231
232
233
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
234
235
236
237
238
239
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
240
241
242
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
243
244
        batch_descriptor = BatchDescriptor(num_tokens=10)

245
246
        with (
            set_forward_context(
247
248
249
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
250
251
252
253
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
254
255
256
257
258
259
260
261
262
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
263
        self.comp_config = CompilationConfig(
264
            mode=CompilationMode.VLLM_COMPILE,
265
266
267
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
268
269
270
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
271
272
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
273

274
275
276
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
277
278
        """Helper to run a single call and monitor the action."""

279
280
281
282
283
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
284

285
286
287
288
289
290
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
291
292
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
293
294
295
296
297
298
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
319
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
320
321
322
323
324
325
326
327
328
329
330
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
331
332
333
334
335
336
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
337
338
339
340
            full_wrapper(input_1)

        rt_mode, key = self.dispatcher.dispatch(desc_1)
        # 1. Capture first shape
341
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
342
343
344
        assert action == "capture_global"

        # 2. Replay first shape
345
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
346
347
348
349
        assert action == "replay"

        rt_mode, key = self.dispatcher.dispatch(desc_2)
        # 3. Capture second shape
350
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
351
352
353
        assert action == "capture_global"

        # 4. Replay second shape
354
355
356
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
357
358
359
360
361
        assert action == "replay"

        # 5. Bypass if no key match
        rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
        assert rt_mode == CUDAGraphMode.NONE
362
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
363
364
365
366
367
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
368
369
370
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
371
372
373
374
375
376
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
377
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
378
379
380
381
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
382
383
384
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
385
386
387
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
388
389
390
391
392
393
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
394
395
396
397

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
398
399
400
401
402
403
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
404
405
406
407
408
409
410
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
411
412
413
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
414
415
416
417
418
419
420
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
421
422
423
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
424
425
426
427
428
429
430
431
432
433
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
434
435
436
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
437
438
439
440
441
442
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
443
444
445
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
446
447
448
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1