data_parallel_pause_resume.py 3.96 KB
Newer Older
raojy's avatar
raojy committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Test pause/resume with Data Parallel (DP) via HTTP API.

This example demonstrates coordinated pause/resume across multiple DP ranks.
The pause synchronizes across all DP engines via all-reduce.

Prerequisites:
    Start a vLLM server with data parallelism:

    $ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
        --enforce-eager \
        --data-parallel-size 4 \
        --tensor-parallel-size 1

    Then run this script:

    $ python data_parallel_pause_resume.py

The test verifies pause works by:
1. Starting a streaming generation request
2. Pausing the server mid-generation
3. Sleeping for PAUSE_DURATION seconds
4. Resuming the server
5. Verifying there was a gap in token generation matching the pause duration
"""

import argparse
import threading
import time

import requests
from openai import OpenAI

BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"
PAUSE_DURATION = 3.0


def pause_generation(base_url: str, mode: str = "keep") -> None:
    """Pause generation via HTTP endpoint."""
    url = f"{base_url}/pause"
    response = requests.post(url, params={"mode": mode}, timeout=60)
    response.raise_for_status()
    print("Server paused")


def resume_generation(base_url: str) -> None:
    """Resume generation via HTTP endpoint."""
    url = f"{base_url}/resume"
    response = requests.post(url, timeout=60)
    response.raise_for_status()
    print("Server resumed")


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base-url", default=BASE_URL)
    parser.add_argument("--model", default=MODEL_NAME)
    args = parser.parse_args()

    client = OpenAI(
        base_url=f"{args.base_url}/v1",
        api_key="EMPTY",
    )

    prompt = "Write a long story about a dragon. Once upon a time"
    token_times: list[float] = []
    pause_token_idx = 0
    pause_triggered = threading.Event()

    def generator_thread():
        """Stream tokens and record timestamps."""
        stream = client.completions.create(
            model=args.model,
            prompt=prompt,
            max_tokens=50,
            stream=True,
        )
        for chunk in stream:
            if chunk.choices[0].text:
                token_times.append(time.monotonic())
                token_count = len(token_times)
                print(f"Token {token_count}: {chunk.choices[0].text!r}")

                # Signal controller after some tokens
                if token_count >= 5 and not pause_triggered.is_set():
                    pause_triggered.set()

    def controller_thread():
        """Pause and resume the server."""
        nonlocal pause_token_idx

        # Wait for some tokens
        pause_triggered.wait()

        print(f"\nPausing server (keep mode) at token {len(token_times)}...")
        pause_generation(args.base_url, mode="keep")
        pause_token_idx = len(token_times)
        print(f"Sleeping for {PAUSE_DURATION}s...")

        time.sleep(PAUSE_DURATION)

        print("Resuming server...")
        resume_generation(args.base_url)
        print("Resumed!\n")

    # Run both threads
    gen_thread = threading.Thread(target=generator_thread)
    ctrl_thread = threading.Thread(target=controller_thread)

    gen_thread.start()
    ctrl_thread.start()

    gen_thread.join()
    ctrl_thread.join()

    # Check gap at the pause point
    if pause_token_idx < len(token_times):
        pause_gap = token_times[pause_token_idx] - token_times[pause_token_idx - 1]
        print(
            f"\nGap after pause (token {pause_token_idx} -> "
            f"{pause_token_idx + 1}): {pause_gap:.3f}s"
        )
        if pause_gap >= PAUSE_DURATION * 0.9:
            print("Test passed! Pause synchronized across DP ranks.")
        else:
            print(f"Test failed! Expected ~{PAUSE_DURATION}s gap, got {pause_gap:.3f}s")
    else:
        print("Test failed! No tokens were generated after resuming.")


if __name__ == "__main__":
    main()