rlhf_http.py 7.45 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
via HTTP API, with native weight syncing APIs.

Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
- NCCL for actual weight data transfer

Prerequisites:
    Start a vLLM server with weight transfer enabled:

    $ VLLM_SERVER_DEV_MODE=1 vllm serve facebook/opt-125m \
        --enforce-eager \
        --weight-transfer-config '{"backend": "nccl"}' \
        --load-format dummy

    Then run this script:

    $ python rlhf_http.py

The example performs the following steps:

* Load the training model on GPU 0.
* Generate text using the vLLM server via OpenAI-compatible API. The output
  is expected to be nonsense because the server is initialized with dummy weights.
* Initialize weight transfer via HTTP endpoint.
* Broadcast the real weights from the training model to the vLLM server
  using NCCL.
* Generate text again to show normal output after the weight update.
"""

import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM

from vllm.distributed.weight_transfer.nccl_engine import (
    NCCLWeightTransferEngine,
)
from vllm.utils.network_utils import get_ip, get_open_port

BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"


def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
    """Generate completions using the OpenAI-compatible API."""
    results = []
    for prompt in prompts:
        response = client.completions.create(
            model=model,
            prompt=prompt,
            max_tokens=32,
            temperature=0,
        )
        results.append(response.choices[0].text)
    return results


def init_weight_transfer_engine(
    base_url: str,
    master_address: str,
    master_port: int,
    rank_offset: int,
    world_size: int,
) -> None:
    """Initialize weight transfer via HTTP endpoint."""
    url = f"{base_url}/init_weight_transfer_engine"
    payload = {
        "init_info": dict(
            master_address=master_address,
            master_port=master_port,
            rank_offset=rank_offset,
            world_size=world_size,
        )
    }
    response = requests.post(url, json=payload, timeout=60)
    response.raise_for_status()


def update_weights(
    base_url: str,
    names: list[str],
    dtype_names: list[str],
    shapes: list[list[int]],
    packed: bool = False,
) -> None:
    """Update weights via HTTP endpoint."""
    url = f"{base_url}/update_weights"
    payload = {
        "update_info": dict(
            names=names,
            dtype_names=dtype_names,
            shapes=shapes,
            packed=packed,
        )
    }
    response = requests.post(url, json=payload, timeout=300)
    response.raise_for_status()


def pause_generation(base_url: str) -> None:
    """Pause generation via HTTP endpoint."""
    url = f"{base_url}/pause"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def resume_generation(base_url: str) -> None:
    """Resume generation via HTTP endpoint."""
    url = f"{base_url}/resume"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def get_world_size(base_url: str) -> int:
    """Get world size from the vLLM server."""
    url = f"{base_url}/get_world_size"
    response = requests.get(url, timeout=10)
    response.raise_for_status()
    return response.json()["world_size"]


def main():
    # Get the inference world size from the vLLM server
    inference_world_size = get_world_size(BASE_URL)
    world_size = inference_world_size + 1  # +1 for the trainer
    device = f"cuda:{inference_world_size}"
    torch.cuda.set_device(device)

    # Load the training model
    print(f"Loading training model: {MODEL_NAME}")
    train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
    train_model.to(device)

    # Create OpenAI client pointing to the vLLM server
    client = OpenAI(
        base_url=f"{BASE_URL}/v1",
        api_key="EMPTY",  # vLLM doesn't require an API key by default
    )

    # Test prompts
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Generate text before weight update. The output is expected to be nonsense
    # because the server is initialized with dummy weights.
    print("-" * 50)
    print("Generating text BEFORE weight update (expect nonsense):")
    print("-" * 50)
    outputs = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

    # Set up the communication channel between the training process and the
    # vLLM server. The trainer is rank 0, vLLM worker(s) start at rank_offset.
    master_address = get_ip()
    master_port = get_open_port()
    rank_offset = 1

    print(f"Initializing weight transfer: master={master_address}:{master_port}")

    # Initialize weight transfer on vLLM server (this is async, server will
    # wait for NCCL connection)
    import threading

    init_thread = threading.Thread(
        target=init_weight_transfer_engine,
        args=(BASE_URL, master_address, master_port, rank_offset, world_size),
    )
    init_thread.start()

    # Initialize NCCL process group on trainer side
    model_update_group = NCCLWeightTransferEngine.trainer_init(
        dict(
            master_address=master_address,
            master_port=master_port,
            world_size=world_size,
        ),
    )

    # Wait for init_weight_transfer_engine to complete
    init_thread.join()

    # Pause generation before weight sync
    pause_generation(BASE_URL)

    # Collect weight metadata for the update request
    names = []
    dtype_names = []
    shapes = []
    for name, p in train_model.named_parameters():
        names.append(name)
        dtype_names.append(str(p.dtype).split(".")[-1])
        shapes.append(list(p.shape))

    # Start the update_weights call in a separate thread since it will block
    # waiting for NCCL broadcasts
    # packed=True enables efficient batched tensor broadcasting
    update_thread = threading.Thread(
        target=update_weights,
        args=(BASE_URL, names, dtype_names, shapes, True),  # packed=True
    )
    update_thread.start()

    # Broadcast all weights from trainer to vLLM workers
    print("Broadcasting weights via NCCL...")
    NCCLWeightTransferEngine.trainer_send_weights(
        iterator=train_model.named_parameters(),
        group=model_update_group,
        packed=True,
    )

    # Wait for update_weights to complete
    update_thread.join()

    # Resume generation after weight sync
    resume_generation(BASE_URL)

    # Generate text after weight update. The output is expected to be normal
    # because the real weights are now loaded.
    print("-" * 50)
    print("Generating text AFTER weight update:")
    print("-" * 50)
    outputs_updated = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs_updated):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)


if __name__ == "__main__":
    main()