test_cudagraph_dispatch.py 21.8 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from dataclasses import replace
4
5
6
7
8
9
10
11
12
from unittest.mock import MagicMock, patch

import pytest
import torch
import torch.nn as nn

from tests.utils import create_new_process_for_each_test
from vllm.compilation.cuda_graph import CUDAGraphWrapper
from vllm.compilation.monitor import set_cudagraph_capturing_enabled
13
14
from vllm.config import (
    CompilationConfig,
15
    CompilationMode,
16
17
18
19
20
    CUDAGraphMode,
    ParallelConfig,
    SchedulerConfig,
    VllmConfig,
)
21
from vllm.config.lora import LoRAConfig
22
23
24
25
from vllm.forward_context import BatchDescriptor, set_forward_context
from vllm.platforms import current_platform
from vllm.v1.cudagraph_dispatcher import CudagraphDispatcher

26
27
DEVICE_TYPE = current_platform.device_type

28
29
30
31
32
33
34
35
36
37
38
39

# Helper MLP for testing
class SimpleMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 10)

    def forward(self, x):
        return self.fc2(self.fc1(x))


40
def _create_vllm_config(
41
42
43
    compilation_config: CompilationConfig,
    max_num_seqs: int = 8,
    lora_config: bool = False,
44
) -> MagicMock:
45
46
    mock_config = MagicMock(spec=VllmConfig)
    mock_config.compilation_config = compilation_config
47
48
49
    mock_config.scheduler_config = SchedulerConfig.default_factory(
        max_num_seqs=max_num_seqs,
    )
50
    mock_config.parallel_config = ParallelConfig()
51
    mock_config.speculative_config = None  # No speculative decoding
52
53
    if not lora_config:
        mock_config.lora_config = None
54
55
56
57
58
59
    else:
        # Create a real LoRAConfig with specialize_active_lora enabled
        mock_config.lora_config = LoRAConfig(
            max_loras=4,
            specialize_active_lora=True,
        )
60
    # Mimic the behavior of VllmConfig.__post_init__()
61
    if compilation_config.mode == CompilationMode.VLLM_COMPILE:
62
63
64
65
        compilation_config.set_splitting_ops_for_v1(
            all2all_backend=mock_config.parallel_config.all2all_backend,
            data_parallel_size=mock_config.parallel_config.data_parallel_size,
        )
66

67
68
69
70
71
72
73
74
    # mimic VllmConfig.__post_init__
    if compilation_config.cudagraph_capture_sizes:
        compilation_config.max_cudagraph_capture_size = (
            compilation_config.cudagraph_capture_sizes[-1]
        )

        compilation_config.post_init_cudagraph_sizes()

75
76
77
78
79
    return mock_config


class TestCudagraphDispatcher:
    @pytest.mark.parametrize(
80
        "cudagraph_mode_str,compilation_mode,lora_config",
81
82
        [
            # Test case 0: Full CG for mixed batches, no separate routine
83
            ("FULL", CompilationMode.NONE, False),
84
            # Test case 1: Full CG for uniform batches, piecewise for mixed
85
            ("FULL_AND_PIECEWISE", CompilationMode.NONE, False),
86
            # Test case 2: Full CG for uniform batches, no CG for mixed
87
            ("FULL_DECODE_ONLY", CompilationMode.NONE, False),
88
            # Test case 3: PIECEWISE for all
89
90
91
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, False),
            # Test case 4: PIECEWISE for all, specialize LoRA cases
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, True),
92
93
        ],
    )
94
    def test_dispatcher(self, cudagraph_mode_str, compilation_mode, lora_config):
95
        # Setup dispatcher
96
97
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
98
            mode=compilation_mode,
99
100
            cudagraph_capture_sizes=[1, 8],
        )
101

102
103
104
105
106
107
108
109
110
111
112
        config = _create_vllm_config(
            comp_config, max_num_seqs=8, lora_config=lora_config
        )
        if (
            cudagraph_mode_str == "FULL_AND_PIECEWISE"
            and compilation_mode == CompilationMode.NONE
        ):
            with pytest.raises(AssertionError):
                dispatcher = CudagraphDispatcher(config)
            return

113
114
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
115
116
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
117
118

        # Verify the key is initialized correctly
119
120
121
122
        # With LoRA specialization (max_loras=4, specialize_active_lora=True):
        # - lora_cases = [0, 1, 2, 4, 5] (no-lora + powers of 2 up to 4 + max_loras+1)
        # - capture_sizes = [1, 8]
        # - Total keys = 2 sizes × 5 lora_cases = 10
123
        if cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
124
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == (
125
                10 if lora_config else 2
126
            )
127
128
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.PIECEWISE]) == 0
129
        if cudagraph_mode_str not in ["NONE", "PIECEWISE"]:
130
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == (
131
                10 if lora_config else 2
132
            )
133
134
135
136
137
        else:
            assert len(dispatcher.cudagraph_keys[CUDAGraphMode.FULL]) == 0

        # Test dispatch logic
        # 1. non-uniform batch, size in cudagraph size list
138
139
140
141
        # FULL mode uses exact keys with num_reqs set
        desc_full_with_reqs = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
        # PIECEWISE mode uses relaxed keys with num_reqs=None
        desc_piecewise = BatchDescriptor(num_tokens=8, num_reqs=None, uniform=False)
142
143
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=False, has_lora=False
144
        )
145
        if cudagraph_mode_str == "FULL":
146
            assert rt_mode == CUDAGraphMode.FULL
147
            assert key == desc_full_with_reqs
148
        elif cudagraph_mode_str in ["FULL_AND_PIECEWISE", "PIECEWISE"]:
149
            assert rt_mode == CUDAGraphMode.PIECEWISE
150
            assert key == desc_piecewise
151
152
153
154
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 2. uniform decode batch, size in cudagraph size list
155
        desc_uniform_exact = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=True)
156
        desc_non_uniform = BatchDescriptor(num_tokens=8, num_reqs=8, uniform=False)
157
158
159
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8, uniform_decode=True, has_lora=False
        )
160
        if cudagraph_mode_str == "FULL":
161
            # Pure FULL mode uses non-uniform keys for all batches
162
            assert rt_mode == CUDAGraphMode.FULL
163
            assert key == desc_non_uniform
164
        elif cudagraph_mode_str in ["FULL_DECODE_ONLY", "FULL_AND_PIECEWISE"]:
165
            # These modes have separate uniform decode keys
166
167
            assert rt_mode == CUDAGraphMode.FULL
            assert key == desc_uniform_exact
168
        elif cudagraph_mode_str == "PIECEWISE":
169
            assert rt_mode == CUDAGraphMode.PIECEWISE
170
            assert key == replace(desc_uniform_exact, num_reqs=None, uniform=False)
171
172
173
174
        else:
            assert rt_mode == CUDAGraphMode.NONE

        # 3. No key match
175
176
177
        rt_mode, key = dispatcher.dispatch(
            num_tokens=15, uniform_decode=False, has_lora=False
        )
178
        assert rt_mode == CUDAGraphMode.NONE
179
        assert key == BatchDescriptor(num_tokens=15)
180

181
182
        # 4. invalid_modes={FULL} should have a fall back mode
        #    (e.g., cascade attention)
183
184
        desc_full_exact = BatchDescriptor(num_tokens=8, uniform=False)
        rt_mode, key = dispatcher.dispatch(
185
186
187
188
            num_tokens=8,
            uniform_decode=False,
            has_lora=False,
            invalid_modes={CUDAGraphMode.FULL},
189
        )
190

191
192
        if "PIECEWISE" in cudagraph_mode_str:  # string contains check
            assert rt_mode == CUDAGraphMode.PIECEWISE
193
            assert key == replace(desc_full_exact, num_reqs=None, uniform=False)
194
195
196
        else:
            assert rt_mode == CUDAGraphMode.NONE

197
198
199
200
201
202
203
204
205
206
        # 5. valid_modes={NONE} always returns NONE even when keys exist
        rt_mode, key = dispatcher.dispatch(
            num_tokens=8,
            uniform_decode=False,
            has_lora=False,
            valid_modes={CUDAGraphMode.NONE},
        )
        assert rt_mode == CUDAGraphMode.NONE
        assert key == BatchDescriptor(num_tokens=8)

207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    @pytest.mark.parametrize(
        "cudagraph_mode_str,compilation_mode,expected_modes",
        [
            # FULL mode: only FULL keys, no PIECEWISE
            ("FULL", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # PIECEWISE mode: only PIECEWISE keys
            ("PIECEWISE", CompilationMode.VLLM_COMPILE, [CUDAGraphMode.PIECEWISE]),
            # FULL_DECODE_ONLY: only FULL keys for uniform decode
            ("FULL_DECODE_ONLY", CompilationMode.NONE, [CUDAGraphMode.FULL]),
            # NONE mode: no keys
            ("NONE", CompilationMode.NONE, []),
        ],
    )
    def test_get_capture_descs(
        self, cudagraph_mode_str, compilation_mode, expected_modes
    ):
        """Test get_capture_descs returns correctly grouped and ordered descs."""
        comp_config = CompilationConfig(
            cudagraph_mode=cudagraph_mode_str,
            mode=compilation_mode,
            cudagraph_capture_sizes=[1, 4, 8, 16],
        )

        config = _create_vllm_config(comp_config, max_num_seqs=16)
        dispatcher = CudagraphDispatcher(config)
        dispatcher.initialize_cudagraph_keys(
            cudagraph_mode=comp_config.cudagraph_mode, uniform_decode_query_len=1
        )

        capture_descs = dispatcher.get_capture_descs()

        # Verify we get the expected modes
        actual_modes = [mode for mode, _ in capture_descs]
        assert actual_modes == expected_modes

        # Verify each group is sorted largest-first
        for mode, descs in capture_descs:
            assert len(descs) > 0, "Each group should have at least one descriptor"
            num_tokens_list = [d.num_tokens for d in descs]
            assert num_tokens_list == sorted(num_tokens_list, reverse=True), (
                f"Descriptors for {mode} should be sorted largest-first"
            )

            # All descriptors in a group should have same uniform value
            uniform_values = [d.uniform for d in descs]
            assert len(set(uniform_values)) == 1, (
                "All descriptors in a group should have the same uniform value"
            )

    def test_get_capture_descs_empty_when_not_initialized(self):
        """Test that get_capture_descs returns empty list when keys not initialized."""
        comp_config = CompilationConfig(
            cudagraph_mode="FULL",
            mode=CompilationMode.NONE,
            cudagraph_capture_sizes=[1, 8],
        )
        config = _create_vllm_config(comp_config, max_num_seqs=8)
        dispatcher = CudagraphDispatcher(config)
        # Don't initialize keys

        assert dispatcher.get_capture_descs() == []

269
270
271
272
273

@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCUDAGraphWrapper:
    def setup_method(self):
        self.vllm_config = _create_vllm_config(CompilationConfig())
274
275
276
        self.model = SimpleMLP().to(DEVICE_TYPE)
        self.persistent_input_buffer = torch.zeros(1, 10, device=DEVICE_TYPE)
        self.input_tensor = torch.randn(1, 10, device=DEVICE_TYPE)
277
278

    def test_capture_and_replay(self):
279
280
281
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
282
283
284
        batch_descriptor = BatchDescriptor(num_tokens=10)

        # 0. global warmup
285
286
287
288
289
290
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
291
292
293
            wrapper(self.input_tensor)

        # 1. Capture
294
295
        with (
            set_forward_context(
296
297
298
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
299
300
301
302
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
303
304
305
306
307
308
309
310
311
312
            output1 = wrapper(self.input_tensor)
            # capturing phase should generate a zero output
            assert torch.allclose(output1, torch.zeros_like(output1))
            mock_cuda_graph.assert_called_once()

        assert batch_descriptor in wrapper.concrete_cudagraph_entries
        entry = wrapper.concrete_cudagraph_entries[batch_descriptor]
        assert entry.cudagraph is not None

        # 2. Replay
313
314
        with (
            set_forward_context(
315
316
317
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.FULL,
318
319
320
321
322
323
                batch_descriptor=batch_descriptor,
            ),
            patch.object(
                entry.cudagraph, "replay", wraps=entry.cudagraph.replay
            ) as mock_replay,
        ):
324
325
326
327
328
329
330
331
            output2 = wrapper(self.input_tensor)
            mock_replay.assert_called_once()

        # Compare with eager output
        eager_output = self.model(self.input_tensor)
        torch.testing.assert_close(eager_output, output2)

    def test_bypass_on_mode_mismatch(self):
332
333
334
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
335
336
        batch_descriptor = BatchDescriptor(num_tokens=10)

337
338
        with (
            set_forward_context(
339
340
341
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE,
342
343
344
345
346
347
348
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
            patch.object(
                self.model, "forward", wraps=self.model.forward
            ) as mock_forward,
        ):
349
350
351
352
353
354
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
            mock_forward.assert_called_once()
        assert not wrapper.concrete_cudagraph_entries

    def test_bypass_on_mode_none(self):
355
356
357
        wrapper = CUDAGraphWrapper(
            self.model, self.vllm_config, runtime_mode=CUDAGraphMode.FULL
        )
358
359
        batch_descriptor = BatchDescriptor(num_tokens=10)

360
361
        with (
            set_forward_context(
362
363
364
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=CUDAGraphMode.NONE,
365
366
367
368
                batch_descriptor=batch_descriptor,
            ),
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_cuda_graph,
        ):
369
370
371
372
373
374
375
376
377
            wrapper(self.input_tensor)
            mock_cuda_graph.assert_not_called()
        assert not wrapper.concrete_cudagraph_entries


@pytest.mark.skipif(not current_platform.is_cuda(), reason="Skip if not cuda")
class TestCudagraphIntegration:
    def setup_method(self):
        # only FULL mode for non-uniform batches
378
        self.comp_config = CompilationConfig(
379
            mode=CompilationMode.VLLM_COMPILE,
380
381
382
            cudagraph_mode="FULL",
            cudagraph_capture_sizes=[10, 20],
        )
383
384
385
        self.vllm_config = _create_vllm_config(self.comp_config)
        self.dispatcher = CudagraphDispatcher(self.vllm_config)
        self.dispatcher.initialize_cudagraph_keys(
386
387
            self.comp_config.cudagraph_mode, uniform_decode_query_len=1
        )
388

389
390
391
    def _run_and_monitor_call(
        self, wrapper, input_tensor, runtime_mode, batch_descriptor
    ):
392
393
        """Helper to run a single call and monitor the action."""

394
395
396
397
398
        with (
            patch("torch.cuda.graph", wraps=torch.cuda.graph) as mock_graph_context,
            patch.object(wrapper, "runnable", wraps=wrapper.runnable) as mock_runnable,
        ):
            entry = wrapper.concrete_cudagraph_entries.get(batch_descriptor, None)
399

400
401
402
403
404
405
            context = set_forward_context(
                attn_metadata=None,
                vllm_config=self.vllm_config,
                cudagraph_runtime_mode=runtime_mode,
                batch_descriptor=batch_descriptor,
            )
406
407
            mock_replay = MagicMock()
            if entry and entry.cudagraph:
408
409
410
411
412
413
                with (
                    context,
                    patch.object(
                        entry.cudagraph, "replay", new_callable=MagicMock
                    ) as mock_replay,
                ):
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
                    wrapper(input_tensor)
            else:
                with context:
                    wrapper(input_tensor)

            if mock_graph_context.called:
                # note that this is globally mocked, so it will be detected
                # even whether called by the inner or outer wrapper
                return "capture_global"
            if mock_replay.called:
                # only for outer wrapper
                return "replay"
            if mock_runnable.call_count > 0:
                # only for outer wrapper
                return "bypass"
            return "unknown"

    @create_new_process_for_each_test("spawn")
    def test_capture_replay_bypass_logic(self):
433
        model = SimpleMLP().to(DEVICE_TYPE)
434
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
435
        max_bs = 16
436
        persistent_input_buffer = torch.zeros(max_bs, 10, device=DEVICE_TYPE)
437
438
439
440
441
442
443
444
445
        input_1 = persistent_input_buffer[:1]
        input_2 = persistent_input_buffer[:2]
        input_3 = persistent_input_buffer[:3]

        desc_1 = BatchDescriptor(num_tokens=1)
        desc_2 = BatchDescriptor(num_tokens=2)
        desc_3_unseen = BatchDescriptor(num_tokens=3)

        # 0. global warmup
446
447
448
449
450
451
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
452
453
            full_wrapper(input_1)

454
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_1.num_tokens)
455
        # 1. Capture first shape
456
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
457
458
459
        assert action == "capture_global"

        # 2. Replay first shape
460
        action = self._run_and_monitor_call(full_wrapper, input_1, rt_mode, key)
461
462
        assert action == "replay"

463
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_2.num_tokens)
464
        # 3. Capture second shape
465
        action = self._run_and_monitor_call(full_wrapper, input_2, rt_mode, key)
466
467
468
        assert action == "capture_global"

        # 4. Replay second shape
469
470
471
        action = self._run_and_monitor_call(
            full_wrapper, input_2, CUDAGraphMode.FULL, desc_2
        )
472
473
474
        assert action == "replay"

        # 5. Bypass if no key match
475
        rt_mode, key = self.dispatcher.dispatch(num_tokens=desc_3_unseen.num_tokens)
476
        assert rt_mode == CUDAGraphMode.NONE
477
        action = self._run_and_monitor_call(full_wrapper, input_3, rt_mode, key)
478
479
480
481
482
        assert action == "bypass"

        # capture unseen shape is not allowed after disable
        set_cudagraph_capturing_enabled(False)
        with pytest.raises(RuntimeError):
483
484
485
            self._run_and_monitor_call(
                full_wrapper, input_3, CUDAGraphMode.FULL, desc_3_unseen
            )
486
487
488
489
490
        set_cudagraph_capturing_enabled(True)

    @create_new_process_for_each_test("spawn")
    def test_nested_wrappers(self):
        """Tests a scenario with a PIECEWISE wrapper inside a FULL one."""
491
        model = SimpleMLP().to(DEVICE_TYPE)
492
        full_wrapper = CUDAGraphWrapper(model, self.vllm_config, CUDAGraphMode.FULL)
493
        input_1 = torch.randn(1, 10, device=DEVICE_TYPE)
494
495

        # Setup: Inner model is wrapped with PIECEWISE, outer with FULL
496
        inner_model = SimpleMLP().to(DEVICE_TYPE)
497
498
499
        piecewise_wrapper = CUDAGraphWrapper(
            inner_model, self.vllm_config, CUDAGraphMode.PIECEWISE
        )
500
        inner_model.forward = MagicMock(wraps=inner_model.forward)
501
        outer_model = SimpleMLP().to(DEVICE_TYPE)
502
        # When outer model is called, it calls the piecewise_wrapper
503
504
505
506
507
508
        outer_model.forward = MagicMock(
            wraps=outer_model.forward, side_effect=piecewise_wrapper
        )
        full_wrapper = CUDAGraphWrapper(
            outer_model, self.vllm_config, CUDAGraphMode.FULL
        )
509
510
511
512

        desc_1 = BatchDescriptor(num_tokens=1)

        # 0. global warmup
513
514
515
516
517
518
        with set_forward_context(
            attn_metadata=None,
            vllm_config=self.vllm_config,
            cudagraph_runtime_mode=CUDAGraphMode.NONE,
            batch_descriptor=None,
        ):
519
520
521
522
523
524
525
            full_wrapper(input_1)

        # --- Test runtime mode FULL---
        # Run with FULL mode context. Expect outer wrapper to capture.
        # The inner mock should be called once inside the graph capture.
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
526
527
528
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
529
530
531
532
533
534
535
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again. Expect outer wrapper to replay.
        # The outer model should NOT be called because the whole graph
        # is replayed.
536
537
538
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.FULL, desc_1
        )
539
540
541
542
543
544
545
546
547
548
        assert action == "replay"
        assert outer_model.forward.call_count == 1  # No new call
        assert inner_model.forward.call_count == 1

        # --- Test runtime mode PIECEWISE ---
        outer_model.forward.reset_mock()
        inner_model.forward.reset_mock()
        # Run with PIECEWISE mode context.
        # Expect outer wrapper to bypass and call inner wrapper.
        # Inner wrapper should capture.
549
550
551
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
552
553
554
555
556
557
        assert action == "capture_global"
        assert outer_model.forward.call_count == 1
        assert inner_model.forward.call_count == 1

        # Run again with PIECEWISE.
        # Outer bypasses, inner replays.
558
559
560
        action = self._run_and_monitor_call(
            full_wrapper, input_1, CUDAGraphMode.PIECEWISE, desc_1
        )
561
562
563
        assert action == "bypass"
        assert outer_model.forward.call_count == 2
        assert inner_model.forward.call_count == 1