rlhf_http_ipc.py 6.29 KB
Newer Older
1
2
3
4
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM
5
6
7
8
9
via HTTP API, with IPC-based weight syncing APIs.

Unlike rlhf_nccl.py which uses NCCL and can use separate GPUs, this script
uses CUDA IPC which requires the training model and vLLM server to be on the
same GPU. Memory must be carefully managed to fit both models.
10
11
12
13
14

Unlike rlhf.py which creates a vLLM instance programmatically, this script
assumes you have already started a vLLM server using `vllm serve`. It uses:
- OpenAI-compatible API for inference requests
- HTTP endpoints for weight transfer control plane
15
- CUDA IPC for actual weight data transfer
16
17

Prerequisites:
18
19
    Start a vLLM server with weight transfer enabled and reduced GPU memory
    utilization to leave room for the training model:
20

21
22
23
24
25
    $ VLLM_SERVER_DEV_MODE=1 VLLM_ALLOW_INSECURE_SERIALIZATION=1 \
        vllm serve facebook/opt-125m --enforce-eager \
        --weight-transfer-config '{"backend": "ipc"}' \
        --load-format dummy \
        --gpu-memory-utilization 0.5
26
27
28

    Then run this script:

29
    $ python rlhf_http_ipc.py
30
31
32

The example performs the following steps:

33
* Load the training model on GPU 0 (same GPU as the vLLM server).
34
35
* Generate text using the vLLM server via OpenAI-compatible API. The output
  is expected to be nonsense because the server is initialized with dummy weights.
36
* Initialize weight transfer via HTTP endpoint (no-op for IPC).
37
* Broadcast the real weights from the training model to the vLLM server
38
  using CUDA IPC handles.
39
40
41
* Generate text again to show normal output after the weight update.
"""

42
43
import os

44
45
46
47
48
import requests
import torch
from openai import OpenAI
from transformers import AutoModelForCausalLM

49
50
51
from vllm.distributed.weight_transfer.ipc_engine import (
    IPCTrainerSendWeightsArgs,
    IPCWeightTransferEngine,
52
53
54
55
56
)

BASE_URL = "http://localhost:8000"
MODEL_NAME = "facebook/opt-125m"

57
58
59
# Enable insecure serialization for IPC handle serialization
os.environ["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

def generate_completions(client: OpenAI, model: str, prompts: list[str]) -> list[str]:
    """Generate completions using the OpenAI-compatible API."""
    results = []
    for prompt in prompts:
        response = client.completions.create(
            model=model,
            prompt=prompt,
            max_tokens=32,
            temperature=0,
        )
        results.append(response.choices[0].text)
    return results


75
76
def init_weight_transfer_engine(base_url: str) -> None:
    """Initialize weight transfer via HTTP endpoint (no-op for IPC)."""
77
    url = f"{base_url}/init_weight_transfer_engine"
78
    payload = {"init_info": dict()}
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    response = requests.post(url, json=payload, timeout=60)
    response.raise_for_status()


def pause_generation(base_url: str) -> None:
    """Pause generation via HTTP endpoint."""
    url = f"{base_url}/pause"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def resume_generation(base_url: str) -> None:
    """Resume generation via HTTP endpoint."""
    url = f"{base_url}/resume"
    response = requests.post(url, timeout=60)
    response.raise_for_status()


def get_world_size(base_url: str) -> int:
    """Get world size from the vLLM server."""
    url = f"{base_url}/get_world_size"
    response = requests.get(url, timeout=10)
    response.raise_for_status()
    return response.json()["world_size"]


def main():
106
107
108
    # IPC requires the training model to be on the same GPU as the vLLM server
    # The server should be started on GPU 0 with reduced memory utilization
    device = "cuda:0"
109
    torch.accelerator.set_device_index(device)
110

111
112
113
114
115
116
117
    # Load the training model on the same GPU as the server
    # Use bfloat16 to reduce memory footprint
    print(f"Loading training model: {MODEL_NAME} on {device}")
    print(
        "Note: Ensure the vLLM server was started with --gpu-memory-utilization 0.5 "
        "or lower to leave room for the training model."
    )
118
119
    train_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, dtype=torch.bfloat16)
    train_model.to(device)
120
    train_model.eval()  # Set to eval mode to save memory
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

    # Create OpenAI client pointing to the vLLM server
    client = OpenAI(
        base_url=f"{BASE_URL}/v1",
        api_key="EMPTY",  # vLLM doesn't require an API key by default
    )

    # Test prompts
    prompts = [
        "Hello, my name is",
        "The president of the United States is",
        "The capital of France is",
        "The future of AI is",
    ]

    # Generate text before weight update. The output is expected to be nonsense
    # because the server is initialized with dummy weights.
    print("-" * 50)
    print("Generating text BEFORE weight update (expect nonsense):")
    print("-" * 50)
    outputs = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

146
    print("Initializing weight transfer (IPC backend)...")
147

148
149
    # Initialize weight transfer on vLLM server (no-op for IPC, but still required)
    init_weight_transfer_engine(BASE_URL)
150
151
152
153

    # Pause generation before weight sync
    pause_generation(BASE_URL)

154
155
156
157
    # Broadcast weights via IPC handles using HTTP mode
    print("Broadcasting weights via CUDA IPC (HTTP)...")
    trainer_args = IPCTrainerSendWeightsArgs(mode="http", url=BASE_URL)
    IPCWeightTransferEngine.trainer_send_weights(
158
        iterator=train_model.named_parameters(),
159
        trainer_args=trainer_args,
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
    )

    # Resume generation after weight sync
    resume_generation(BASE_URL)

    # Generate text after weight update. The output is expected to be normal
    # because the real weights are now loaded.
    print("-" * 50)
    print("Generating text AFTER weight update:")
    print("-" * 50)
    outputs_updated = generate_completions(client, MODEL_NAME, prompts)
    for prompt, generated_text in zip(prompts, outputs_updated):
        print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
        print("-" * 50)

175
176
177
178
    # Note: The training model and IPC handles remain in memory.
    # In a real RLHF training loop, you would update the training model
    # and create new IPC handles for each weight update.

179
180
181

if __name__ == "__main__":
    main()