Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
235bfd5d
Unverified
Commit
235bfd5d
authored
Jul 15, 2025
by
Ricardo Decal
Committed by
GitHub
Jul 15, 2025
Browse files
[Docs] Improve documentation for RLHF example (#20598)
Signed-off-by:
Ricardo Decal
<
rdecal@anyscale.com
>
parent
68d28e37
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
49 additions
and
36 deletions
+49
-36
examples/offline_inference/rlhf.py
examples/offline_inference/rlhf.py
+49
-36
No files found.
examples/offline_inference/rlhf.py
View file @
235bfd5d
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""
"""
a simple demonstration of RLHF with vLLM, inspired by
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
The script separates training and inference workloads onto distinct GPUs
are different, and they live on different GPUs.
so that Ray can manage process placement and inter-process communication.
Training processes send prompts to inference processes to generate data,
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
and also synchronize the weights of the model by broadcasting the weights
tensor-parallel vLLM inference engine occupies GPU 1–2.
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
The example performs the following steps:
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
* Load the training model on GPU 0.
to the OpenRLHF framework.
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
and Ray placement groups.
* Generate text from a list of prompts using the inference engine.
* Update the weights of the training model and broadcast the updated weights
to the inference engine by using a Ray collective RPC group. Note that
for demonstration purposes we simply zero out the weights.
For a production-ready implementation that supports multiple training and
inference replicas, see the OpenRLHF framework:
https://github.com/OpenRLHF/OpenRLHF
This example assumes a single-node cluster with three GPUs, but Ray
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
workloads. Residual GPU activity interferes with vLLM memory profiling and
causes unexpected behavior.
"""
"""
import
os
import
os
...
@@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
...
@@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
class
MyLLM
(
LLM
):
class
MyLLM
(
LLM
):
"""Configure the vLLM worker for Ray placement group execution."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# a hack to make the script work.
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# so that vLLM can manage its own device placement within the worker.
# at the top-level
os
.
environ
.
pop
(
"CUDA_VISIBLE_DEVICES"
,
None
)
os
.
environ
.
pop
(
"CUDA_VISIBLE_DEVICES"
,
None
)
super
().
__init__
(
*
args
,
**
kwargs
)
super
().
__init__
(
*
args
,
**
kwargs
)
"""
# Load the OPT-125M model onto GPU 0 for the training workload.
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
"""
train_model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
train_model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
train_model
.
to
(
"cuda:0"
)
train_model
.
to
(
"cuda:0"
)
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and
# Initialize Ray and set the visible devices. The vLLM engine will
GPU 2. For the details on how to use ray, please refer to the ray
# be placed on GPUs 1 and 2.
documentation https://docs.ray.io/en/latest/ .
"""
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1,2"
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1,2"
ray
.
init
()
ray
.
init
()
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
# Learn more about Ray placement groups:
# https://docs.ray.io/en/latest/placement-groups.html
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
ray
.
get
(
pg_inference
.
ready
())
ray
.
get
(
pg_inference
.
ready
())
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
...
@@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
...
@@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
placement_group_capture_child_tasks
=
True
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
0
,
placement_group_bundle_index
=
0
,
)
)
"""
launch the vLLM inference engine.
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
here we use `enforce_eager` to reduce the start time.
# start-up latency.
"""
llm
=
ray
.
remote
(
llm
=
ray
.
remote
(
num_cpus
=
0
,
num_cpus
=
0
,
num_gpus
=
0
,
num_gpus
=
0
,
...
@@ -74,7 +85,7 @@ llm = ray.remote(
...
@@ -74,7 +85,7 @@ llm = ray.remote(
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
)
)
# Generate text
s
from the prompts.
# Generate text from the prompts.
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
@@ -93,8 +104,8 @@ for output in outputs:
...
@@ -93,8 +104,8 @@ for output in outputs:
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
f
"Prompt:
{
prompt
!
r
}
\n
Generated text:
{
generated_text
!
r
}
"
)
print
(
"-"
*
50
)
print
(
"-"
*
50
)
#
s
et up the communication between the training process
#
S
et up the communication
channel
between the training process
and the
#
and the
inference engine.
# inference engine.
master_address
=
get_ip
()
master_address
=
get_ip
()
master_port
=
get_open_port
()
master_port
=
get_open_port
()
...
@@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
...
@@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
)
)
ray
.
get
(
handle
)
ray
.
get
(
handle
)
# simulate training, modify the weights of the model.
# Simulate a training step by zeroing out all model weights.
# In a real RLHF training loop the weights would be updated using the gradient
# from an RL objective such as PPO on a reward model.
for
name
,
p
in
train_model
.
named_parameters
():
for
name
,
p
in
train_model
.
named_parameters
():
p
.
data
.
zero_
()
p
.
data
.
zero_
()
#
s
ync
weight from the training proces
s to the inference engine.
#
S
ync
hronize the updated weight
s to the inference engine.
for
name
,
p
in
train_model
.
named_parameters
():
for
name
,
p
in
train_model
.
named_parameters
():
handle
=
llm
.
collective_rpc
.
remote
(
"update_weight"
,
args
=
(
name
,
p
.
dtype
,
p
.
shape
))
handle
=
llm
.
collective_rpc
.
remote
(
"update_weight"
,
args
=
(
name
,
p
.
dtype
,
p
.
shape
))
model_update_group
.
broadcast
(
p
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
model_update_group
.
broadcast
(
p
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
ray
.
get
(
handle
)
ray
.
get
(
handle
)
#
check if the weights are
updated.
#
Verify that the inference weights have been
updated.
assert
all
(
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
)))
assert
all
(
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
)))
#
use the updated model to generate texts, they will
be nonsense
#
Generate text with the updated model. The output is expected to
be nonsense
# because the weights are
all
zero
s
.
# because the weights are zero.
outputs_updated
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
outputs_updated
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
print
(
"-"
*
50
)
print
(
"-"
*
50
)
for
output
in
outputs_updated
:
for
output
in
outputs_updated
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment