"vllm/tool_parsers/minimax_m2_tool_parser.py" did not exist on "720af6ab791164175eca32c67de7cfe2994642fc"
test_cudagraph_dispatch.py 17.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
12
13
from vllm.config import (
    CompilationConfig,
14
    CompilationMode,
15
16
17
18
19
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher


# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


36
def _create_vllm_config(
37
38
39
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
40
) -> MagicMock:
41
42
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
43
44
45
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
46
    mock_config.parallel_config = ParallelConfig()
47
    mock_config.speculative_config = None  # No speculative decoding
48
49
    if not lora_config:
        mock_config.lora_config = None
50
    # Mimic the behavior of VllmConfig.__post_init__()
51
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
52
53
54
55
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
56

57
58
59
60
61
62
63
64
65
66
67
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()
        mock_config.pad_for_cudagraph = (
            lambda batch_size: compilation_config.bs_to_padded_graph_size[batch_size]
        )

68
69
70
71
72
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
73
        "cudagraph_mode_str,compilation_mode,lora_config",
74
75
        [
            # Test case 0: Full CG for mixed batches, no separate routine
76
            ("FULL", CompilationMode.NONE, False),
77
            # Test case 1: Full CG for uniform batches, piecewise for mixed
78
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
79
            # Test case 2: Full CG for uniform batches, no CG for mixed
80
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
81
            # Test case 3: PIECEWISE for all
82
83
84
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
85
86
        ],
    )
87
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
88
        # Setup dispatcher
89
90
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
91
            mode=compilation_mode,
92
93
            cudagraph_capture_sizes=[1, 8],
        )
94

95
96
97
98
99
100
101
102
103
104
105
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

106
107
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
108
109
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
110
111

        # Verify the key is initialized correctly
112
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
113
114
115
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
                4 if lora_config else 2
            )
116
117
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
118
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
119
120
121
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
                4 if lora_config else 2
            )
122
123
124
125
126
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
127
128
        desc_full_exact = BatchDescriptor(
            num_tokens=8,
129
130
131
132
            uniform=False,
        )
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
133
        )
134
        if cudagraph_mode_str == "FULL":
135
136
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_full_exact
137
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
138
139
140
141
142
143
            assert rt_mode == CUDAGraphMode.PIECEWISE
            assert key == desc_full_exact
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
144
145
146
147
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
148
        if cudagraph_mode_str == "FULL":
149
            assert rt_mode == CUDAGraphMode.FULL
150
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
151
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
152
153
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
154
        elif cudagraph_mode_str == "PIECEWISE":
155
            assert rt_mode == CUDAGraphMode.PIECEWISE
156
            assert key == desc_uniform_exact.relax_for_mixed_batch_cudagraphs()
157
158
159
160
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
161
162
163
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
164
        assert rt_mode == CUDAGraphMode.NONE
165
        assert key == BatchDescriptor(num_tokens=15)
166

167
        # 4. disable_full should have a fall back mode (e.g., cascade attention)
168
169
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
170
            num_tokens=8, uniform_decode=False, has_lora=False, disable_full=True
171
        )
172
173
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
174
            assert key == desc_full_exact.relax_for_mixed_batch_cudagraphs()
175
176
177
        else:
            assert rt_mode == CUDAGraphMode.NONE

178
179
180
181
182
183
184
185
186
187

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
        self.model = SimpleMLP().to("cuda")
        self.persistent_input_buffer = torch.zeros(1, 10, device="cuda")
        self.input_tensor = torch.randn(1, 10, device="cuda")

    def test_capture_and_replay(self):
188
189
190
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
191
192
193
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
194
195
196
197
198
199
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
200
201
202
            wrapper(self.input_tensor)

        # 1. Capture
203
204
        with (
            set_forward_context(
205
206
207
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
208
209
210
211
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
212
213
214
215
216
217
218
219
220
221
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
222
223
        with (
            set_forward_context(
224
225
226
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
227
228
229
230
231
232
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
233
234
235
236
237
238
239
240
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
241
242
243
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
244
245
        batch_descriptor = BatchDescriptor(num_tokens=10)

246
247
        with (
            set_forward_context(
248
249
250
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
251
252
253
254
255
256
257
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
258
259
260
261
262
263
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
264
265
266
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
267
268
        batch_descriptor = BatchDescriptor(num_tokens=10)

269
270
        with (
            set_forward_context(
271
272
273
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
274
275
276
277
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
278
279
280
281
282
283
284
285
286
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
287
        self.comp_config = CompilationConfig(
288
            mode=CompilationMode.VLLM_COMPILE,
289
290
291
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
292
293
294
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
295
296
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
297

298
299
300
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
301
302
        """Helper to run a single call and monitor the action."""

303
304
305
306
307
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
308

309
310
311
312
313
314
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
315
316
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
317
318
319
320
321
322
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
        model = SimpleMLP().to("cuda")
343
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
344
345
346
347
348
349
350
351
352
353
354
        max_bs = 16
        persistent_input_buffer = torch.zeros(max_bs, 10, device="cuda")
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
355
356
357
358
359
360
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
361
362
363
364
            full_wrapper(input_1)

        rt_mode, key = self.dispatcher.dispatch(desc_1)
        # 1. Capture first shape
365
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
366
367
368
        assert action == "capture_global"

        # 2. Replay first shape
369
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
370
371
372
373
        assert action == "replay"

        rt_mode, key = self.dispatcher.dispatch(desc_2)
        # 3. Capture second shape
374
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
375
376
377
        assert action == "capture_global"

        # 4. Replay second shape
378
379
380
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
381
382
383
384
385
        assert action == "replay"

        # 5. Bypass if no key match
        rt_mode, key = self.dispatcher.dispatch(desc_3_unseen)
        assert rt_mode == CUDAGraphMode.NONE
386
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
387
388
389
390
391
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
392
393
394
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
395
396
397
398
399
400
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
        model = SimpleMLP().to("cuda")
401
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
402
403
404
405
        input_1 = torch.randn(1, 10, device="cuda")

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
        inner_model = SimpleMLP().to("cuda")
406
407
408
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
409
410
411
        inner_model.forward = MagicMock(wraps=inner_model.forward)
        outer_model = SimpleMLP().to("cuda")
        # When outer model is called, it calls the piecewise_wrapper
412
413
414
415
416
417
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
418
419
420
421

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
422
423
424
425
426
427
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
428
429
430
431
432
433
434
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
435
436
437
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
438
439
440
441
442
443
444
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
445
446
447
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
448
449
450
451
452
453
454
455
456
457
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
458
459
460
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
461
462
463
464
465
466
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
467
468
469
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
470
471
472
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1