test_cudagraph_dispatch.py 20.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
from vllm.config import (
    CompilationConfig,
14
    CompilationMode,
15
16
17
18
19
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
from vllm.config.lora import LoRAConfig
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


37
def _create_vllm_config(
38
39
40
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
41
) -> MagicMock:
42
43
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
44
45
46
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
47
    mock_config.parallel_config = ParallelConfig()
48
    mock_config.speculative_config = None  # No speculative decoding
49
50
    if not lora_config:
        mock_config.lora_config = None
51
52
53
54
55
56
    else:
        # Create a real LoRAConfig with specialize_active_lora enabled
        mock_config.lora_config = LoRAConfig(
            max_loras=4,
            specialize_active_lora=True,
        )
57
    # Mimic the behavior of VllmConfig.__post_init__()
58
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
59
60
61
62
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
63

64
65
66
67
68
69
70
71
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()

72
73
74
75
76
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
77
        "cudagraph_mode_str,compilation_mode,lora_config",
78
79
        [
            # Test case 0: Full CG for mixed batches, no separate routine
80
            ("FULL", CompilationMode.NONE, False),
81
            # Test case 1: Full CG for uniform batches, piecewise for mixed
82
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
83
            # Test case 2: Full CG for uniform batches, no CG for mixed
84
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
85
            # Test case 3: PIECEWISE for all
86
87
88
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
89
90
        ],
    )
91
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
92
        # Setup dispatcher
93
94
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
95
            mode=compilation_mode,
96
97
            cudagraph_capture_sizes=[1, 8],
        )
98

99
100
101
102
103
104
105
106
107
108
109
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

110
111
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
112
113
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
114
115

        # Verify the key is initialized correctly
116
117
118
119
        # With LoRA specialization (max_loras=4, specialize_active_lora=True):
        # - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
        # - capture_sizes = [1, 8]
        # - Total keys = 2 sizes × 5 lora_cases = 10
120
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
121
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
122
                10 if lora_config else 2
123
            )
124
125
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
126
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
127
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
128
                10 if lora_config else 2
129
            )
130
131
132
133
134
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
135
136
        desc_full_exact = BatchDescriptor(
            num_tokens=8,
137
138
139
140
            uniform=False,
        )
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
141
        )
142
        if cudagraph_mode_str == "FULL":
143
144
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
145
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
146
147
148
149
150
151
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
152
153
154
155
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
156
        if cudagraph_mode_str == "FULL":
157
            assert rt_mode == CUDAGraphMode.FULL
158
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
159
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
160
161
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
162
        elif cudagraph_mode_str == "PIECEWISE":
163
            assert rt_mode == CUDAGraphMode.PIECEWISE
164
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
165
166
167
168
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
169
170
171
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
172
        assert rt_mode == CUDAGraphMode.NONE
173
        assert key == BatchDescriptor(num_tokens=15)
174

175
        # 4. disable_full should have a fall back mode (e.g., cascade attention)
176
177
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
178
            num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
179
        )
180

181
182
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
183
            assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
184
185
186
        else:
            assert rt_mode == CUDAGraphMode.NONE

187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
    @pytest.mark.parametrize(
        "cudagraph_mode_str,compilation_mode,expected_modes",
        [
            # FULL mode: only FULL keys, no PIECEWISE
            ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # PIECEWISE mode: only PIECEWISE keys
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
            # FULL_DECODE_ONLY: only FULL keys for uniform decode
            ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # NONE mode: no keys
            ("NONE", CompilationMode.NONE, []),
        ],
    )
    def test_get_capture_descs(
        self, cudagraph_mode_str, compilation_mode, expected_modes
    ):
        """Test get_capture_descs returns correctly grouped and ordered descs."""
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            mode=compilation_mode,
            cudagraph_capture_sizes=[1, 4, 8, 16],
        )

        config = _create_vllm_config(comp_config, max_num_seqs=16)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )

        capture_descs = dispatcher.get_capture_descs()

        # Verify we get the expected modes
        actual_modes = [mode for mode, _ in capture_descs]
        assert actual_modes == expected_modes

        # Verify each group is sorted largest-first
        for mode, descs in capture_descs:
            assert len(descs) > 0, "Each group should have at least one descriptor"
            num_tokens_list = [d.num_tokens for d in descs]
            assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
                f"Descriptors for {mode} should be sorted largest-first"
            )

            # All descriptors in a group should have same uniform value
            uniform_values = [d.uniform for d in descs]
            assert len(set(uniform_values)) == 1, (
                "All descriptors in a group should have the same uniform value"
            )

    def test_get_capture_descs_empty_when_not_initialized(self):
        """Test that get_capture_descs returns empty list when keys not initialized."""
        comp_config = CompilationConfig(
            cudagraph_mode="FULL",
            mode=CompilationMode.NONE,
            cudagraph_capture_sizes=[1, 8],
        )
        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        # Don't initialize keys

        assert dispatcher.get_capture_descs() == []

249
250
251
252
253
254
255
256
257
258

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
259
260
261
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
262
263
264
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
265
266
267
268
269
270
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
271
272
273
            wrapper(self.input_tensor)

        # 1. Capture
274
275
        with (
            set_forward_context(
276
277
278
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
279
280
281
282
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
283
284
285
286
287
288
289
290
291
292
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
293
294
        with (
            set_forward_context(
295
296
297
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
298
299
300
301
302
303
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
304
305
306
307
308
309
310
311
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
312
313
314
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
315
316
        batch_descriptor = BatchDescriptor(num_tokens=10)

317
318
        with (
            set_forward_context(
319
320
321
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
322
323
324
325
326
327
328
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
329
330
331
332
333
334
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
335
336
337
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
338
339
        batch_descriptor = BatchDescriptor(num_tokens=10)

340
341
        with (
            set_forward_context(
342
343
344
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
345
346
347
348
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
349
350
351
352
353
354
355
356
357
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
358
        self.comp_config = CompilationConfig(
359
            mode=CompilationMode.VLLM_COMPILE,
360
361
362
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
363
364
365
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
366
367
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
368

369
370
371
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
372
373
        """Helper to run a single call and monitor the action."""

374
375
376
377
378
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
379

380
381
382
383
384
385
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
386
387
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
388
389
390
391
392
393
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
414
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
415
416
417
418
419
420
421
422
423
424
425
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
426
427
428
429
430
431
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
432
433
            full_wrapper(input_1)

434
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
435
        # 1. Capture first shape
436
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
437
438
439
        assert action == "capture_global"

        # 2. Replay first shape
440
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
441
442
        assert action == "replay"

443
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
444
        # 3. Capture second shape
445
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
446
447
448
        assert action == "capture_global"

        # 4. Replay second shape
449
450
451
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
452
453
454
        assert action == "replay"

        # 5. Bypass if no key match
455
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
456
        assert rt_mode == CUDAGraphMode.NONE
457
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
458
459
460
461
462
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
463
464
465
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
466
467
468
469
470
471
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
472
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
473
474
475
476
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
477
478
479
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
480
481
482
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
483
484
485
486
487
488
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
489
490
491
492

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
493
494
495
496
497
498
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
499
500
501
502
503
504
505
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
506
507
508
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
509
510
511
512
513
514
515
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
516
517
518
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
519
520
521
522
523
524
525
526
527
528
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
529
530
531
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
532
533
534
535
536
537
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
538
539
540
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
541
542
543
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1