"vllm/entrypoints/openai/engine/protocol.py" did not exist on "05b044e698bb3c151871d94b64fabd87188de9ef"
rlhf.py 5.05 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
"""
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.

The script separates training and inference workloads onto distinct GPUs
so that Ray can manage process placement and inter-process communication.
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
tensor-parallel vLLM inference engine occupies GPU 1–2.

The example performs the following steps:

* Load the training model on GPU 0.
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
  and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
  to the inference engine by using a Ray collective RPC group. Note that
  for demonstration purposes we simply zero out the weights.

For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF

This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
29
"""
30

31
32
33
34
35
36
import os

import ray
import torch
from ray.util.placement_group import placement_group
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
37
from rlhf_utils import stateless_init_process_group
38
39
from transformers import AutoModelForCausalLM

40
from vllm import LLM, SamplingParams
41
42
43
44
from vllm.utils import get_ip, get_open_port


class MyLLM(LLM):
45
46
    """Configure the vLLM worker for Ray placement group execution."""

47
    def __init__(self, *args, **kwargs):
48
49
        # Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
        # so that vLLM can manage its own device placement within the worker.
50
        os.environ.pop("CUDA_VISIBLE_DEVICES", None)
51
52
53
        super().__init__(*args, **kwargs)


54
# Load the OPT-125M model onto GPU 0 for the training workload.
55
56
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
train_model.to("cuda:0")
57
58
59

# Initialize Ray and set the visible devices. The vLLM engine will
# be placed on GPUs 1 and 2.
60
61
62
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
ray.init()

63
64
65
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
66
67
68
69
70
71
72
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
ray.get(pg_inference.ready())
scheduling_inference = PlacementGroupSchedulingStrategy(
    placement_group=pg_inference,
    placement_group_capture_child_tasks=True,
    placement_group_bundle_index=0,
)
73
74
75

# Launch the vLLM inference engine. The `enforce_eager` flag reduces
# start-up latency.
76
77
78
79
80
81
82
llm = ray.remote(
    num_cpus=0,
    num_gpus=0,
    scheduling_strategy=scheduling_inference,
)(MyLLM).remote(
    model="facebook/opt-125m",
    enforce_eager=True,
83
    worker_extension_cls="rlhf_utils.WorkerExtension",
84
85
86
87
    tensor_parallel_size=2,
    distributed_executor_backend="ray",
)

88
# Generate text from the prompts.
89
90
91
92
93
94
95
96
97
98
99
prompts = [
    "Hello, my name is",
    "The president of the United States is",
    "The capital of France is",
    "The future of AI is",
]

sampling_params = SamplingParams(temperature=0)

outputs = ray.get(llm.generate.remote(prompts, sampling_params))

100
print("-" * 50)
101
102
103
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
104
    print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
105
    print("-" * 50)
106

107
108
# Set up the communication channel between the training process and the
# inference engine.
109
110
111
master_address = get_ip()
master_port = get_open_port()

112
113
114
handle = llm.collective_rpc.remote(
    "init_weight_update_group", args=(master_address, master_port, 1, 3)
)
115

116
117
118
model_update_group = stateless_init_process_group(
    master_address, master_port, 0, 3, torch.device("cuda:0")
)
119
120
ray.get(handle)

121
122
123
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
124
125
126
for name, p in train_model.named_parameters():
    p.data.zero_()

127
# Synchronize the updated weights to the inference engine.
128
for name, p in train_model.named_parameters():
129
    handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
130
131
132
    model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
    ray.get(handle)

133
# Verify that the inference weights have been updated.
134
135
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))

136
137
# Generate text with the updated model. The output is expected to be nonsense
# because the weights are zero.
138
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
139
print("-" * 50)
140
141
142
for output in outputs_updated:
    prompt = output.prompt
    generated_text = output.outputs[0].text
143
    print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
144
    print("-" * 50)