Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
92e793d9
Unverified
Commit
92e793d9
authored
Jan 16, 2025
by
youkaichao
Committed by
GitHub
Jan 16, 2025
Browse files
[core] LLM.collective_rpc interface and RLHF example (#12084)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
bf53e0c7
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
270 additions
and
35 deletions
+270
-35
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+4
-0
examples/offline_inference/rlhf.py
examples/offline_inference/rlhf.py
+191
-0
vllm/__init__.py
vllm/__init__.py
+39
-0
vllm/entrypoints/llm.py
vllm/entrypoints/llm.py
+25
-0
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+0
-31
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+11
-4
No files found.
.buildkite/test-pipeline.yaml
View file @
92e793d9
...
@@ -126,11 +126,15 @@ steps:
...
@@ -126,11 +126,15 @@ steps:
-
tests/distributed
-
tests/distributed
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/compile
-
tests/compile
-
examples/offline_inference/rlhf.py
commands
:
commands
:
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
# TODO: create a dedicated test section for multi-GPU example tests
# when we have multiple distributed example tests
-
python3 ../examples/offline_inference/rlhf.py
-
label
:
Metrics, Tracing Test
# 10min
-
label
:
Metrics, Tracing Test
# 10min
num_gpus
:
2
num_gpus
:
2
...
...
examples/offline_inference/rlhf.py
0 → 100644
View file @
92e793d9
"""
a simple demonstration of RLHF with vLLM, inspired by
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
It follows the design that, training processes and inference processes
are different, and they live on different GPUs.
Training processes send prompts to inference processes to generate data,
and also synchronize the weights of the model by broadcasting the weights
from the training process to the inference process.
Note that this is a simple demonstration of one training instance and one
inference instance. In practice, there could be multiple training instances
and multiple inference instances. For the full implementation, please refer
to the OpenRLHF framework.
"""
import
os
import
ray
import
torch
from
ray.util.placement_group
import
placement_group
from
ray.util.scheduling_strategies
import
PlacementGroupSchedulingStrategy
from
transformers
import
AutoModelForCausalLM
from
vllm
import
LLM
,
SamplingParams
,
configure_as_vllm_process
from
vllm.utils
import
get_ip
,
get_open_port
from
vllm.worker.worker
import
Worker
def
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
device
):
"""
vLLM provides `StatelessProcessGroup` to create a process group
without considering the global process group in torch.distributed.
It is recommended to create `StatelessProcessGroup`, and then initialize
the data-plane communication (NCCL) between external (train processes)
and vLLM workers.
"""
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.utils
import
StatelessProcessGroup
pg
=
StatelessProcessGroup
.
create
(
host
=
master_address
,
port
=
master_port
,
rank
=
rank
,
world_size
=
world_size
)
pynccl
=
PyNcclCommunicator
(
pg
,
device
=
device
)
return
pynccl
class
MyWorker
(
Worker
):
"""
The `MyWorker` class inherits from `Worker` to provide custom functions.
For simplicity, we define the `MyWorker` class in this self-contained
script. Normally, we should define the `MyWorker` class in a separate
file and pass the qualified name of the class to the `worker_cls`
parameter.
"""
def
init_weight_update_group
(
self
,
master_address
,
master_port
,
rank_offset
,
world_size
):
from
vllm.distributed.parallel_state
import
get_world_group
rank
=
get_world_group
().
rank
+
rank_offset
self
.
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
rank
,
world_size
,
self
.
device
,
)
def
update_weight
(
self
,
name
,
dtype
,
shape
):
weight
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
"cuda"
)
self
.
model_update_group
.
broadcast
(
weight
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
self
.
model_runner
.
model
.
load_weights
(
weights
=
[(
name
,
weight
)])
del
weight
def
check_weights_changed
(
self
):
"""
Check if the weights are updated to 0.
"""
weights_updated
=
True
for
name
,
p
in
self
.
model_runner
.
model
.
named_parameters
():
weights_updated
=
weights_updated
and
torch
.
allclose
(
p
,
torch
.
zeros_like
(
p
))
return
weights_updated
class
MyLLM
(
LLM
):
def
__init__
(
self
,
*
args
,
**
kwargs
):
# a hack to make the script work.
# stop ray from manipulating CUDA_VISIBLE_DEVICES
# at the top-level
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
super
().
__init__
(
*
args
,
**
kwargs
)
"""
Start the training process, here we use huggingface transformers
as an example to hold a model on GPU 0.
It is important for all the processes outside of vLLM to call
`configure_as_vllm_process` to set some common environment variables
the same as vLLM workers.
"""
configure_as_vllm_process
()
train_model
=
AutoModelForCausalLM
.
from_pretrained
(
"facebook/opt-125m"
)
train_model
.
to
(
"cuda:0"
)
"""
Start the inference process, here we use vLLM to hold a model on GPU 1 and
GPU 2. For the details on how to use ray, please refer to the ray
documentation https://docs.ray.io/en/latest/ .
"""
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"1,2"
ray
.
init
()
pg_inference
=
placement_group
([{
"GPU"
:
1
,
"CPU"
:
0
}]
*
2
)
ray
.
get
(
pg_inference
.
ready
())
scheduling_inference
=
PlacementGroupSchedulingStrategy
(
placement_group
=
pg_inference
,
placement_group_capture_child_tasks
=
True
,
placement_group_bundle_index
=
0
,
)
"""
launch the vLLM inference engine.
here we use `enforce_eager` to reduce the start time.
"""
llm
=
ray
.
remote
(
num_cpus
=
0
,
num_gpus
=
0
,
scheduling_strategy
=
scheduling_inference
,
)(
MyLLM
).
remote
(
model
=
"facebook/opt-125m"
,
enforce_eager
=
True
,
worker_cls
=
MyWorker
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"ray"
,
)
# Generate texts from the prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
outputs
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
# set up the communication between the training process
# and the inference engine.
master_address
=
get_ip
()
master_port
=
get_open_port
()
handle
=
llm
.
collective_rpc
.
remote
(
"init_weight_update_group"
,
args
=
(
master_address
,
master_port
,
1
,
3
))
model_update_group
=
stateless_init_process_group
(
master_address
,
master_port
,
0
,
3
,
torch
.
device
(
"cuda:0"
))
ray
.
get
(
handle
)
# simulate training, modify the weights of the model.
for
name
,
p
in
train_model
.
named_parameters
():
p
.
data
.
zero_
()
# sync weight from the training process to the inference engine.
for
name
,
p
in
train_model
.
named_parameters
():
handle
=
llm
.
collective_rpc
.
remote
(
"update_weight"
,
args
=
(
name
,
p
.
dtype
,
p
.
shape
))
model_update_group
.
broadcast
(
p
,
src
=
0
,
stream
=
torch
.
cuda
.
current_stream
())
ray
.
get
(
handle
)
# check if the weights are updated.
assert
all
(
ray
.
get
(
llm
.
collective_rpc
.
remote
(
"check_weights_changed"
)))
# use the updated model to generate texts, they will be nonsense
# because the weights are all zeros.
outputs_updated
=
ray
.
get
(
llm
.
generate
.
remote
(
prompts
,
sampling_params
))
for
output
in
outputs_updated
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
vllm/__init__.py
View file @
92e793d9
...
@@ -17,6 +17,44 @@ from vllm.sampling_params import SamplingParams
...
@@ -17,6 +17,44 @@ from vllm.sampling_params import SamplingParams
from
.version
import
__version__
,
__version_tuple__
from
.version
import
__version__
,
__version_tuple__
def
configure_as_vllm_process
():
"""
set some common config/environment variables that should be set
for all processes created by vllm and all processes
that interact with vllm workers.
"""
import
os
import
torch
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
# see https://github.com/vllm-project/vllm/issues/10480
os
.
environ
[
'TORCHINDUCTOR_COMPILE_THREADS'
]
=
'1'
# see https://github.com/vllm-project/vllm/issues/10619
torch
.
_inductor
.
config
.
compile_threads
=
1
from
vllm.platforms
import
current_platform
if
current_platform
.
is_xpu
():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch
.
_dynamo
.
config
.
disable
=
True
elif
current_platform
.
is_hpu
():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy
=
os
.
environ
.
get
(
'PT_HPU_LAZY_MODE'
,
'1'
)
==
'1'
if
is_lazy
:
torch
.
_dynamo
.
config
.
disable
=
True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os
.
environ
[
'PT_HPU_ENABLE_LAZY_COLLECTIVES'
]
=
'true'
__all__
=
[
__all__
=
[
"__version__"
,
"__version__"
,
"__version_tuple__"
,
"__version_tuple__"
,
...
@@ -42,4 +80,5 @@ __all__ = [
...
@@ -42,4 +80,5 @@ __all__ = [
"AsyncEngineArgs"
,
"AsyncEngineArgs"
,
"initialize_ray_cluster"
,
"initialize_ray_cluster"
,
"PoolingParams"
,
"PoolingParams"
,
"configure_as_vllm_process"
,
]
]
vllm/entrypoints/llm.py
View file @
92e793d9
...
@@ -4,6 +4,7 @@ from contextlib import contextmanager
...
@@ -4,6 +4,7 @@ from contextlib import contextmanager
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
from
typing
import
(
Any
,
ClassVar
,
Dict
,
List
,
Optional
,
Sequence
,
Tuple
,
Type
,
Union
,
cast
,
overload
)
Union
,
cast
,
overload
)
import
cloudpickle
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing_extensions
import
deprecated
from
typing_extensions
import
deprecated
...
@@ -186,6 +187,13 @@ class LLM:
...
@@ -186,6 +187,13 @@ class LLM:
if
"disable_log_stats"
not
in
kwargs
:
if
"disable_log_stats"
not
in
kwargs
:
kwargs
[
"disable_log_stats"
]
=
True
kwargs
[
"disable_log_stats"
]
=
True
if
"worker_cls"
in
kwargs
:
worker_cls
=
kwargs
[
"worker_cls"
]
# if the worker_cls is not qualified string name,
# we serialize it using cloudpickle to avoid pickling issues
if
isinstance
(
worker_cls
,
type
):
kwargs
[
"worker_cls"
]
=
cloudpickle
.
dumps
(
worker_cls
)
if
compilation_config
is
not
None
:
if
compilation_config
is
not
None
:
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
if
isinstance
(
compilation_config
,
(
int
,
dict
)):
compilation_config_instance
=
CompilationConfig
.
from_cli
(
compilation_config_instance
=
CompilationConfig
.
from_cli
(
...
@@ -455,6 +463,23 @@ class LLM:
...
@@ -455,6 +463,23 @@ class LLM:
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
outputs
=
self
.
_run_engine
(
use_tqdm
=
use_tqdm
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
return
self
.
engine_class
.
validate_outputs
(
outputs
,
RequestOutput
)
def
collective_rpc
(
self
,
method
:
str
,
timeout
:
Optional
[
float
]
=
None
,
args
:
Tuple
=
(),
kwargs
:
Optional
[
Dict
]
=
None
)
->
List
[
Any
]:
"""
Run a method on all workers, with homogeneous arguments.
The main extension point for the LLM entrypoint.
Users can provide custom worker class through `worker_cls`
argument, and implement new methods in the worker class.
Then, users can call the new methods through this API.
It is recommended to use this API to only pass control messages,
and set up data-plane communication to pass data.
"""
return
self
.
llm_engine
.
model_executor
.
collective_rpc
(
method
,
timeout
,
args
,
kwargs
)
def
beam_search
(
def
beam_search
(
self
,
self
,
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]],
prompts
:
List
[
Union
[
TokensPrompt
,
TextPrompt
]],
...
...
vllm/plugins/__init__.py
View file @
92e793d9
import
logging
import
logging
import
os
from
typing
import
Callable
,
Dict
from
typing
import
Callable
,
Dict
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -50,34 +47,6 @@ def load_general_plugins():
...
@@ -50,34 +47,6 @@ def load_general_plugins():
processes. They should be designed in a way that they can be loaded
processes. They should be designed in a way that they can be loaded
multiple times without causing issues.
multiple times without causing issues.
"""
"""
# all processes created by vllm will load plugins,
# and here we can inject some common environment variables
# for all processes.
# see https://github.com/vllm-project/vllm/issues/10480
os
.
environ
[
'TORCHINDUCTOR_COMPILE_THREADS'
]
=
'1'
# see https://github.com/vllm-project/vllm/issues/10619
torch
.
_inductor
.
config
.
compile_threads
=
1
from
vllm.platforms
import
current_platform
if
current_platform
.
is_xpu
():
# see https://github.com/pytorch/pytorch/blob/43c5f59/torch/_dynamo/config.py#L158
torch
.
_dynamo
.
config
.
disable
=
True
if
current_platform
.
is_hpu
():
# NOTE(kzawora): PT HPU lazy backend (PT_HPU_LAZY_MODE = 1)
# does not support torch.compile
# Eager backend (PT_HPU_LAZY_MODE = 0) must be selected for
# torch.compile support
is_lazy
=
os
.
environ
.
get
(
'PT_HPU_LAZY_MODE'
,
'1'
)
==
'1'
if
is_lazy
:
torch
.
_dynamo
.
config
.
disable
=
True
# NOTE(kzawora) multi-HPU inference with HPUGraphs (lazy-only)
# requires enabling lazy collectives
# see https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html # noqa: E501
os
.
environ
[
'PT_HPU_ENABLE_LAZY_COLLECTIVES'
]
=
'true'
global
plugins_loaded
global
plugins_loaded
if
plugins_loaded
:
if
plugins_loaded
:
return
return
...
...
vllm/worker/worker_base.py
View file @
92e793d9
...
@@ -4,6 +4,7 @@ import time
...
@@ -4,6 +4,7 @@ import time
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Tuple
,
Type
,
Union
import
cloudpickle
import
torch
import
torch
from
vllm.config
import
ObservabilityConfig
,
VllmConfig
from
vllm.config
import
ObservabilityConfig
,
VllmConfig
...
@@ -521,14 +522,20 @@ class WorkerWrapperBase:
...
@@ -521,14 +522,20 @@ class WorkerWrapperBase:
kwargs
=
all_kwargs
[
self
.
rpc_rank
]
kwargs
=
all_kwargs
[
self
.
rpc_rank
]
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
# see https://github.com/NVIDIA/nccl/issues/1234
from
vllm
import
configure_as_vllm_process
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
configure_as_vllm_process
()
from
vllm.plugins
import
load_general_plugins
from
vllm.plugins
import
load_general_plugins
load_general_plugins
()
load_general_plugins
()
worker_class
=
resolve_obj_by_qualname
(
if
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
str
):
self
.
vllm_config
.
parallel_config
.
worker_cls
)
worker_class
=
resolve_obj_by_qualname
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
else
:
assert
isinstance
(
self
.
vllm_config
.
parallel_config
.
worker_cls
,
bytes
)
worker_class
=
cloudpickle
.
loads
(
self
.
vllm_config
.
parallel_config
.
worker_cls
)
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment