test_cudagraph_dispatch.py 21.7 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import replace
4
5
6
7
8
9
10
11
12
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
13
14
from vllm.config import (
    CompilationConfig,
15
    CompilationMode,
16
17
18
19
20
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
21
from vllm.config.lora import LoRAConfig
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


38
def _create_vllm_config(
39
40
41
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
42
) -> MagicMock:
43
44
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
45
46
47
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
48
    mock_config.parallel_config = ParallelConfig()
49
    mock_config.speculative_config = None  # No speculative decoding
50
51
    if not lora_config:
        mock_config.lora_config = None
52
53
54
55
56
57
    else:
        # Create a real LoRAConfig with specialize_active_lora enabled
        mock_config.lora_config = LoRAConfig(
            max_loras=4,
            specialize_active_lora=True,
        )
58
    # Mimic the behavior of VllmConfig.__post_init__()
59
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
60
61
62
63
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
64

65
66
67
68
69
70
71
72
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()

73
74
75
76
77
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
78
        "cudagraph_mode_str,compilation_mode,lora_config",
79
80
        [
            # Test case 0: Full CG for mixed batches, no separate routine
81
            ("FULL", CompilationMode.NONE, False),
82
            # Test case 1: Full CG for uniform batches, piecewise for mixed
83
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
84
            # Test case 2: Full CG for uniform batches, no CG for mixed
85
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
86
            # Test case 3: PIECEWISE for all
87
88
89
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
90
91
        ],
    )
92
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
93
        # Setup dispatcher
94
95
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
96
            mode=compilation_mode,
97
98
            cudagraph_capture_sizes=[1, 8],
        )
99

100
101
102
103
104
105
106
107
108
109
110
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

111
112
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
113
114
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
115
116

        # Verify the key is initialized correctly
117
118
119
120
        # With LoRA specialization (max_loras=4, specialize_active_lora=True):
        # - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
        # - capture_sizes = [1, 8]
        # - Total keys = 2 sizes × 5 lora_cases = 10
121
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
122
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
123
                10 if lora_config else 2
124
            )
125
126
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
127
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
128
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
129
                10 if lora_config else 2
130
            )
131
132
133
134
135
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
136
137
138
139
        # FULL mode uses exact keys with num_reqs set
        desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
        # PIECEWISE mode uses relaxed keys with num_reqs=None
        desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
140
141
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
142
        )
143
        if cudagraph_mode_str == "FULL":
144
            assert rt_mode == CUDAGraphMode.FULL
145
            assert key == desc_full_with_reqs
146
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
147
            assert rt_mode == CUDAGraphMode.PIECEWISE
148
            assert key == desc_piecewise
149
150
151
152
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
153
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
154
        desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
155
156
157
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
158
        if cudagraph_mode_str == "FULL":
159
            # Pure FULL mode uses non-uniform keys for all batches
160
            assert rt_mode == CUDAGraphMode.FULL
161
            assert key == desc_non_uniform
162
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
163
            # These modes have separate uniform decode keys
164
165
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
166
        elif cudagraph_mode_str == "PIECEWISE":
167
            assert rt_mode == CUDAGraphMode.PIECEWISE
168
            assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
169
170
171
172
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
173
174
175
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
176
        assert rt_mode == CUDAGraphMode.NONE
177
        assert key == BatchDescriptor(num_tokens=15)
178

179
180
        # 4. invalid_modes={FULL} should have a fall back mode
        #    (e.g., cascade attention)
181
182
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
183
184
185
186
            num_tokens=8,
            uniform_decode=False,
            has_lora=False,
            invalid_modes={CUDAGraphMode.FULL},
187
        )
188

189
190
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
191
            assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
192
193
194
        else:
            assert rt_mode == CUDAGraphMode.NONE

195
196
197
198
199
200
201
202
203
204
        # 5. valid_modes={NONE} always returns NONE even when keys exist
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8,
            uniform_decode=False,
            has_lora=False,
            valid_modes={CUDAGraphMode.NONE},
        )
        assert rt_mode == CUDAGraphMode.NONE
        assert key == BatchDescriptor(num_tokens=8)

205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
    @pytest.mark.parametrize(
        "cudagraph_mode_str,compilation_mode,expected_modes",
        [
            # FULL mode: only FULL keys, no PIECEWISE
            ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # PIECEWISE mode: only PIECEWISE keys
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
            # FULL_DECODE_ONLY: only FULL keys for uniform decode
            ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # NONE mode: no keys
            ("NONE", CompilationMode.NONE, []),
        ],
    )
    def test_get_capture_descs(
        self, cudagraph_mode_str, compilation_mode, expected_modes
    ):
        """Test get_capture_descs returns correctly grouped and ordered descs."""
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            mode=compilation_mode,
            cudagraph_capture_sizes=[1, 4, 8, 16],
        )

        config = _create_vllm_config(comp_config, max_num_seqs=16)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )

        capture_descs = dispatcher.get_capture_descs()

        # Verify we get the expected modes
        actual_modes = [mode for mode, _ in capture_descs]
        assert actual_modes == expected_modes

        # Verify each group is sorted largest-first
        for mode, descs in capture_descs:
            assert len(descs) > 0, "Each group should have at least one descriptor"
            num_tokens_list = [d.num_tokens for d in descs]
            assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
                f"Descriptors for {mode} should be sorted largest-first"
            )

            # All descriptors in a group should have same uniform value
            uniform_values = [d.uniform for d in descs]
            assert len(set(uniform_values)) == 1, (
                "All descriptors in a group should have the same uniform value"
            )

    def test_get_capture_descs_empty_when_not_initialized(self):
        """Test that get_capture_descs returns empty list when keys not initialized."""
        comp_config = CompilationConfig(
            cudagraph_mode="FULL",
            mode=CompilationMode.NONE,
            cudagraph_capture_sizes=[1, 8],
        )
        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        # Don't initialize keys

        assert dispatcher.get_capture_descs() == []

267
268
269
270
271
272
273
274
275
276

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
277
278
279
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
280
281
282
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
283
284
285
286
287
288
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
289
290
291
            wrapper(self.input_tensor)

        # 1. Capture
292
293
        with (
            set_forward_context(
294
295
296
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
297
298
299
300
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
301
302
303
304
305
306
307
308
309
310
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
311
312
        with (
            set_forward_context(
313
314
315
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
316
317
318
319
320
321
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
322
323
324
325
326
327
328
329
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
330
331
332
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
333
334
        batch_descriptor = BatchDescriptor(num_tokens=10)

335
336
        with (
            set_forward_context(
337
338
339
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
340
341
342
343
344
345
346
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
347
348
349
350
351
352
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
353
354
355
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
356
357
        batch_descriptor = BatchDescriptor(num_tokens=10)

358
359
        with (
            set_forward_context(
360
361
362
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
363
364
365
366
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
367
368
369
370
371
372
373
374
375
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
376
        self.comp_config = CompilationConfig(
377
            mode=CompilationMode.VLLM_COMPILE,
378
379
380
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
381
382
383
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
384
385
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
386

387
388
389
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
390
391
        """Helper to run a single call and monitor the action."""

392
393
394
395
396
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
397

398
399
400
401
402
403
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
404
405
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
406
407
408
409
410
411
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
432
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
433
434
435
436
437
438
439
440
441
442
443
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
444
445
446
447
448
449
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
450
451
            full_wrapper(input_1)

452
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
453
        # 1. Capture first shape
454
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
455
456
457
        assert action == "capture_global"

        # 2. Replay first shape
458
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
459
460
        assert action == "replay"

461
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
462
        # 3. Capture second shape
463
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
464
465
466
        assert action == "capture_global"

        # 4. Replay second shape
467
468
469
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
470
471
472
        assert action == "replay"

        # 5. Bypass if no key match
473
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
474
        assert rt_mode == CUDAGraphMode.NONE
475
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
476
477
478
479
480
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
481
482
483
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
484
485
486
487
488
489
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
490
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
491
492
493
494
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
495
496
497
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
498
499
500
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
501
502
503
504
505
506
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
507
508
509
510

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
511
512
513
514
515
516
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
517
518
519
520
521
522
523
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
524
525
526
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
527
528
529
530
531
532
533
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
534
535
536
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
537
538
539
540
541
542
543
544
545
546
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
547
548
549
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
550
551
552
553
554
555
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
556
557
558
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
559
560
561
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1