test_cudagraph_dispatch.py 20.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
from vllm.config import (
    CompilationConfig,
14
    CompilationMode,
15
16
17
18
19
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


36
def _create_vllm_config(
37
38
39
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
40
) -> MagicMock:
41
42
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
43
44
45
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
46
    mock_config.parallel_config = ParallelConfig()
47
    mock_config.speculative_config = None  # No speculative decoding
48
49
    if not lora_config:
        mock_config.lora_config = None
50
    # Mimic the behavior of VllmConfig.__post_init__()
51
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
52
53
54
55
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
56

57
58
59
60
61
62
63
64
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()

65
66
67
68
69
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
70
        "cudagraph_mode_str,compilation_mode,lora_config",
71
72
        [
            # Test case 0: Full CG for mixed batches, no separate routine
73
            ("FULL", CompilationMode.NONE, False),
74
            # Test case 1: Full CG for uniform batches, piecewise for mixed
75
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
76
            # Test case 2: Full CG for uniform batches, no CG for mixed
77
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
78
            # Test case 3: PIECEWISE for all
79
80
81
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
82
83
        ],
    )
84
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
85
        # Setup dispatcher
86
87
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
88
            mode=compilation_mode,
89
90
            cudagraph_capture_sizes=[1, 8],
        )
91

92
93
94
95
96
97
98
99
100
101
102
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

103
104
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
105
106
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
107
108

        # Verify the key is initialized correctly
109
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
110
111
112
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
                4 if lora_config else 2
            )
113
114
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
115
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
116
117
118
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
                4 if lora_config else 2
            )
119
120
121
122
123
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
124
125
        desc_full_exact = BatchDescriptor(
            num_tokens=8,
126
127
128
129
            uniform=False,
        )
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
130
        )
131
        if cudagraph_mode_str == "FULL":
132
133
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
134
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
135
136
137
138
139
140
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
141
142
143
144
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
145
        if cudagraph_mode_str == "FULL":
146
            assert rt_mode == CUDAGraphMode.FULL
147
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
148
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
149
150
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
151
        elif cudagraph_mode_str == "PIECEWISE":
152
            assert rt_mode == CUDAGraphMode.PIECEWISE
153
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
154
155
156
157
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
158
159
160
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
161
        assert rt_mode == CUDAGraphMode.NONE
162
        assert key == BatchDescriptor(num_tokens=15)
163

164
        # 4. disable_full should have a fall back mode (e.g., cascade attention)
165
166
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
167
            num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
168
        )
169

170
171
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
172
            assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
173
174
175
        else:
            assert rt_mode == CUDAGraphMode.NONE

176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
    @pytest.mark.parametrize(
        "cudagraph_mode_str,compilation_mode,expected_modes",
        [
            # FULL mode: only FULL keys, no PIECEWISE
            ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # PIECEWISE mode: only PIECEWISE keys
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
            # FULL_DECODE_ONLY: only FULL keys for uniform decode
            ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # NONE mode: no keys
            ("NONE", CompilationMode.NONE, []),
        ],
    )
    def test_get_capture_descs(
        self, cudagraph_mode_str, compilation_mode, expected_modes
    ):
        """Test get_capture_descs returns correctly grouped and ordered descs."""
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            mode=compilation_mode,
            cudagraph_capture_sizes=[1, 4, 8, 16],
        )

        config = _create_vllm_config(comp_config, max_num_seqs=16)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )

        capture_descs = dispatcher.get_capture_descs()

        # Verify we get the expected modes
        actual_modes = [mode for mode, _ in capture_descs]
        assert actual_modes == expected_modes

        # Verify each group is sorted largest-first
        for mode, descs in capture_descs:
            assert len(descs) > 0, "Each group should have at least one descriptor"
            num_tokens_list = [d.num_tokens for d in descs]
            assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
                f"Descriptors for {mode} should be sorted largest-first"
            )

            # All descriptors in a group should have same uniform value
            uniform_values = [d.uniform for d in descs]
            assert len(set(uniform_values)) == 1, (
                "All descriptors in a group should have the same uniform value"
            )

    def test_get_capture_descs_empty_when_not_initialized(self):
        """Test that get_capture_descs returns empty list when keys not initialized."""
        comp_config = CompilationConfig(
            cudagraph_mode="FULL",
            mode=CompilationMode.NONE,
            cudagraph_capture_sizes=[1, 8],
        )
        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        # Don't initialize keys

        assert dispatcher.get_capture_descs() == []

238
239
240
241
242
243
244
245
246
247

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
248
249
250
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
251
252
253
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
254
255
256
257
258
259
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
260
261
262
            wrapper(self.input_tensor)

        # 1. Capture
263
264
        with (
            set_forward_context(
265
266
267
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
268
269
270
271
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
272
273
274
275
276
277
278
279
280
281
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
282
283
        with (
            set_forward_context(
284
285
286
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
287
288
289
290
291
292
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
293
294
295
296
297
298
299
300
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
301
302
303
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
304
305
        batch_descriptor = BatchDescriptor(num_tokens=10)

306
307
        with (
            set_forward_context(
308
309
310
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
311
312
313
314
315
316
317
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
318
319
320
321
322
323
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
324
325
326
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
327
328
        batch_descriptor = BatchDescriptor(num_tokens=10)

329
330
        with (
            set_forward_context(
331
332
333
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
334
335
336
337
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
338
339
340
341
342
343
344
345
346
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
347
        self.comp_config = CompilationConfig(
348
            mode=CompilationMode.VLLM_COMPILE,
349
350
351
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
352
353
354
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
355
356
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
357

358
359
360
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
361
362
        """Helper to run a single call and monitor the action."""

363
364
365
366
367
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
368

369
370
371
372
373
374
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
375
376
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
377
378
379
380
381
382
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
403
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
404
405
406
407
408
409
410
411
412
413
414
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
415
416
417
418
419
420
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
421
422
            full_wrapper(input_1)

423
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
424
        # 1. Capture first shape
425
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
426
427
428
        assert action == "capture_global"

        # 2. Replay first shape
429
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
430
431
        assert action == "replay"

432
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
433
        # 3. Capture second shape
434
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
435
436
437
        assert action == "capture_global"

        # 4. Replay second shape
438
439
440
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
441
442
443
        assert action == "replay"

        # 5. Bypass if no key match
444
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
445
        assert rt_mode == CUDAGraphMode.NONE
446
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
447
448
449
450
451
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
452
453
454
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
455
456
457
458
459
460
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
461
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
462
463
464
465
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
466
467
468
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
469
470
471
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
472
473
474
475
476
477
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
478
479
480
481

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
482
483
484
485
486
487
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
488
489
490
491
492
493
494
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
495
496
497
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
498
499
500
501
502
503
504
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
505
506
507
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
508
509
510
511
512
513
514
515
516
517
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
518
519
520
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
521
522
523
524
525
526
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
527
528
529
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
530
531
532
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1