Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
151b08e0
Unverified
Commit
151b08e0
authored
Mar 07, 2025
by
youkaichao
Committed by
GitHub
Mar 07, 2025
Browse files
[RLHF] use worker_extension_cls for compatibility with V0 and V1 (#14185)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
81b2f4a4
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
153 additions
and
100 deletions
+153
-100
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+4
-2
examples/offline_inference/rlhf.py
examples/offline_inference/rlhf.py
+3
-63
examples/offline_inference/rlhf_colocate.py
examples/offline_inference/rlhf_colocate.py
+1
-35
examples/offline_inference/rlhf_utils.py
examples/offline_inference/rlhf_utils.py
+105
-0
vllm/config.py
vllm/config.py
+4
-0
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+9
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+27
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
151b08e0
...
@@ -145,8 +145,10 @@ steps:
...
@@ -145,8 +145,10 @@ steps:
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
# when we have multiple distributed example tests
-
python3 ../examples/offline_inference/rlhf.py
-
pushd ../examples/offline_inference
-
RAY_DEDUP_LOGS=0 python3 ../examples/offline_inference/rlhf_colocate.py
-
python3 rlhf.py
-
RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
-
popd
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
num_gpus
:
2
num_gpus
:
2
...
...
examples/offline_inference/rlhf.py
View file @
151b08e0
...
@@ -18,72 +18,11 @@ import ray
...
@@ -18,72 +18,11 @@ import ray
import
torch
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
rlhf_utils
import
stateless_init_process_group
from
transformers
import
AutoModelForCausalLM
from
transformers
import
AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.utils
import
get_ip
,
get_open_port
from
vllm.utils
import
get_ip
,
get_open_port
from
vllm.worker.worker
import
Worker
def
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
device
):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.utils
import
StatelessProcessGroup
pg
=
StatelessProcessGroup
.
create
(
host
=
master_address
,
port
=
master_port
,
rank
=
rank
,
world_size
=
world_size
)
pynccl
=
PyNcclCommunicator
(
pg
,
device
=
device
)
return
pynccl
class
MyWorker
(
Worker
):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""
def
init_weight_update_group
(
self
,
master_address
,
master_port
,
rank_offset
,
world_size
):
from
vllm.distributed.parallel_state
import
get_world_group
rank
=
get_world_group
().
rank
+
rank_offset
self
.
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
self
.
device
,
)
def
update_weight
(
self
,
name
,
dtype
,
shape
):
weight
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
self
.
model_update_group
.
broadcast
(
weight
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
self
.
model_runner
.
model
.
load_weights
(
weights
=
[(
name
,
weight
)])
del
weight
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
MyLLM
(
LLM
):
class
MyLLM
(
LLM
):
...
@@ -129,7 +68,7 @@ llm = ray.remote(
...
@@ -129,7 +68,7 @@ llm = ray.remote(
)(
MyLLM
).
remote
(
)(
MyLLM
).
remote
(
model
=
"facebook/opt-125m"
,
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
worker_
cls
=
MyWorker
,
worker_
extension_cls
=
"rlhf_utils.WorkerExtension"
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
)
)
...
@@ -159,6 +98,7 @@ master_port = get_open_port()
...
@@ -159,6 +98,7 @@ master_port = get_open_port()
handle
=
llm
.
collective_rpc
.
remote
(
"init_weight_update_group"
,
handle
=
llm
.
collective_rpc
.
remote
(
"init_weight_update_group"
,
args
=
(
master_address
,
master_port
,
1
,
3
))
args
=
(
master_address
,
master_port
,
1
,
3
))
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
0
,
3
,
torch
.
device
(
"cuda:0"
))
0
,
3
,
torch
.
device
(
"cuda:0"
))
ray
.
get
(
handle
)
ray
.
get
(
handle
)
...
...
examples/offline_inference/rlhf_colocate.py
View file @
151b08e0
...
@@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
...
@@ -17,40 +17,6 @@ from ray.util.placement_group import placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
vllm
import
LLM
from
vllm
import
LLM
from
vllm.worker.worker
import
Worker
class
MyWorker
(
Worker
):
def
report_device_id
(
self
)
->
str
:
from
vllm.platforms
import
current_platform
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
return
self
.
device_uuid
def
update_weights_from_ipc_handles
(
self
,
ipc_handles
):
handles
=
ipc_handles
[
self
.
device_uuid
]
device_id
=
self
.
device
.
index
weights
=
[]
for
name
,
handle
in
handles
.
items
():
func
,
args
=
handle
list_args
=
list
(
args
)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args
[
6
]
=
device_id
tensor
=
func
(
*
list_args
)
weights
.
append
((
name
,
tensor
))
self
.
model_runner
.
model
.
load_weights
(
weights
=
weights
)
torch
.
cuda
.
synchronize
()
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
MyLLM
(
LLM
):
class
MyLLM
(
LLM
):
...
@@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
...
@@ -150,7 +116,7 @@ for (i, bundle_indices) in enumerate([[0, 1], [2, 3]]):
)(
MyLLM
).
remote
(
)(
MyLLM
).
remote
(
model
=
"facebook/opt-125m"
,
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
enforce_eager
=
True
,
worker_
cls
=
MyWorker
,
worker_
extension_cls
=
"rlhf_utils.ColocateWorkerExtension"
,
tensor_parallel_size
=
2
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"ray"
,
distributed_executor_backend
=
"ray"
,
gpu_memory_utilization
=
0.4
,
gpu_memory_utilization
=
0.4
,
...
...
examples/offline_inference/rlhf_utils.py
0 → 100644
View file @
151b08e0
# SPDX-License-Identifier: Apache-2.0
import
torch
def
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
device
):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.utils
import
StatelessProcessGroup
pg
=
StatelessProcessGroup
.
create
(
host
=
master_address
,
port
=
master_port
,
rank
=
rank
,
world_size
=
world_size
)
pynccl
=
PyNcclCommunicator
(
pg
,
device
=
device
)
return
pynccl
class
WorkerExtension
:
"""
The class for vLLM's worker to inherit from.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def
init_weight_update_group
(
self
,
master_address
,
master_port
,
rank_offset
,
world_size
):
from
vllm.distributed.parallel_state
import
get_world_group
rank
=
get_world_group
().
rank
+
rank_offset
self
.
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
self
.
device
,
)
def
update_weight
(
self
,
name
,
dtype
,
shape
):
weight
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
self
.
model_update_group
.
broadcast
(
weight
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
self
.
model_runner
.
model
.
load_weights
(
weights
=
[(
name
,
weight
)])
del
weight
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
ColocateWorkerExtension
:
"""
The class for vLLM's worker to inherit from, in the colocate setting.
By defining an extension class, the code can work no matter what is
the underlying worker class. This way, the code can be compatible
with both vLLM V0 and V1.
NOTE: we define this class in a separate module, and the main module
should pass the full qualified name as `worker_extension_cls` argument.
"""
def
report_device_id
(
self
)
->
str
:
from
vllm.platforms
import
current_platform
self
.
device_uuid
=
current_platform
.
get_device_uuid
(
self
.
device
.
index
)
return
self
.
device_uuid
def
update_weights_from_ipc_handles
(
self
,
ipc_handles
):
handles
=
ipc_handles
[
self
.
device_uuid
]
device_id
=
self
.
device
.
index
weights
=
[]
for
name
,
handle
in
handles
.
items
():
func
,
args
=
handle
list_args
=
list
(
args
)
# the key is to change device id to the current device id
# in case two processes have different CUDA_VISIBLE_DEVICES
list_args
[
6
]
=
device_id
tensor
=
func
(
*
list_args
)
weights
.
append
((
name
,
tensor
))
self
.
model_runner
.
model
.
load_weights
(
weights
=
weights
)
torch
.
cuda
.
synchronize
()
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
vllm/config.py
View file @
151b08e0
...
@@ -1366,6 +1366,7 @@ class ParallelConfig:
...
@@ -1366,6 +1366,7 @@ class ParallelConfig:
# will be determined based on the platform.
# will be determined based on the platform.
worker_cls
:
str
=
"auto"
worker_cls
:
str
=
"auto"
sd_worker_cls
:
str
=
"auto"
sd_worker_cls
:
str
=
"auto"
worker_extension_cls
:
str
=
""
# world_size is TPxPP, it affects the number of workers we create.
# world_size is TPxPP, it affects the number of workers we create.
world_size
:
int
=
field
(
init
=
False
)
world_size
:
int
=
field
(
init
=
False
)
...
@@ -1523,6 +1524,9 @@ class ParallelConfig:
...
@@ -1523,6 +1524,9 @@ class ParallelConfig:
raise
ValueError
(
"Unable to use nsight profiling unless workers "
raise
ValueError
(
"Unable to use nsight profiling unless workers "
"run with Ray."
)
"run with Ray."
)
assert
isinstance
(
self
.
worker_extension_cls
,
str
),
(
"worker_extension_cls must be a string (qualified class name)."
)
@
dataclass
@
dataclass
class
SchedulerConfig
:
class
SchedulerConfig
:
...
...
vllm/engine/arg_utils.py
View file @
151b08e0
...
@@ -202,6 +202,7 @@ class EngineArgs:
...
@@ -202,6 +202,7 @@ class EngineArgs:
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
compilation_config
:
Optional
[
CompilationConfig
]
=
None
compilation_config
:
Optional
[
CompilationConfig
]
=
None
worker_cls
:
str
=
"auto"
worker_cls
:
str
=
"auto"
worker_extension_cls
:
str
=
""
kv_transfer_config
:
Optional
[
KVTransferConfig
]
=
None
kv_transfer_config
:
Optional
[
KVTransferConfig
]
=
None
...
@@ -1015,6 +1016,13 @@ class EngineArgs:
...
@@ -1015,6 +1016,13 @@ class EngineArgs:
type
=
str
,
type
=
str
,
default
=
"auto"
,
default
=
"auto"
,
help
=
'The worker class to use for distributed execution.'
)
help
=
'The worker class to use for distributed execution.'
)
parser
.
add_argument
(
'--worker-extension-cls'
,
type
=
str
,
default
=
""
,
help
=
'The worker extension class on top of the worker cls, '
'it is useful if you just want to add new functions to the worker '
'class without changing the existing functions.'
)
parser
.
add_argument
(
parser
.
add_argument
(
"--generation-config"
,
"--generation-config"
,
type
=
nullable_str
,
type
=
nullable_str
,
...
@@ -1209,6 +1217,7 @@ class EngineArgs:
...
@@ -1209,6 +1217,7 @@ class EngineArgs:
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
ray_workers_use_nsight
=
self
.
ray_workers_use_nsight
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
distributed_executor_backend
=
self
.
distributed_executor_backend
,
worker_cls
=
self
.
worker_cls
,
worker_cls
=
self
.
worker_cls
,
worker_extension_cls
=
self
.
worker_extension_cls
,
)
)
max_model_len
=
model_config
.
max_model_len
max_model_len
=
model_config
.
max_model_len
...
...
vllm/worker/worker_base.py
View file @
151b08e0
...
@@ -558,10 +558,37 @@ class WorkerWrapperBase:
...
@@ -558,10 +558,37 @@ class WorkerWrapperBase:
worker_class
=
resolve_obj_by_qualname
(
worker_class
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
self
.
vllm_config
.
parallel_config
.
worker_cls
)
else
:
else
:
logger
.
warning
(
"passing worker_cls as a class object is strongly deprecated,"
" as the serialization of class objects can be tricky and"
" error-prone. To be safe, please keep the class in a separate"
" module and pass the qualified name of the class as a string."
)
assert
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
assert
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
bytes
)
bytes
)
worker_class
=
cloudpickle
.
loads
(
worker_class
=
cloudpickle
.
loads
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
self
.
vllm_config
.
parallel_config
.
worker_cls
)
if
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
:
worker_extension_cls
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_extension_cls
)
extended_calls
=
[]
if
worker_extension_cls
not
in
worker_class
.
__bases__
:
# check any conflicts between worker and worker_extension_cls
for
attr
in
dir
(
worker_extension_cls
):
if
attr
.
startswith
(
"__"
):
continue
assert
not
hasattr
(
worker_class
,
attr
),
(
f
"Worker class
{
worker_class
}
already has an attribute"
f
"
{
attr
}
, which conflicts with the worker"
f
" extension class
{
worker_extension_cls
}
."
)
if
callable
(
getattr
(
worker_extension_cls
,
attr
)):
extended_calls
.
append
(
attr
)
# dynamically inherit the worker extension class
worker_class
.
__bases__
=
worker_class
.
__bases__
+
(
worker_extension_cls
,
)
logger
.
info
(
"Injected %s into %s for extended collective_rpc calls %s"
,
worker_extension_cls
,
worker_class
,
extended_calls
)
with
set_current_vllm_config
(
self
.
vllm_config
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during worker initialization
# To make vLLM config available during worker initialization
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment