test_async_spec_decode.py 3.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test that verifies no implicit GPU-CPU synchronization occurs during
speculative decoding generation under expected conditions.
"""

import multiprocessing
import sys
import traceback

import pytest
import torch


@pytest.fixture
def sync_tracker():
    """
    Fixture that patches CommonAttentionMetadata.seq_lens_cpu to detect
    lazy init syncs. Prints stack traces immediately when syncs occur.
    """
22
    from vllm.v1.attention.backend import CommonAttentionMetadata
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

    # Shared counter for cross-process communication (inherited by fork)
    sync_count = multiprocessing.Value("i", 0)

    # Save original property
    original_prop = CommonAttentionMetadata.seq_lens_cpu
    original_fget = original_prop.fget

    # Create tracking wrapper
    def tracking_seq_lens_cpu(self):
        if self._seq_lens_cpu is None:
            # Increment counter
            with sync_count.get_lock():
                sync_count.value += 1
                count = sync_count.value
            # Print stack trace immediately (shows in subprocess output)
            print(f"\n{'=' * 60}", file=sys.stderr)
            print(f"SYNC #{count}: seq_lens_cpu lazy init triggered!", file=sys.stderr)
            print(f"{'=' * 60}", file=sys.stderr)
            traceback.print_stack(file=sys.stderr)
            print(f"{'=' * 60}\n", file=sys.stderr)
            sys.stderr.flush()
        return original_fget(self)

    # Apply patch
    CommonAttentionMetadata.seq_lens_cpu = property(tracking_seq_lens_cpu)

    class SyncTracker:
        @property
        def count(self) -> int:
            return sync_count.value

        def assert_no_sync(self, msg: str = ""):
            count = sync_count.value
            assert count == 0, (
                f"Unexpected GPU-CPU sync: seq_lens_cpu lazy init triggered "
                f"{count} times. See stack traces above. {msg}"
            )

    yield SyncTracker()

    # Restore original property
    CommonAttentionMetadata.seq_lens_cpu = original_prop
    torch._dynamo.reset()


# Test configurations: (model, spec_model, method, num_spec_tokens, backend_env)
SPEC_DECODE_CONFIGS = [
    pytest.param(
        "meta-llama/Llama-3.2-1B-Instruct",
        "nm-testing/Llama3_2_1B_speculator.eagle3",
        "eagle3",
        2,
        id="eagle3-llama",
    ),
    pytest.param(
        "eagle618/deepseek-v3-random",
        "eagle618/eagle-deepseek-v3-random",
        "eagle",
        2,
        id="eagle-mla-deepseek",
    ),
]


@pytest.mark.parametrize(
    "model,spec_model,method,num_spec_tokens",
    SPEC_DECODE_CONFIGS,
)
def test_no_sync_with_spec_decode(
    sync_tracker,
    model: str,
    spec_model: str,
    method: str,
    num_spec_tokens: int,
):
    """
    Test that no implicit GPU-CPU sync occurs during speculative decoding
    generation.
    """
    # Import vLLM AFTER sync_tracker fixture has applied the patch
    from vllm import LLM, SamplingParams
    from vllm.distributed import cleanup_dist_env_and_memory

    llm = LLM(
        model=model,
        max_model_len=256,
        speculative_config={
            "method": method,
            "num_speculative_tokens": num_spec_tokens,
            "model": spec_model,
        },
        enforce_eager=True,
        async_scheduling=True,
    )

    outputs = llm.generate(
        ["Hello, my name is"],
        SamplingParams(temperature=0, max_tokens=10),
    )

    assert len(outputs) == 1
    assert len(outputs[0].outputs[0].text) > 0

    del llm
    torch.cuda.empty_cache()
    cleanup_dist_env_and_memory()

    sync_tracker.assert_no_sync()