Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3d2779c2
Unverified
Commit
3d2779c2
authored
May 15, 2025
by
Lucia Fang
Committed by
GitHub
May 15, 2025
Browse files
[Feature] Support Pipeline Parallism in torchrun SPMD offline inference for V1 (#17827)
Signed-off-by:
Lucia Fang
<
fanglu@fb.com
>
parent
6b31c84a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
55 additions
and
27 deletions
+55
-27
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
examples/offline_inference/torchrun_example.py
examples/offline_inference/torchrun_example.py
+14
-9
tests/distributed/test_torchrun_example.py
tests/distributed/test_torchrun_example.py
+2
-1
vllm/config.py
vllm/config.py
+0
-1
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+4
-2
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+3
-2
vllm/executor/uniproc_executor.py
vllm/executor/uniproc_executor.py
+0
-3
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+27
-6
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-3
No files found.
.buildkite/test-pipeline.yaml
View file @
3d2779c2
...
...
@@ -148,6 +148,8 @@ steps:
# test with tp=2 and external_dp=2
-
VLLM_USE_V1=0 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
-
torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with tp=2 and pp=2
-
PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
# test with internal dp
-
python3 ../examples/offline_inference/data_parallel.py
-
TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
...
...
examples/offline_inference/torchrun_example.py
View file @
3d2779c2
...
...
@@ -8,6 +8,8 @@ the argument 2 should match the `tensor_parallel_size` below.
see `tests/distributed/test_torchrun_example.py` for the unit test.
"""
import
torch.distributed
as
dist
from
vllm
import
LLM
,
SamplingParams
# Create prompts, the same across all ranks
...
...
@@ -27,23 +29,26 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# all ranks have the same random seed, so that sampling can be
# deterministic across ranks.
llm
=
LLM
(
model
=
"
facebook/opt-125m
"
,
model
=
"
meta-llama/Llama-3.1-8B
"
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
2
,
distributed_executor_backend
=
"external_launcher"
,
seed
=
0
,
max_model_len
=
32768
,
seed
=
1
,
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# all ranks will have the same outputs
print
(
"-"
*
50
)
for
output
in
outputs
:
if
dist
.
get_rank
()
==
0
:
print
(
"-"
*
50
)
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
\n
"
f
"Generated text:
{
generated_text
!
r
}
"
)
f
"Generated text:
{
generated_text
!
r
}
\n
"
)
print
(
"-"
*
50
)
"""
"""
Further tips:
1. to communicate control messages across all ranks, use the cpu group,
...
...
tests/distributed/test_torchrun_example.py
View file @
3d2779c2
# SPDX-License-Identifier: Apache-2.0
# unit test for `examples/offline_inference/torchrun_example.py`
import
os
import
random
import
torch.distributed
as
dist
...
...
@@ -25,6 +25,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# to test if all ranks agree on the same kv cache configuration.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
2
,
pipeline_parallel_size
=
int
(
os
.
getenv
(
"PP_SIZE"
,
1
)),
distributed_executor_backend
=
"external_launcher"
,
gpu_memory_utilization
=
random
.
uniform
(
0.7
,
0.9
),
swap_space
=
random
.
randint
(
1
,
4
),
...
...
vllm/config.py
View file @
3d2779c2
...
...
@@ -1695,7 +1695,6 @@ class ParallelConfig:
"""Port of the data parallel master."""
enable_expert_parallel
:
bool
=
False
"""Use expert parallelism instead of tensor parallelism for MoE layers."""
max_parallel_loading_workers
:
Optional
[
int
]
=
None
"""Maximum number of parallel loading workers when loading model
sequentially in multiple batches. To avoid RAM OOM when using tensor
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
3d2779c2
...
...
@@ -265,6 +265,7 @@ class CustomAllreduce:
def
close
(
self
):
if
not
self
.
disabled
and
self
.
_ptr
:
if
ops
is
not
None
:
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
...
...
@@ -298,4 +299,5 @@ class CustomAllreduce:
rank
:
Optional
[
int
]
=
0
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
if
ops
is
not
None
:
ops
.
free_shared_buffer
(
pointers
[
rank
])
vllm/engine/arg_utils.py
View file @
3d2779c2
...
...
@@ -1383,9 +1383,10 @@ class EngineArgs:
return
False
if
(
self
.
pipeline_parallel_size
>
1
and
self
.
distributed_executor_backend
not
in
[
"ray"
,
"mp"
]):
and
self
.
distributed_executor_backend
not
in
(
"ray"
,
"mp"
,
"external_launcher"
)):
name
=
"Pipeline Parallelism without Ray distributed executor "
\
"or multiprocessing executor"
"or multiprocessing executor
or external launcher
"
_raise_or_fallback
(
feature_name
=
name
,
recommend_to_remove
=
False
)
return
False
...
...
vllm/executor/uniproc_executor.py
View file @
3d2779c2
...
...
@@ -86,9 +86,6 @@ class ExecutorWithExternalLauncher(UniProcExecutor):
def
_init_executor
(
self
)
->
None
:
"""Initialize the worker and load the model.
"""
assert
self
.
vllm_config
.
parallel_config
.
pipeline_parallel_size
==
1
,
\
(
"ExecutorWithExternalLauncher does not "
"support pipeline parallelism."
)
assert
self
.
vllm_config
.
scheduler_config
.
delay_factor
==
0.0
,
\
(
"ExecutorWithExternalLauncher needs deterministic "
"execution, so it"
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3d2779c2
...
...
@@ -22,7 +22,8 @@ from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group
)
from
vllm.distributed.kv_transfer.kv_connector.v1
import
KVConnectorBase_V1
from
vllm.distributed.parallel_state
import
(
get_pp_group
,
graph_capture
,
prepare_communication_buffer_for_model
)
get_pp_group
,
get_tp_group
,
graph_capture
,
prepare_communication_buffer_for_model
)
from
vllm.forward_context
import
get_forward_context
,
set_forward_context
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.rotary_embedding
import
MRotaryEmbedding
...
...
@@ -1162,13 +1163,32 @@ class GPUModelRunner(LoRAModelRunnerMixin):
hidden_states
,
aux_hidden_states
=
model_output
else
:
hidden_states
=
model_output
# Broadcast PP output for external_launcher (torchrun)
# to make sure we are synced across pp ranks
# TODO: Support overlapping mirco-batches
# https://github.com/vllm-project/vllm/issues/18019
broadcast_pp_output
=
\
self
.
parallel_config
.
distributed_executor_backend
\
==
"external_launcher"
and
len
(
get_pp_group
().
ranks
)
>
0
if
not
get_pp_group
().
is_last_rank
:
# For mid-pipeline stages, return the hidden states.
if
not
broadcast_pp_output
:
return
hidden_states
assert
isinstance
(
hidden_states
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
hidden_states
.
tensors
,
all_gather_group
=
get_tp_group
())
logits
=
None
else
:
sample_hidden_states
=
hidden_states
[
logits_indices
]
logits
=
self
.
model
.
compute_logits
(
sample_hidden_states
,
None
)
if
broadcast_pp_output
:
model_output_broadcast_data
=
{
"logits"
:
logits
.
contiguous
(),
}
if
logits
is
not
None
else
{}
model_output_broadcast_data
=
get_pp_group
().
broadcast_tensor_dict
(
model_output_broadcast_data
,
src
=
len
(
get_pp_group
().
ranks
)
-
1
)
assert
model_output_broadcast_data
is
not
None
logits
=
model_output_broadcast_data
[
"logits"
]
# Apply structured output bitmasks if present
if
scheduler_output
.
grammar_bitmask
is
not
None
:
...
...
@@ -1186,6 +1206,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# creates a new tensor with separate storage from the original
# logits tensor. This means any in-place operations on bonus_logits
# won't affect the original logits tensor.
assert
logits
is
not
None
bonus_logits
=
logits
[
spec_decode_metadata
.
bonus_logits_indices
]
sampler_output
=
self
.
sampler
(
logits
=
bonus_logits
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
3d2779c2
...
...
@@ -275,13 +275,13 @@ class Worker(WorkerBase):
output
=
self
.
model_runner
.
execute_model
(
scheduler_output
,
intermediate_tensors
)
if
not
get_pp_group
().
is_last_rank
:
parallel_config
=
self
.
vllm_config
.
parallel_config
if
parallel_config
.
distributed_executor_backend
!=
"external_launcher"
\
and
not
get_pp_group
().
is_last_rank
:
assert
isinstance
(
output
,
IntermediateTensors
)
get_pp_group
().
send_tensor_dict
(
output
.
tensors
,
all_gather_group
=
get_tp_group
())
return
None
assert
isinstance
(
output
,
ModelRunnerOutput
)
return
output
if
self
.
is_driver_worker
else
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment