test_cudagraph_dispatch.py 21.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import replace
4
5
6
7
8
9
10
11
12
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
13
14
from vllm.config import (
    CompilationConfig,
15
    CompilationMode,
16
17
18
19
20
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
21
from vllm.config.lora import LoRAConfig
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


38
def _create_vllm_config(
39
40
41
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
42
) -> MagicMock:
43
44
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
45
46
47
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
48
    mock_config.parallel_config = ParallelConfig()
49
    mock_config.speculative_config = None  # No speculative decoding
50
51
    if not lora_config:
        mock_config.lora_config = None
52
53
54
55
56
57
    else:
        # Create a real LoRAConfig with specialize_active_lora enabled
        mock_config.lora_config = LoRAConfig(
            max_loras=4,
            specialize_active_lora=True,
        )
58
    # Mimic the behavior of VllmConfig.__post_init__()
59
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
60
61
62
63
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
64

65
66
67
68
69
70
71
72
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()

73
74
75
76
77
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
78
        "cudagraph_mode_str,compilation_mode,lora_config",
79
80
        [
            # Test case 0: Full CG for mixed batches, no separate routine
81
            ("FULL", CompilationMode.NONE, False),
82
            # Test case 1: Full CG for uniform batches, piecewise for mixed
83
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
84
            # Test case 2: Full CG for uniform batches, no CG for mixed
85
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
86
            # Test case 3: PIECEWISE for all
87
88
89
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
90
91
        ],
    )
92
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
93
        # Setup dispatcher
94
95
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
96
            mode=compilation_mode,
97
98
            cudagraph_capture_sizes=[1, 8],
        )
99

100
101
102
103
104
105
106
107
108
109
110
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

111
112
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
113
114
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
115
116

        # Verify the key is initialized correctly
117
118
119
120
        # With LoRA specialization (max_loras=4, specialize_active_lora=True):
        # - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
        # - capture_sizes = [1, 8]
        # - Total keys = 2 sizes × 5 lora_cases = 10
121
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
122
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
123
                10 if lora_config else 2
124
            )
125
126
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
127
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
128
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
129
                10 if lora_config else 2
130
            )
131
132
133
134
135
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
136
137
138
139
        # FULL mode uses exact keys with num_reqs set
        desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
        # PIECEWISE mode uses relaxed keys with num_reqs=None
        desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
140
141
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
142
        )
143
        if cudagraph_mode_str == "FULL":
144
            assert rt_mode == CUDAGraphMode.FULL
145
            assert key == desc_full_with_reqs
146
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
147
            assert rt_mode == CUDAGraphMode.PIECEWISE
148
            assert key == desc_piecewise
149
150
151
152
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
153
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
154
        desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
155
156
157
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
158
        if cudagraph_mode_str == "FULL":
159
            # Pure FULL mode uses non-uniform keys for all batches
160
            assert rt_mode == CUDAGraphMode.FULL
161
            assert key == desc_non_uniform
162
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
163
            # These modes have separate uniform decode keys
164
165
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
166
        elif cudagraph_mode_str == "PIECEWISE":
167
            assert rt_mode == CUDAGraphMode.PIECEWISE
168
            assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
169
170
171
172
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
173
174
175
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
176
        assert rt_mode == CUDAGraphMode.NONE
177
        assert key == BatchDescriptor(num_tokens=15)
178

179
        # 4. disable_full should have a fall back mode (e.g., cascade attention)
180
181
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
182
            num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
183
        )
184

185
186
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
187
            assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
188
189
190
        else:
            assert rt_mode == CUDAGraphMode.NONE

191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
    @pytest.mark.parametrize(
        "cudagraph_mode_str,compilation_mode,expected_modes",
        [
            # FULL mode: only FULL keys, no PIECEWISE
            ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # PIECEWISE mode: only PIECEWISE keys
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
            # FULL_DECODE_ONLY: only FULL keys for uniform decode
            ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # NONE mode: no keys
            ("NONE", CompilationMode.NONE, []),
        ],
    )
    def test_get_capture_descs(
        self, cudagraph_mode_str, compilation_mode, expected_modes
    ):
        """Test get_capture_descs returns correctly grouped and ordered descs."""
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            mode=compilation_mode,
            cudagraph_capture_sizes=[1, 4, 8, 16],
        )

        config = _create_vllm_config(comp_config, max_num_seqs=16)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )

        capture_descs = dispatcher.get_capture_descs()

        # Verify we get the expected modes
        actual_modes = [mode for mode, _ in capture_descs]
        assert actual_modes == expected_modes

        # Verify each group is sorted largest-first
        for mode, descs in capture_descs:
            assert len(descs) > 0, "Each group should have at least one descriptor"
            num_tokens_list = [d.num_tokens for d in descs]
            assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
                f"Descriptors for {mode} should be sorted largest-first"
            )

            # All descriptors in a group should have same uniform value
            uniform_values = [d.uniform for d in descs]
            assert len(set(uniform_values)) == 1, (
                "All descriptors in a group should have the same uniform value"
            )

    def test_get_capture_descs_empty_when_not_initialized(self):
        """Test that get_capture_descs returns empty list when keys not initialized."""
        comp_config = CompilationConfig(
            cudagraph_mode="FULL",
            mode=CompilationMode.NONE,
            cudagraph_capture_sizes=[1, 8],
        )
        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        # Don't initialize keys

        assert dispatcher.get_capture_descs() == []

253
254
255
256
257
258
259
260
261
262

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
263
264
265
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
266
267
268
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
269
270
271
272
273
274
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
275
276
277
            wrapper(self.input_tensor)

        # 1. Capture
278
279
        with (
            set_forward_context(
280
281
282
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
283
284
285
286
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
287
288
289
290
291
292
293
294
295
296
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
297
298
        with (
            set_forward_context(
299
300
301
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
302
303
304
305
306
307
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
308
309
310
311
312
313
314
315
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
316
317
318
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
319
320
        batch_descriptor = BatchDescriptor(num_tokens=10)

321
322
        with (
            set_forward_context(
323
324
325
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
326
327
328
329
330
331
332
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
333
334
335
336
337
338
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
339
340
341
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
342
343
        batch_descriptor = BatchDescriptor(num_tokens=10)

344
345
        with (
            set_forward_context(
346
347
348
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
349
350
351
352
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
353
354
355
356
357
358
359
360
361
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
362
        self.comp_config = CompilationConfig(
363
            mode=CompilationMode.VLLM_COMPILE,
364
365
366
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
367
368
369
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
370
371
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
372

373
374
375
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
376
377
        """Helper to run a single call and monitor the action."""

378
379
380
381
382
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
383

384
385
386
387
388
389
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
390
391
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
392
393
394
395
396
397
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
418
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
419
420
421
422
423
424
425
426
427
428
429
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
430
431
432
433
434
435
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
436
437
            full_wrapper(input_1)

438
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
439
        # 1. Capture first shape
440
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
441
442
443
        assert action == "capture_global"

        # 2. Replay first shape
444
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
445
446
        assert action == "replay"

447
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
448
        # 3. Capture second shape
449
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
450
451
452
        assert action == "capture_global"

        # 4. Replay second shape
453
454
455
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
456
457
458
        assert action == "replay"

        # 5. Bypass if no key match
459
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
460
        assert rt_mode == CUDAGraphMode.NONE
461
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
462
463
464
465
466
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
467
468
469
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
470
471
472
473
474
475
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
476
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
477
478
479
480
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
481
482
483
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
484
485
486
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
487
488
489
490
491
492
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
493
494
495
496

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
497
498
499
500
501
502
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
503
504
505
506
507
508
509
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
510
511
512
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
513
514
515
516
517
518
519
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
520
521
522
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
523
524
525
526
527
528
529
530
531
532
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
533
534
535
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
536
537
538
539
540
541
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
542
543
544
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
545
546
547
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1