Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
bf53e0c7
Unverified
Commit
bf53e0c7
authored
Jan 16, 2025
by
youkaichao
Committed by
GitHub
Jan 16, 2025
Browse files
Support torchrun and SPMD-style offline inference (#12071)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
dd7c9ad8
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
248 additions
and
30 deletions
+248
-30
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+1
-0
examples/offline_inference/torchrun_example.py
examples/offline_inference/torchrun_example.py
+64
-0
tests/distributed/test_torchrun_example.py
tests/distributed/test_torchrun_example.py
+56
-0
tests/engine/test_multiproc_workers.py
tests/engine/test_multiproc_workers.py
+1
-1
vllm/config.py
vllm/config.py
+4
-3
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+5
-0
vllm/executor/ray_distributed_executor.py
vllm/executor/ray_distributed_executor.py
+3
-3
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+80
-1
vllm/lora/layers.py
vllm/lora/layers.py
+2
-2
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+10
-6
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+1
-1
vllm/worker/worker.py
vllm/worker/worker.py
+0
-3
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+20
-9
No files found.
.buildkite/test-pipeline.yaml
View file @
bf53e0c7
...
...
@@ -463,6 +463,7 @@ steps:
-
vllm/worker/worker.py
-
vllm/worker/model_runner.py
commands
:
-
torchrun --nproc-per-node=2 distributed/test_torchrun_example.py
-
pytest -v -s ./compile/test_basic_correctness.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep 'Same node test passed'
...
...
examples/offline_inference/torchrun_example.py
0 → 100644
View file @
bf53e0c7
"""
experimental support for tensor-parallel inference with torchrun,
see https://github.com/vllm-project/vllm/issues/11400 for
the motivation and use case for this example.
run the script with `torchrun --nproc-per-node=2 torchrun_example.py`,
the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""
from
vllm
import
LLM
,
SamplingParams
# Create prompts, the same across all ranks
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# Create sampling parameters, the same across all ranks
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# Use `distributed_executor_backend="external_launcher"` so that
# this llm engine/instance only creates one worker.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"external_launcher"
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# all ranks will have the same outputs
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
a PyTorch ProcessGroup with GLOO backend.
```python
from vllm.distributed.parallel_state import get_world_group
cpu_group = get_world_group().cpu_group
torch_rank = dist.get_rank(group=cpu_group)
if torch_rank == 0:
# do something for rank 0, e.g. saving the results to disk.
```
2. to communicate data across all ranks, use the model's device group,
a PyTorch ProcessGroup with NCCL backend.
```python
from vllm.distributed.parallel_state import get_world_group
device_group = get_world_group().device_group
```
3. to access the model directly in every rank, use the following code:
```python
llm.llm_engine.model_executor.driver_worker.worker.model_runner.model
```
"""
tests/distributed/test_torchrun_example.py
0 → 100644
View file @
bf53e0c7
# unit test for `examples/offline_inference/torchrun_example.py`
import
random
import
torch.distributed
as
dist
from
vllm
import
LLM
,
SamplingParams
from
vllm.distributed.parallel_state
import
get_world_group
# Create prompts
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
)
# set different `gpu_memory_utilization` and `swap_space` for different ranks,
# to test if all ranks agree on the same kv cache configuration.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
2
,
distributed_executor_backend
=
"external_launcher"
,
gpu_memory_utilization
=
random
.
uniform
(
0.7
,
0.9
),
swap_space
=
random
.
randint
(
1
,
4
))
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
cpu_group
=
get_world_group
().
cpu_group
torch_rank
=
dist
.
get_rank
(
group
=
cpu_group
)
def
test_consistent_across_ranks
(
obj
):
if
torch_rank
==
0
:
dist
.
broadcast_object_list
([
obj
],
src
=
0
,
group
=
cpu_group
)
else
:
container
=
[
None
]
dist
.
broadcast_object_list
(
container
,
src
=
0
,
group
=
cpu_group
)
assert
container
[
0
]
==
obj
test_consistent_across_ranks
(
llm
.
llm_engine
.
vllm_config
.
cache_config
.
num_cpu_blocks
)
test_consistent_across_ranks
(
llm
.
llm_engine
.
vllm_config
.
cache_config
.
num_gpu_blocks
)
# all ranks should have the same outputs
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
test_consistent_across_ranks
(
prompt
)
test_consistent_across_ranks
(
generated_text
)
print
(
f
"Rank
{
torch_rank
}
, Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
tests/engine/test_multiproc_workers.py
View file @
bf53e0c7
...
...
@@ -22,7 +22,7 @@ class DummyWorkerWrapper(WorkerWrapperBase):
# simulate error case
raise
worker_input
return
self
.
rank
,
input
return
self
.
rpc_
rank
,
input
def
_start_workers
()
->
Tuple
[
List
[
ProcessWorkerWrapper
],
WorkerMonitor
]:
...
...
vllm/config.py
View file @
bf53e0c7
...
...
@@ -1338,14 +1338,15 @@ class ParallelConfig:
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.platforms
import
current_platform
if
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
"uni"
,
None
)
and
not
(
isinstance
(
"ray"
,
"mp"
,
"uni"
,
"external_launcher"
,
None
)
and
not
(
isinstance
(
self
.
distributed_executor_backend
,
type
)
and
issubclass
(
self
.
distributed_executor_backend
,
ExecutorBase
)):
raise
ValueError
(
"Unrecognized distributed executor backend "
f
"
{
self
.
distributed_executor_backend
}
. Supported "
"values are 'ray', 'mp' 'uni',
or custom ExecutorBase
"
" subclass."
)
"values are 'ray', 'mp' 'uni',
'external_launcher' or
"
"
custom ExecutorBase
subclass."
)
if
self
.
use_ray
:
from
vllm.executor
import
ray_utils
ray_utils
.
assert_ray_available
()
...
...
vllm/engine/arg_utils.py
View file @
bf53e0c7
...
...
@@ -388,7 +388,7 @@ class EngineArgs:
# Parallel arguments
parser
.
add_argument
(
'--distributed-executor-backend'
,
choices
=
[
'ray'
,
'mp'
],
choices
=
[
'ray'
,
'mp'
,
'uni'
,
'external_launcher'
],
default
=
EngineArgs
.
distributed_executor_backend
,
help
=
'Backend to use for distributed model '
'workers, either "ray" or "mp" (multiprocessing). If the product '
...
...
vllm/engine/llm_engine.py
View file @
bf53e0c7
...
...
@@ -457,6 +457,11 @@ class LLMEngine:
# JAX-style, single-process, multi-device executor.
from
vllm.executor.uniproc_executor
import
UniProcExecutor
executor_class
=
UniProcExecutor
elif
distributed_executor_backend
==
"external_launcher"
:
# executor with external launcher
from
vllm.executor.uniproc_executor
import
(
# noqa
ExecutorWithExternalLauncher
)
executor_class
=
ExecutorWithExternalLauncher
else
:
from
vllm.executor.uniproc_executor
import
UniProcExecutor
executor_class
=
UniProcExecutor
...
...
vllm/executor/ray_distributed_executor.py
View file @
bf53e0c7
...
...
@@ -172,7 +172,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
,
rank
=
rank
)
rpc_
rank
=
rank
)
else
:
worker
=
ray
.
remote
(
num_cpus
=
0
,
...
...
@@ -181,7 +181,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
scheduling_strategy
=
scheduling_strategy
,
**
ray_remote_kwargs
,
)(
RayWorkerWrapper
).
remote
(
vllm_config
=
self
.
vllm_config
,
rank
=
rank
)
rpc_
rank
=
rank
)
worker_metadata
.
append
(
RayWorkerMetaData
(
worker
=
worker
,
created_rank
=
rank
))
rank
+=
1
...
...
@@ -204,7 +204,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# as the resource holder for the driver process.
self
.
driver_dummy_worker
=
worker
self
.
driver_worker
=
RayWorkerWrapper
(
vllm_config
=
self
.
vllm_config
,
rank
=
0
)
vllm_config
=
self
.
vllm_config
,
rpc_
rank
=
0
)
worker_metadata
.
pop
(
i
)
break
...
...
vllm/executor/uniproc_executor.py
View file @
bf53e0c7
import
os
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.utils
import
get_distributed_init_method
,
get_ip
,
get_open_port
...
...
@@ -16,7 +21,7 @@ class UniProcExecutor(ExecutorBase):
"""Initialize the worker and load the model.
"""
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rank
=
0
)
rpc_
rank
=
0
)
distributed_init_method
=
get_distributed_init_method
(
get_ip
(),
get_open_port
())
local_rank
=
0
...
...
@@ -55,3 +60,77 @@ class UniProcExecutor(ExecutorBase):
UniProcExecutorAsync
=
UniProcExecutor
class
ExecutorWithExternalLauncher
(
UniProcExecutor
):
"""An executor that uses external launchers to launch engines,
specially designed for torchrun-compatible launchers, for
offline inference with tensor parallelism.
see https://github.com/vllm-project/vllm/issues/11400 for
the motivation, and examples/offline_inference/torchrun_example.py
for the usage example.
The key idea: although it is tensor-parallel inference, we only
create one worker per executor, users will launch multiple
engines with torchrun-compatible launchers, and all these engines
work together to process the same prompts. When scheduling is
deterministic, all the engines will generate the same outputs,
and they don't need to synchronize the states with each other.
"""
uses_ray
:
bool
=
False
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""
assert
self
.
vllm_config
.
parallel_config
.
pipeline_parallel_size
==
1
,
\
(
"ExecutorWithExternalLauncher does not "
"support pipeline parallelism."
)
assert
self
.
vllm_config
.
scheduler_config
.
delay_factor
==
0.0
,
\
(
"ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
"does not support delay_factor in scheduling"
)
assert
not
envs
.
VLLM_USE_V1
,
\
(
"V1 architecture cannot guarantee deterministic execution, "
"so it is not supported in ExecutorWithExternalLauncher."
)
self
.
driver_worker
=
WorkerWrapperBase
(
vllm_config
=
self
.
vllm_config
,
rpc_rank
=
0
)
# engines are launched in torchrun-compatible launchers
# so we can use the env:// method.
# required env vars:
# - RANK
# - MASTER_ADDR
# - MASTER_PORT
distributed_init_method
=
"env://"
rank
=
int
(
os
.
environ
[
"RANK"
])
local_rank
=
rank
is_driver_worker
=
True
kwargs
=
dict
(
vllm_config
=
self
.
vllm_config
,
local_rank
=
local_rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
is_driver_worker
=
is_driver_worker
,
)
self
.
collective_rpc
(
"init_worker"
,
args
=
([
kwargs
],
))
self
.
collective_rpc
(
"init_device"
)
self
.
collective_rpc
(
"load_model"
)
def
determine_num_available_blocks
(
self
)
->
Tuple
[
int
,
int
]:
"""
Determine the number of available KV blocks.
Add an additional all_reduce to get the min across all ranks.
Note that even if we have the same `gpu_memory_utilization` and
`swap_space`, the available memory in every rank might still
differ because NCCL can take different amounts of memory in
different ranks. Therefore, it is necessary to test if all ranks
agree on the same KV cache configuration.
"""
a
,
b
=
super
().
determine_num_available_blocks
()
from
vllm.distributed.parallel_state
import
get_world_group
cpu_group
=
get_world_group
().
cpu_group
a_tensor
=
torch
.
tensor
([
a
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
b_tensor
=
torch
.
tensor
([
b
],
device
=
"cpu"
,
dtype
=
torch
.
int64
)
dist
.
all_reduce
(
a_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
dist
.
all_reduce
(
b_tensor
,
group
=
cpu_group
,
op
=
dist
.
ReduceOp
.
MIN
)
return
a_tensor
.
item
(),
b_tensor
.
item
()
vllm/lora/layers.py
View file @
bf53e0c7
...
...
@@ -940,8 +940,8 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
return
self
.
base_layer
.
soft_cap
@
property
def
use_gather
(
self
):
return
self
.
base_layer
.
use_gather
def
use_
all_
gather
(
self
):
return
self
.
base_layer
.
use_
all_
gather
@
property
def
org_vocab_size
(
self
):
...
...
vllm/model_executor/layers/logits_processor.py
View file @
bf53e0c7
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.config
import
get_current_vllm_config
from
vllm.distributed
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_gather
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
...
...
@@ -44,8 +45,10 @@ class LogitsProcessor(nn.Module):
self
.
soft_cap
=
soft_cap
# Whether to use gather or all-gather to gather the logits.
self
.
use_gather
=
not
current_platform
.
is_tpu
(
)
and
not
envs
.
VLLM_USE_V1
parallel_config
=
get_current_vllm_config
().
parallel_config
self
.
use_all_gather
=
current_platform
.
is_tpu
()
\
or
envs
.
VLLM_USE_V1
\
or
parallel_config
.
distributed_executor_backend
==
"external_launcher"
# noqa
def
forward
(
self
,
...
...
@@ -88,16 +91,17 @@ class LogitsProcessor(nn.Module):
logits
=
lm_head
.
linear_method
.
apply
(
lm_head
,
hidden_states
,
bias
=
embedding_bias
)
if
self
.
use_gather
:
# None may be returned for rank > 0
logits
=
tensor_model_parallel_gather
(
logits
)
else
:
if
self
.
use_all_gather
:
# Gather is not supported for some devices such as TPUs.
# Use all-gather instead.
# NOTE(woosuk): Here, the outputs of every device should not be None
# because XLA requires strict SPMD among all devices. Every device
# should execute the same operations after gathering the logits.
logits
=
tensor_model_parallel_all_gather
(
logits
)
else
:
# None may be returned for rank > 0
logits
=
tensor_model_parallel_gather
(
logits
)
# Remove paddings in vocab (if any).
if
logits
is
not
None
:
logits
=
logits
[...,
:
self
.
org_vocab_size
]
...
...
vllm/v1/executor/multiproc_executor.py
View file @
bf53e0c7
...
...
@@ -246,7 +246,7 @@ class WorkerProc:
ready_path
:
str
,
):
self
.
rank
=
rank
wrapper
=
WorkerWrapperBase
(
vllm_config
=
vllm_config
,
rank
=
rank
)
wrapper
=
WorkerWrapperBase
(
vllm_config
=
vllm_config
,
rpc_
rank
=
rank
)
# TODO: move `init_worker` to executor level as a collective rpc call
all_kwargs
:
List
[
Dict
]
=
[
{}
for
_
in
range
(
vllm_config
.
parallel_config
.
world_size
)
...
...
vllm/worker/worker.py
View file @
bf53e0c7
...
...
@@ -55,9 +55,6 @@ class Worker(LocalOrDistributedWorkerBase):
self
.
rank
=
rank
self
.
distributed_init_method
=
distributed_init_method
self
.
is_driver_worker
=
is_driver_worker
if
is_driver_worker
:
assert
rank
%
self
.
parallel_config
.
tensor_parallel_size
==
0
,
\
"Driver worker should be rank 0 of tensor parallel group."
if
self
.
model_config
.
trust_remote_code
:
# note: lazy import to avoid importing torch before initializing
from
vllm.utils
import
init_cached_hf_modules
...
...
vllm/worker/worker_base.py
View file @
bf53e0c7
...
...
@@ -461,7 +461,8 @@ class LocalOrDistributedWorkerBase(WorkerBase):
class
WorkerWrapperBase
:
"""
The whole point of this class is to lazily initialize the worker.
This class represents one process in an executor/engine. It is responsible
for lazily initializing the worker and handling the worker's lifecycle.
We first instantiate the WorkerWrapper, which remembers the worker module
and class name. Then, when we call `update_environment_variables`, and the
real initialization happens in `init_worker`.
...
...
@@ -470,9 +471,19 @@ class WorkerWrapperBase:
def
__init__
(
self
,
vllm_config
:
VllmConfig
,
rank
:
int
=
0
,
rpc_
rank
:
int
=
0
,
)
->
None
:
self
.
rank
=
rank
"""
Initialize the worker wrapper with the given vllm_config and rpc_rank.
Note: rpc_rank is the rank of the worker in the executor. In most cases,
it is also the rank of the worker in the distributed group. However,
when multiple executors work together, they can be different.
e.g. in the case of SPMD-style offline inference with TP=2,
users can launch 2 engines/executors, each with only 1 worker.
All workers have rpc_rank=0, but they have different ranks in the TP
group.
"""
self
.
rpc_rank
=
rpc_rank
self
.
vllm_config
=
vllm_config
self
.
worker
:
Optional
[
WorkerBase
]
=
None
if
vllm_config
.
model_config
is
not
None
:
...
...
@@ -485,16 +496,16 @@ class WorkerWrapperBase:
def
adjust_rank
(
self
,
rank_mapping
:
Dict
[
int
,
int
])
->
None
:
"""
Adjust the rank based on the given mapping.
Adjust the
rpc_
rank based on the given mapping.
It is only used during the initialization of the executor,
to adjust the rank of workers after we create all workers.
to adjust the
rpc_
rank of workers after we create all workers.
"""
if
self
.
rank
in
rank_mapping
:
self
.
rank
=
rank_mapping
[
self
.
rank
]
if
self
.
rpc_
rank
in
rank_mapping
:
self
.
rpc_
rank
=
rank_mapping
[
self
.
rpc_
rank
]
def
update_environment_variables
(
self
,
envs_list
:
List
[
Dict
[
str
,
str
]])
->
None
:
envs
=
envs_list
[
self
.
rank
]
envs
=
envs_list
[
self
.
rpc_
rank
]
key
=
'CUDA_VISIBLE_DEVICES'
if
key
in
envs
and
key
in
os
.
environ
:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
...
...
@@ -507,7 +518,7 @@ class WorkerWrapperBase:
Here we inject some common logic before initializing the worker.
Arguments are passed to the worker class constructor.
"""
kwargs
=
all_kwargs
[
self
.
rank
]
kwargs
=
all_kwargs
[
self
.
rpc_
rank
]
enable_trace_function_call_for_thread
(
self
.
vllm_config
)
# see https://github.com/NVIDIA/nccl/issues/1234
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment