test_async_spec_decode.py 3.63 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""

import multiprocessing
import sys
import traceback

import pytest
import torch


@pytest.fixture
def sync_tracker():
    """
    Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
    lazy init syncs. Prints stack traces immediately when syncs occur.
    """
    from vllm.v1.attention.backends.utils import CommonAttentionMetadata

    # Shared counter for cross-process communication (inherited by fork)
    sync_count = multiprocessing.Value("i", 0)

    # Save original property
    original_prop = CommonAttentionMetadata.seq_lens_cpu
    original_fget = original_prop.fget

    # Create tracking wrapper
    def tracking_seq_lens_cpu(self):
        if self._seq_lens_cpu is None:
            # Increment counter
            with sync_count.get_lock():
                sync_count.value += 1
                count = sync_count.value
            # Print stack trace immediately (shows in subprocess output)
            print(f"\n{'=' * 60}", file=sys.stderr)
            print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
            print(f"{'=' * 60}", file=sys.stderr)
            traceback.print_stack(file=sys.stderr)
            print(f"{'=' * 60}\n", file=sys.stderr)
            sys.stderr.flush()
        return original_fget(self)

    # Apply patch
    CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)

    class SyncTracker:
        @property
        def count(self) -> int:
            return sync_count.value

        def assert_no_sync(self, msg: str = ""):
            count = sync_count.value
            assert count == 0, (
                f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
                f"{count} times. See stack traces above. {msg}"
            )

    yield SyncTracker()

    # Restore original property
    CommonAttentionMetadata.seq_lens_cpu = original_prop
    torch._dynamo.reset()


# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
    pytest.param(
        "meta-llama/Llama-3.2-1B-Instruct",
        "nm-testing/Llama3_2_1B_speculator.eagle3",
        "eagle3",
        2,
        id="eagle3-llama",
    ),
    pytest.param(
        "eagle618/deepseek-v3-random",
        "eagle618/eagle-deepseek-v3-random",
        "eagle",
        2,
        id="eagle-mla-deepseek",
    ),
]


@pytest.mark.parametrize(
    "model,spec_model,method,num_spec_tokens",
    SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
    sync_tracker,
    model: str,
    spec_model: str,
    method: str,
    num_spec_tokens: int,
):
    """
    Test that no implicit GPU-CPU sync occurs during speculative decoding
    generation.
    """
    # Import vLLM AFTER sync_tracker fixture has applied the patch
    from vllm import LLM, SamplingParams
    from vllm.distributed import cleanup_dist_env_and_memory

    llm = LLM(
        model=model,
        max_model_len=256,
        speculative_config={
            "method": method,
            "num_speculative_tokens": num_spec_tokens,
            "model": spec_model,
        },
        enforce_eager=True,
        async_scheduling=True,
    )

    outputs = llm.generate(
        ["Hello, my name is"],
        SamplingParams(temperature=0, max_tokens=10),
    )

    assert len(outputs) == 1
    assert len(outputs[0].outputs[0].text) > 0

    del llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    sync_tracker.assert_no_sync()