test_cudagraph_dispatch.py 17.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
from vllm.config import (
    CompilationConfig,
14
    CompilationMode,
15
16
17
18
19
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


36
def _create_vllm_config(
37
38
39
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
40
) -> MagicMock:
41
42
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
43
44
45
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
46
    mock_config.parallel_config = ParallelConfig()
47
    mock_config.speculative_config = None  # No speculative decoding
48
49
    if not lora_config:
        mock_config.lora_config = None
50
    # Mimic the behavior of VllmConfig.__post_init__()
51
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
52
53
        compilation_config.set_splitting_ops_for_v1()

54
55
56
57
58
59
60
61
62
63
64
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()
        mock_config.pad_for_cudagraph = (
            lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
        )

65
66
67
68
69
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
70
        "cudagraph_mode_str,compilation_mode,lora_config",
71
72
        [
            # Test case 0: Full CG for mixed batches, no separate routine
73
            ("FULL", CompilationMode.NONE, False),
74
            # Test case 1: Full CG for uniform batches, piecewise for mixed
75
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
76
            # Test case 2: Full CG for uniform batches, no CG for mixed
77
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
78
            # Test case 3: PIECEWISE for all
79
80
81
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
82
83
        ],
    )
84
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
85
        # Setup dispatcher
86
87
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
88
            mode=compilation_mode,
89
90
            cudagraph_capture_sizes=[1, 8],
        )
91

92
93
94
95
96
97
98
99
100
101
102
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

103
104
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
105
106
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
107
108

        # Verify the key is initialized correctly
109
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
110
111
112
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
                4 if lora_config else 2
            )
113
114
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
115
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
116
117
118
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
                4 if lora_config else 2
            )
119
120
121
122
123
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
124
125
        desc_full_exact = BatchDescriptor(
            num_tokens=8,
126
127
128
129
            uniform=False,
        )
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
130
        )
131
        if cudagraph_mode_str == "FULL":
132
133
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
134
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
135
136
137
138
139
140
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
141
142
143
144
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
145
        if cudagraph_mode_str == "FULL":
146
            assert rt_mode == CUDAGraphMode.FULL
147
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
148
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
149
150
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
151
        elif cudagraph_mode_str == "PIECEWISE":
152
            assert rt_mode == CUDAGraphMode.PIECEWISE
153
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
154
155
156
157
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
158
159
160
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
161
        assert rt_mode == CUDAGraphMode.NONE
162
        assert key == BatchDescriptor(num_tokens=15)
163

164
        # 4. Cascade attention should have a fall back mode
165
166
167
168
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False, use_cascade_attn=True
        )
169
170
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
171
            assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
172
173
174
        else:
            assert rt_mode == CUDAGraphMode.NONE

175
176
177
178
179
180
181
182
183
184

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
185
186
187
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
188
189
190
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
191
192
193
194
195
196
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
197
198
199
            wrapper(self.input_tensor)

        # 1. Capture
200
201
        with (
            set_forward_context(
202
203
204
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
205
206
207
208
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
209
210
211
212
213
214
215
216
217
218
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
219
220
        with (
            set_forward_context(
221
222
223
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
224
225
226
227
228
229
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
230
231
232
233
234
235
236
237
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
238
239
240
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
241
242
        batch_descriptor = BatchDescriptor(num_tokens=10)

243
244
        with (
            set_forward_context(
245
246
247
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
248
249
250
251
252
253
254
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
255
256
257
258
259
260
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
261
262
263
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
264
265
        batch_descriptor = BatchDescriptor(num_tokens=10)

266
267
        with (
            set_forward_context(
268
269
270
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
271
272
273
274
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
275
276
277
278
279
280
281
282
283
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
284
        self.comp_config = CompilationConfig(
285
            mode=CompilationMode.VLLM_COMPILE,
286
287
288
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
289
290
291
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
292
293
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
294

295
296
297
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
298
299
        """Helper to run a single call and monitor the action."""

300
301
302
303
304
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
305

306
307
308
309
310
311
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
312
313
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
314
315
316
317
318
319
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
340
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
341
342
343
344
345
346
347
348
349
350
351
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
352
353
354
355
356
357
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
358
359
360
361
            full_wrapper(input_1)

        rt_mode, key = self.dispatcher.dispatch(desc_1)
        # 1. Capture first shape
362
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
363
364
365
        assert action == "capture_global"

        # 2. Replay first shape
366
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
367
368
369
370
        assert action == "replay"

        rt_mode, key = self.dispatcher.dispatch(desc_2)
        # 3. Capture second shape
371
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
372
373
374
        assert action == "capture_global"

        # 4. Replay second shape
375
376
377
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
378
379
380
381
382
        assert action == "replay"

        # 5. Bypass if no key match
        rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
        assert rt_mode == CUDAGraphMode.NONE
383
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
384
385
386
387
388
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
389
390
391
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
392
393
394
395
396
397
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
398
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
399
400
401
402
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
403
404
405
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
406
407
408
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
409
410
411
412
413
414
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
415
416
417
418

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
419
420
421
422
423
424
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
425
426
427
428
429
430
431
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
432
433
434
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
435
436
437
438
439
440
441
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
442
443
444
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
445
446
447
448
449
450
451
452
453
454
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
455
456
457
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
458
459
460
461
462
463
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
464
465
466
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
467
468
469
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1