test_cudagraph_dispatch.py 16.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
14
15
16
17
18
19
from vllm.config import (
    CompilationConfig,
    CompilationLevel,
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


36
37
38
def _create_vllm_config(
    compilation_config: CompilationConfig, max_num_seqs: int = 8
) -> MagicMock:
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
    mock_config.scheduler_config = SchedulerConfig(max_num_seqs=max_num_seqs)
    mock_config.parallel_config = ParallelConfig()

    # Mimic the behavior of VllmConfig.__post_init__()
    if compilation_config.level == CompilationLevel.PIECEWISE:
        compilation_config.set_splitting_ops_for_v1()

    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
53
        "case_id,cudagraph_mode_str,compilation_level",
54
55
        [
            # Test case 0: Full CG for mixed batches, no separate routine
56
            (0, "FULL", CompilationLevel.NO_COMPILATION),
57
            # Test case 1: Full CG for uniform batches, piecewise for mixed
58
            (1, "FULL_AND_PIECEWISE", CompilationLevel.NO_COMPILATION),
59
            # Test case 2: Full CG for uniform batches, no CG for mixed
60
            (2, "FULL_DECODE_ONLY", CompilationLevel.NO_COMPILATION),
61
            # Test case 3: Piecewise for all
62
            (3, "PIECEWISE", CompilationLevel.PIECEWISE),
63
64
        ],
    )
65
    def test_dispatcher(self, cudagraph_mode_str, compilation_level):
66
        # Setup dispatcher
67
68
69
70
71
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            level=compilation_level,
            cudagraph_capture_sizes=[1, 8],
        )
72
73
74
75

        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
76
77
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
78
79

        # Verify the key is initialized correctly
80
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
81
82
83
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 2
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
84
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
85
86
87
88
89
90
91
92
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 2
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
        rt_mode, key = dispatcher.dispatch(desc_full_exact)
93
        if cudagraph_mode_str == "FULL":
94
95
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
96
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
97
98
99
100
101
102
103
104
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
        desc_uniform_exact = BatchDescriptor(num_tokens=8, uniform_decode=True)
        rt_mode, key = dispatcher.dispatch(desc_uniform_exact)
105
        if cudagraph_mode_str == "FULL":
106
107
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact.non_uniform
108
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
109
110
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
111
        elif cudagraph_mode_str == "PIECEWISE":
112
113
114
115
116
117
118
119
120
121
122
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_uniform_exact.non_uniform
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
        desc_no_match = BatchDescriptor(num_tokens=15, uniform_decode=False)
        rt_mode, key = dispatcher.dispatch(desc_no_match)
        assert rt_mode == CUDAGraphMode.NONE
        assert key is None

123
124
        # 4. Cascade attention should have a fall back mode
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform_decode=False)
125
        rt_mode, key = dispatcher.dispatch(desc_full_exact, use_cascade_attn=True)
126
127
128
129
130
131
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact.non_uniform
        else:
            assert rt_mode == CUDAGraphMode.NONE

132
133
134
135
136
137
138
139
140
141
142

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    @create_new_process_for_each_test("spawn")
    def test_capture_and_replay(self):
143
144
145
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
146
147
148
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
149
150
151
152
153
154
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
155
156
157
            wrapper(self.input_tensor)

        # 1. Capture
158
159
        with (
            set_forward_context(
160
161
162
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
163
164
165
166
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
167
168
169
170
171
172
173
174
175
176
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
177
178
        with (
            set_forward_context(
179
180
181
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
182
183
184
185
186
187
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
188
189
190
191
192
193
194
195
196
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    @create_new_process_for_each_test("spawn")
    def test_bypass_on_mode_mismatch(self):
197
198
199
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
200
201
        batch_descriptor = BatchDescriptor(num_tokens=10)

202
203
        with (
            set_forward_context(
204
205
206
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
207
208
209
210
211
212
213
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
214
215
216
217
218
219
220
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    @create_new_process_for_each_test("spawn")
    def test_bypass_on_mode_none(self):
221
222
223
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
224
225
        batch_descriptor = BatchDescriptor(num_tokens=10)

226
227
        with (
            set_forward_context(
228
229
230
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
231
232
233
234
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
235
236
237
238
239
240
241
242
243
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
244
245
246
247
248
        self.comp_config = CompilationConfig(
            level=CompilationLevel.PIECEWISE,
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
249
250
251
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
252
253
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
254

255
256
257
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
258
259
        """Helper to run a single call and monitor the action."""

260
261
262
263
264
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
265

266
267
268
269
270
271
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
272
273
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
274
275
276
277
278
279
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
300
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
301
302
303
304
305
306
307
308
309
310
311
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
312
313
314
315
316
317
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
318
319
320
321
            full_wrapper(input_1)

        rt_mode, key = self.dispatcher.dispatch(desc_1)
        # 1. Capture first shape
322
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
323
324
325
        assert action == "capture_global"

        # 2. Replay first shape
326
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
327
328
329
330
        assert action == "replay"

        rt_mode, key = self.dispatcher.dispatch(desc_2)
        # 3. Capture second shape
331
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
332
333
334
        assert action == "capture_global"

        # 4. Replay second shape
335
336
337
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
338
339
340
341
342
        assert action == "replay"

        # 5. Bypass if no key match
        rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
        assert rt_mode == CUDAGraphMode.NONE
343
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
344
345
346
347
348
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
349
350
351
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
352
353
354
355
356
357
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
358
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
359
360
361
362
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
363
364
365
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
366
367
368
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
369
370
371
372
373
374
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
375
376
377
378

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
379
380
381
382
383
384
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
385
386
387
388
389
390
391
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
392
393
394
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
395
396
397
398
399
400
401
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
402
403
404
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
405
406
407
408
409
410
411
412
413
414
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
415
416
417
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
418
419
420
421
422
423
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
424
425
426
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
427
428
429
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1