Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
aa0ca5eb
Unverified
Commit
aa0ca5eb
authored
Feb 10, 2025
by
youkaichao
Committed by
GitHub
Feb 10, 2025
Browse files
[core][rlhf] add colocate example for RLHF (#12984)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
59fff4a0
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
78 additions
and
10 deletions
+78
-10
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-2
examples/offline_inference/rlhf_colocate.py
examples/offline_inference/rlhf_colocate.py
+76
-8
No files found.
.buildkite/test-pipeline.yaml
View file @
aa0ca5eb
...
@@ -128,7 +128,7 @@ steps:
...
@@ -128,7 +128,7 @@ steps:
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/compile
-
tests/compile
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/r
ay_placement
.py
-
examples/offline_inference/r
lhf_colocate
.py
commands
:
commands
:
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s compile/test_basic_correctness.py
...
@@ -137,7 +137,7 @@ steps:
...
@@ -137,7 +137,7 @@ steps:
# TODO: create a dedicated test section for multi-GPU example tests
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
# when we have multiple distributed example tests
-
python3 ../examples/offline_inference/rlhf.py
-
python3 ../examples/offline_inference/rlhf.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
ay_placement
.py
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/r
lhf_colocate
.py
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
num_gpus
:
2
num_gpus
:
2
...
...
examples/offline_inference/r
ay_placement
.py
→
examples/offline_inference/r
lhf_colocate
.py
View file @
aa0ca5eb
# SPDX-License-Identifier: Apache-2.0
# SPDX-License-Identifier: Apache-2.0
"""
"""
a simple demonstration to show how to control
a simple demonstration to show how to co-locate
the placement of the vLLM workers with Ray.
vLLM worker with training actors on the same GPUs,
The key is to set VLLM_RAY_PER_WORKER_GPUS and
for RLHF-like applications.
VLLM_RAY_BUNDLE_INDICES properly.
The key points:
- Control the placement of the vLLM workers with Ray, by setting
VLLM_RAY_PER_WORKER_GPUS and VLLM_RAY_BUNDLE_INDICES properly.
- Use cuda-ipc to pass tensors, since NCCL does not work when we have
multiple processes on the same GPU.
"""
"""
import
os
import
os
import
ray
import
ray
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
...
@@ -19,7 +24,33 @@ class MyWorker(Worker):
...
@@ -19,7 +24,33 @@ class MyWorker(Worker):
def
report_device_id
(
self
)
->
str
:
def
report_device_id
(
self
)
->
str
:
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
return
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
return
self
.
device_uuid
def
update_weights_from_ipc_handles
(
self
,
ipc_handles
):
handles
=
ipc_handles
[
self
.
device_uuid
]
device_id
=
self
.
device
.
index
weights
=
[]
for
name
,
handle
in
handles
.
items
():
func
,
args
=
handle
list_args
=
list
(
args
)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args
[
6
]
=
device_id
tensor
=
func
(
*
list_args
)
weights
.
append
((
name
,
tensor
))
self
.
model_runner
.
model
.
load_weights
(
weights
=
weights
)
torch
.
cuda
.
synchronize
()
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
MyLLM
(
LLM
):
class
MyLLM
(
LLM
):
...
@@ -40,12 +71,32 @@ class MyLLM(LLM):
...
@@ -40,12 +71,32 @@ class MyLLM(LLM):
class
RayTrainingActor
:
class
RayTrainingActor
:
def
report_device_id
(
self
)
->
str
:
def
__init__
(
self
):
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from
transformers
import
AutoModelForCausalLM
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
self
.
model
.
to
(
"cuda:0"
)
for
name
,
p
in
self
.
model
.
named_parameters
():
p
.
data
.
zero_
()
torch
.
cuda
.
synchronize
()
# the argument for get_device_uuid is the index
# the argument for get_device_uuid is the index
# of the GPU in the visible devices.
# of the GPU in the visible devices.
# ray will set CUDA_VISIBLE_DEVICES to the assigned GPUs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
return
current_platform
.
get_device_uuid
(
0
)
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
0
)
def
report_device_id
(
self
)
->
str
:
return
self
.
device_uuid
def
get_weight_ipc_handles
(
self
):
from
torch.multiprocessing.reductions
import
reduce_tensor
data
=
{}
for
name
,
p
in
self
.
model
.
named_parameters
():
# the training actor might only have a subset of the weights
# and need to all-gather the weights from all the actors.
# for demonstration, here we assume all training actors have
# the full weights.
data
[
name
]
=
reduce_tensor
(
p
.
detach
())
return
{
self
.
device_uuid
:
data
}
# ray manages 4 GPUs
# ray manages 4 GPUs
...
@@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]:
...
@@ -78,6 +129,8 @@ for bundle_index in [0, 1, 2, 3]:
),
),
)(
RayTrainingActor
).
remote
()
)(
RayTrainingActor
).
remote
()
training_actors
.
append
(
training_actor
)
training_actors
.
append
(
training_actor
)
for
bundle_index
,
training_actor
in
enumerate
(
training_actors
):
device_id
=
ray
.
get
(
training_actor
.
report_device_id
.
remote
())
device_id
=
ray
.
get
(
training_actor
.
report_device_id
.
remote
())
print
(
f
"training actor
{
bundle_index
}
is on
{
device_id
}
"
)
print
(
f
"training actor
{
bundle_index
}
is on
{
device_id
}
"
)
training_actor_device_ids
.
append
(
device_id
)
training_actor_device_ids
.
append
(
device_id
)
...
@@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
...
@@ -119,3 +172,18 @@ assert training_actor_device_ids[:2] == inference_engine_device_ids[0]
# the last two training actors should be
# the last two training actors should be
# on the same GPUs as the second inference engine
# on the same GPUs as the second inference engine
assert
training_actor_device_ids
[
2
:]
==
inference_engine_device_ids
[
1
]
assert
training_actor_device_ids
[
2
:]
==
inference_engine_device_ids
[
1
]
print
(
"gather all the IPC handles from the training actors"
)
ipc_handles
=
{}
for
actor
in
training_actors
:
ipc_handles
.
update
(
ray
.
get
(
actor
.
get_weight_ipc_handles
.
remote
()))
print
(
"update the weights of the inference engines"
)
for
llm
in
inference_engines
:
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"update_weights_from_ipc_handles"
,
args
=
(
ipc_handles
,
)))
print
(
"check if the weights are updated"
)
for
llm
in
inference_engines
:
assert
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
,
args
=
tuple
()))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment