Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa0ca5eb
Unverified
Commit
aa0ca5eb
authored
Feb 10, 2025
by
youkaichao
Committed by
GitHub
Feb 10, 2025
Browse files
[core][rlhf] add colocate example for RLHF (#12984)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
59fff4a0
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
10 deletions
+78
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
examples/offline_inference/rlhf_colocate.py
examples/offline_inference/rlhf_colocate.py
+76
-8
No files found.
.buildkite/test-pipeline.yaml
View file @
aa0ca5eb
...
...
@@ -128,7 +128,7 @@ steps:
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/compile
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/r
ay_placement
.py
-
examples/offline_inference/r
lhf_colocate
.py
commands
:
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
...
...
@@ -137,7 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
-
python3 ../examples/offline_inference/rlhf.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
ay_placement
.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
lhf_colocate
.py
-
label
:
Metrics, Tracing Test
# 10min
num_gpus
:
2
...
...
examples/offline_inference/r
ay_placement
.py
→
examples/offline_inference/r
lhf_colocate
.py
View file @
aa0ca5eb
# SPDX-License-Identifier: Apache-2.0
"""
a simple demonstration to show how to control
the placement of the vLLM workers with Ray.
The key is to set VLLM_RAY_PER_WORKER_GPUS and
VLLM_RAY_BUNDLE_INDICES properly.
a simple demonstration to show how to co-locate
vLLM worker with training actors on the same GPUs,
for RLHF-like applications.
The key points:
- Control the placement of the vLLM workers with Ray, by setting
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
"""
import
os
import
ray
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
...
@@ -19,7 +24,33 @@ class MyWorker(Worker):
def
report_device_id
(
self
)
->
str
:
from
vllm.platforms
import
current_platform
return
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
return
self
.
device_uuid
def
update_weights_from_ipc_handles
(
self
,
ipc_handles
):
handles
=
ipc_handles
[
self
.
device_uuid
]
device_id
=
self
.
device
.
index
weights
=
[]
for
name
,
handle
in
handles
.
items
():
func
,
args
=
handle
list_args
=
list
(
args
)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args
[
6
]
=
device_id
tensor
=
func
(
*
list_args
)
weights
.
append
((
name
,
tensor
))
self
.
model_runner
.
model
.
load_weights
(
weights
=
weights
)
torch
.
cuda
.
synchronize
()
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
MyLLM
(
LLM
):
...
...
@@ -40,12 +71,32 @@ class MyLLM(LLM):
class
RayTrainingActor
:
def
report_device_id
(
self
)
->
str
:
def
__init__
(
self
):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from
transformers
import
AutoModelForCausalLM
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
self
.
model
.
to
(
"cuda:0"
)
for
name
,
p
in
self
.
model
.
named_parameters
():
p
.
data
.
zero_
()
torch
.
cuda
.
synchronize
()
# the argument for get_device_uuid is the index
# of the GPU in the visible devices.
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from
vllm.platforms
import
current_platform
return
current_platform
.
get_device_uuid
(
0
)
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
0
)
def
report_device_id
(
self
)
->
str
:
return
self
.
device_uuid
def
get_weight_ipc_handles
(
self
):
from
torch.multiprocessing.reductions
import
reduce_tensor
data
=
{}
for
name
,
p
in
self
.
model
.
named_parameters
():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
data
[
name
]
=
reduce_tensor
(
p
.
detach
())
return
{
self
.
device_uuid
:
data
}
# ray manages 4 GPUs
...
...
@@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]:
),
)(
RayTrainingActor
).
remote
()
training_actors
.
append
(
training_actor
)
for
bundle_index
,
training_actor
in
enumerate
(
training_actors
):
device_id
=
ray
.
get
(
training_actor
.
report_device_id
.
remote
())
print
(
f
"training actor
{
bundle_index
}
is on
{
device_id
}
"
)
training_actor_device_ids
.
append
(
device_id
)
...
...
@@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the last two training actors should be
# on the same GPUs as the second inference engine
assert
training_actor_device_ids
[
2
:]
==
inference_engine_device_ids
[
1
]
print
(
"gather all the IPC handles from the training actors"
)
ipc_handles
=
{}
for
actor
in
training_actors
:
ipc_handles
.
update
(
ray
.
get
(
actor
.
get_weight_ipc_handles
.
remote
()))
print
(
"update the weights of the inference engines"
)
for
llm
in
inference_engines
:
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"update_weights_from_ipc_handles"
,
args
=
(
ipc_handles
,
)))
print
(
"check if the weights are updated"
)
for
llm
in
inference_engines
:
assert
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
,
args
=
tuple
()))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment