Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3e472d88
Unverified
Commit
3e472d88
authored
Feb 22, 2025
by
youkaichao
Committed by
GitHub
Feb 22, 2025
Browse files
[core] set up data parallel communication (#13591)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
7f6bae56
Changes
17
Show whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
416 additions
and
28 deletions
+416
-28
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+2
-0
examples/offline_inference/data_parallel.py
examples/offline_inference/data_parallel.py
+76
-0
vllm/config.py
vllm/config.py
+57
-0
vllm/distributed/device_communicators/cuda_communicator.py
vllm/distributed/device_communicators/cuda_communicator.py
+2
-2
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+7
-4
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+62
-14
vllm/distributed/utils.py
vllm/distributed/utils.py
+90
-1
vllm/envs.py
vllm/envs.py
+20
-0
vllm/forward_context.py
vllm/forward_context.py
+31
-3
vllm/utils.py
vllm/utils.py
+18
-0
vllm/v1/engine/core.py
vllm/v1/engine/core.py
+3
-0
vllm/v1/engine/core_client.py
vllm/v1/engine/core_client.py
+14
-0
vllm/v1/engine/llm_engine.py
vllm/v1/engine/llm_engine.py
+24
-2
vllm/v1/executor/multiproc_executor.py
vllm/v1/executor/multiproc_executor.py
+1
-1
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+1
-1
vllm/v1/worker/gpu_worker.py
vllm/v1/worker/gpu_worker.py
+3
-0
vllm/worker/worker_base.py
vllm/worker/worker_base.py
+5
-0
No files found.
.buildkite/test-pipeline.yaml
View file @
3e472d88
...
@@ -134,7 +134,9 @@ steps:
...
@@ -134,7 +134,9 @@ steps:
-
tests/compile/test_basic_correctness
-
tests/compile/test_basic_correctness
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/rlhf.py
-
examples/offline_inference/rlhf_colocate.py
-
examples/offline_inference/rlhf_colocate.py
-
tests/examples/offline_inference/data_parallel.py
commands
:
commands
:
-
VLLM_USE_V1=1 python3 ../examples/offline_inference/data_parallel.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s distributed/test_utils.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s distributed/test_pynccl.py
...
...
examples/offline_inference/data_parallel.py
0 → 100644
View file @
3e472d88
# SPDX-License-Identifier: Apache-2.0
# usage: VLLM_USE_V1=1 python examples/offline_inference/data_parallel.py
# we need to have a launcher to create multiple data parallel
# ranks. And each rank will create a vLLM instance to process its own prompts.
import
os
from
vllm
import
LLM
,
SamplingParams
from
vllm.utils
import
get_open_port
def
main
(
dp_size
,
dp_rank
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
):
os
.
environ
[
"VLLM_DP_RANK"
]
=
str
(
dp_rank
)
os
.
environ
[
"VLLM_DP_SIZE"
]
=
str
(
dp_size
)
os
.
environ
[
"VLLM_DP_MASTER_IP"
]
=
dp_master_ip
os
.
environ
[
"VLLM_DP_MASTER_PORT"
]
=
str
(
dp_master_port
)
# set devices for each dp_rank
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
","
.
join
(
str
(
i
)
for
i
in
range
(
dp_rank
*
GPUs_per_dp_rank
,
(
dp_rank
+
1
)
*
GPUs_per_dp_rank
))
# Sample prompts.
prompts
=
[
"Hello, my name is"
,
"The president of the United States is"
,
"The capital of France is"
,
"The future of AI is"
,
]
# with DP, each rank should process different prompts.
# usually all the DP ranks process a full dataset,
# and each rank processes a different part of the dataset.
promts_per_rank
=
len
(
prompts
)
//
dp_size
start
=
dp_rank
*
promts_per_rank
end
=
start
+
promts_per_rank
prompts
=
prompts
[
start
:
end
]
if
len
(
prompts
)
==
0
:
# if any rank has no prompts to process,
# we need to set a placeholder prompt
prompts
=
[
"Placeholder"
]
print
(
f
"DP rank
{
dp_rank
}
needs to process
{
len
(
prompts
)
}
prompts"
)
# Create a sampling params object.
# since we are doing data parallel, every rank can have different
# sampling params. here we set different max_tokens for different
# ranks for demonstration.
sampling_params
=
SamplingParams
(
temperature
=
0.8
,
top_p
=
0.95
,
max_tokens
=
16
*
(
dp_rank
+
1
))
# Create an LLM.
llm
=
LLM
(
model
=
"facebook/opt-125m"
,
tensor_parallel_size
=
2
,
enforce_eager
=
True
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
# Print the outputs.
for
output
in
outputs
:
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"DP rank
{
dp_rank
}
, Prompt:
{
prompt
!
r
}
, "
f
"Generated text:
{
generated_text
!
r
}
"
)
if
__name__
==
"__main__"
:
from
multiprocessing
import
Process
dp_size
=
2
GPUs_per_dp_rank
=
2
dp_master_ip
=
"127.0.0.1"
dp_master_port
=
get_open_port
()
procs
=
[]
for
i
in
range
(
dp_size
):
proc
=
Process
(
target
=
main
,
args
=
(
dp_size
,
i
,
dp_master_ip
,
dp_master_port
,
GPUs_per_dp_rank
))
proc
.
start
()
procs
.
append
(
proc
)
for
proc
in
procs
:
proc
.
join
()
vllm/config.py
View file @
3e472d88
...
@@ -16,6 +16,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
...
@@ -16,6 +16,7 @@ from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
import
torch
import
torch
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
torch.distributed
import
ProcessGroup
,
ReduceOp
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -1296,6 +1297,11 @@ class ParallelConfig:
...
@@ -1296,6 +1297,11 @@ class ParallelConfig:
pipeline_parallel_size
:
int
=
1
# Number of pipeline parallel groups.
pipeline_parallel_size
:
int
=
1
# Number of pipeline parallel groups.
tensor_parallel_size
:
int
=
1
# Number of tensor parallel groups.
tensor_parallel_size
:
int
=
1
# Number of tensor parallel groups.
data_parallel_size
:
int
=
1
# Number of data parallel groups.
data_parallel_rank
:
int
=
0
# Rank of the data parallel group.
# IP of the data parallel master.
data_parallel_master_ip
:
str
=
"127.0.0.1"
data_parallel_master_port
:
int
=
29500
# Port of the data parallel master.
# Maximum number of multiple batches
# Maximum number of multiple batches
# when load model sequentially. To avoid RAM OOM when using tensor
# when load model sequentially. To avoid RAM OOM when using tensor
...
@@ -1329,10 +1335,55 @@ class ParallelConfig:
...
@@ -1329,10 +1335,55 @@ class ParallelConfig:
worker_cls
:
str
=
"auto"
worker_cls
:
str
=
"auto"
sd_worker_cls
:
str
=
"auto"
sd_worker_cls
:
str
=
"auto"
# world_size is TPxPP, it affects the number of workers we create.
world_size
:
int
=
field
(
init
=
False
)
world_size
:
int
=
field
(
init
=
False
)
# world_size_across_dp is TPxPPxDP, it is the size of the world
# including data parallelism.
world_size_across_dp
:
int
=
field
(
init
=
False
)
rank
:
int
=
0
rank
:
int
=
0
def
get_next_dp_init_port
(
self
)
->
int
:
"""
We might need to initialize process groups in multiple
processes that is related to data parallelism,
e.g. both in the worker and in the engine, which
can live in different processes. To avoid port conflicts, we
increment the port number each time we need to initialize a
new process group related to data parallelism.
"""
answer
=
self
.
data_parallel_master_port
self
.
data_parallel_master_port
+=
1
return
answer
def
stateless_init_dp_group
(
self
)
->
"ProcessGroup"
:
from
vllm.distributed.utils
import
(
stateless_init_torch_distributed_process_group
)
# use gloo since the engine process might not have cuda device
dp_group
=
stateless_init_torch_distributed_process_group
(
self
.
data_parallel_master_ip
,
self
.
get_next_dp_init_port
(),
self
.
data_parallel_rank
,
self
.
data_parallel_size
,
backend
=
"gloo"
)
return
dp_group
@
staticmethod
def
has_unfinished_dp
(
dp_group
:
"ProcessGroup"
,
has_unfinished
:
bool
)
->
bool
:
tensor
=
torch
.
tensor
([
has_unfinished
],
dtype
=
torch
.
int32
,
device
=
"cpu"
)
# dp rank 0: has_unfinished_seqs=True
# dp rank 1: has_unfinished_seqs=False
# aggregated: has_unfinished_seqs=True
# so this is an OR operation, i.e. MAX in integers
torch
.
distributed
.
all_reduce
(
tensor
,
op
=
ReduceOp
.
MAX
,
group
=
dp_group
)
aggregated_has_unfinished
=
bool
(
tensor
.
item
())
return
aggregated_has_unfinished
def
compute_hash
(
self
):
def
compute_hash
(
self
):
"""
"""
Provide a hash that uniquely identifies all the configs
Provide a hash that uniquely identifies all the configs
...
@@ -1350,6 +1401,12 @@ class ParallelConfig:
...
@@ -1350,6 +1401,12 @@ class ParallelConfig:
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
world_size
=
self
.
pipeline_parallel_size
*
\
self
.
tensor_parallel_size
self
.
tensor_parallel_size
self
.
data_parallel_size
=
envs
.
VLLM_DP_SIZE
self
.
data_parallel_rank
=
envs
.
VLLM_DP_RANK
self
.
data_parallel_master_ip
=
envs
.
VLLM_DP_MASTER_IP
self
.
data_parallel_master_port
=
envs
.
VLLM_DP_MASTER_PORT
self
.
world_size_across_dp
=
self
.
world_size
*
self
.
data_parallel_size
ray_only_devices
=
[
"tpu"
]
ray_only_devices
=
[
"tpu"
]
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
if
(
current_platform
.
device_type
in
ray_only_devices
if
(
current_platform
.
device_type
in
ray_only_devices
...
...
vllm/distributed/device_communicators/cuda_communicator.py
View file @
3e472d88
...
@@ -16,8 +16,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
...
@@ -16,8 +16,8 @@ class CudaCommunicator(DeviceCommunicatorBase):
device_group
:
Optional
[
ProcessGroup
]
=
None
,
device_group
:
Optional
[
ProcessGroup
]
=
None
,
unique_name
:
str
=
""
):
unique_name
:
str
=
""
):
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
super
().
__init__
(
cpu_group
,
device
,
device_group
,
unique_name
)
if
"
p
p"
in
unique_name
:
if
"
t
p"
not
in
unique_name
:
#
pipeline parallel does not need
custom allreduce
#
only tp uses
custom allreduce
use_custom_allreduce
=
False
use_custom_allreduce
=
False
else
:
else
:
from
vllm.distributed.parallel_state
import
(
from
vllm.distributed.parallel_state
import
(
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
3e472d88
...
@@ -87,6 +87,7 @@ class CustomAllreduce:
...
@@ -87,6 +87,7 @@ class CustomAllreduce:
return
return
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
self
.
rank
=
rank
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
# No need to initialize custom allreduce for single GPU case.
...
@@ -201,7 +202,9 @@ class CustomAllreduce:
...
@@ -201,7 +202,9 @@ class CustomAllreduce:
@
staticmethod
@
staticmethod
def
free_shared_buffer
(
pointers
:
List
[
int
],
def
free_shared_buffer
(
pointers
:
List
[
int
],
group
:
Optional
[
ProcessGroup
]
=
None
)
->
None
:
group
:
Optional
[
ProcessGroup
]
=
None
,
rank
:
Optional
[
int
]
=
None
)
->
None
:
if
rank
is
None
:
rank
=
dist
.
get_rank
(
group
=
group
)
rank
=
dist
.
get_rank
(
group
=
group
)
lib
=
CudaRTLibrary
()
lib
=
CudaRTLibrary
()
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
lib
.
cudaFree
(
ctypes
.
c_void_p
(
pointers
[
rank
]))
...
@@ -298,8 +301,8 @@ class CustomAllreduce:
...
@@ -298,8 +301,8 @@ class CustomAllreduce:
if
not
self
.
disabled
and
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
ops
.
dispose
(
self
.
_ptr
)
ops
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
_ptr
=
0
self
.
free_shared_buffer
(
self
.
meta_ptrs
)
self
.
free_shared_buffer
(
self
.
meta_ptrs
,
rank
=
self
.
rank
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
)
self
.
free_shared_buffer
(
self
.
buffer_ptrs
,
rank
=
self
.
rank
)
def
__del__
(
self
):
def
__del__
(
self
):
self
.
close
()
self
.
close
()
vllm/distributed/parallel_state.py
View file @
3e472d88
...
@@ -750,6 +750,13 @@ get_tensor_model_parallel_group = get_tp_group
...
@@ -750,6 +750,13 @@ get_tensor_model_parallel_group = get_tp_group
_PP
:
Optional
[
GroupCoordinator
]
=
None
_PP
:
Optional
[
GroupCoordinator
]
=
None
_DP
:
Optional
[
GroupCoordinator
]
=
None
def
get_dp_group
()
->
GroupCoordinator
:
assert
_DP
is
not
None
,
(
"data parallel group is not initialized"
)
return
_DP
def
get_pp_group
()
->
GroupCoordinator
:
def
get_pp_group
()
->
GroupCoordinator
:
assert
_PP
is
not
None
,
(
assert
_PP
is
not
None
,
(
...
@@ -811,6 +818,21 @@ def init_distributed_environment(
...
@@ -811,6 +818,21 @@ def init_distributed_environment(
"world_size=%d rank=%d local_rank=%d "
"world_size=%d rank=%d local_rank=%d "
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
"distributed_init_method=%s backend=%s"
,
world_size
,
rank
,
local_rank
,
distributed_init_method
,
backend
)
distributed_init_method
,
backend
)
from
vllm.config
import
get_current_vllm_config
config
=
get_current_vllm_config
()
if
config
is
not
None
and
config
.
parallel_config
.
data_parallel_size
>
1
:
parallel_config
=
config
.
parallel_config
# adjust to take into account data parallelism
# offset the rank by the data parallel rank
rank
=
parallel_config
.
data_parallel_rank
*
world_size
+
rank
# adjust the world size to take into account data parallelism
world_size
=
parallel_config
.
world_size_across_dp
ip
=
parallel_config
.
data_parallel_master_ip
port
=
parallel_config
.
get_next_dp_init_port
()
distributed_init_method
=
f
"tcp://
{
ip
}
:
{
port
}
"
# noqa
logger
.
info
(
"Adjusting world_size=%d rank=%d distributed_init_method=%s for DP"
,
world_size
,
rank
,
distributed_init_method
)
if
not
torch
.
distributed
.
is_initialized
():
if
not
torch
.
distributed
.
is_initialized
():
assert
distributed_init_method
is
not
None
,
(
assert
distributed_init_method
is
not
None
,
(
"distributed_init_method must be provided when initializing "
"distributed_init_method must be provided when initializing "
...
@@ -870,20 +892,28 @@ def initialize_model_parallel(
...
@@ -870,20 +892,28 @@ def initialize_model_parallel(
# Get world size and rank. Ensure some consistencies.
# Get world size and rank. Ensure some consistencies.
assert
torch
.
distributed
.
is_initialized
()
assert
torch
.
distributed
.
is_initialized
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
world_size
:
int
=
torch
.
distributed
.
get_world_size
()
rank
=
torch
.
distributed
.
get_rank
()
backend
=
backend
or
torch
.
distributed
.
get_backend
(
backend
=
backend
or
torch
.
distributed
.
get_backend
(
get_world_group
().
device_group
)
get_world_group
().
device_group
)
data_parallel_size
=
1
from
vllm.config
import
get_current_vllm_config
config
=
get_current_vllm_config
()
if
config
is
not
None
:
data_parallel_size
=
config
.
parallel_config
.
data_parallel_size
# the layout order is: DP x PP x TP
# to get group_ranks for each dimension, transpose that dimension to the
# last dimension, then reshape to 2D, then unbind the last dimension
all_ranks
=
torch
.
arange
(
world_size
).
reshape
(
data_parallel_size
,
pipeline_model_parallel_size
,
tensor_model_parallel_size
)
# noqa
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
num_tensor_model_parallel_groups
:
int
=
(
world_size
//
tensor_model_parallel_size
)
global
_TP
global
_TP
assert
_TP
is
None
,
(
"tensor model parallel group is already initialized"
)
assert
_TP
is
None
,
(
"tensor model parallel group is already initialized"
)
group_ranks
=
[]
group_ranks
=
all_ranks
.
view
(
-
1
,
tensor_model_parallel_size
).
unbind
(
0
)
for
i
in
range
(
num_tensor_model_parallel_groups
):
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
ranks
=
list
(
range
(
i
*
tensor_model_parallel_size
,
(
i
+
1
)
*
tensor_model_parallel_size
))
group_ranks
.
append
(
ranks
)
# message queue broadcaster is only used in tensor model parallel group
# message queue broadcaster is only used in tensor model parallel group
_TP
=
init_model_parallel_group
(
group_ranks
,
_TP
=
init_model_parallel_group
(
group_ranks
,
...
@@ -893,20 +923,33 @@ def initialize_model_parallel(
...
@@ -893,20 +923,33 @@ def initialize_model_parallel(
group_name
=
"tp"
)
group_name
=
"tp"
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
pipeline_model_parallel_size
)
global
_PP
global
_PP
assert
_PP
is
None
,
(
assert
_PP
is
None
,
(
"pipeline model parallel group is already initialized"
)
"pipeline model parallel group is already initialized"
)
group_ranks
=
[]
group_ranks
=
all_ranks
.
transpose
(
1
,
2
).
reshape
(
for
i
in
range
(
num_pipeline_model_parallel_groups
):
-
1
,
pipeline_model_parallel_size
).
unbind
(
0
)
ranks
=
list
(
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
))
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
group_ranks
.
append
(
ranks
)
_PP
=
init_model_parallel_group
(
group_ranks
,
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
get_world_group
().
local_rank
,
backend
,
backend
,
group_name
=
"pp"
)
group_name
=
"pp"
)
global
_DP
assert
_DP
is
None
,
(
"data parallel group is already initialized"
)
group_ranks
=
all_ranks
.
transpose
(
0
,
2
).
reshape
(
-
1
,
data_parallel_size
).
unbind
(
0
)
group_ranks
=
[
x
.
tolist
()
for
x
in
group_ranks
]
_DP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
group_name
=
"dp"
)
logger
.
info
(
"rank %s in world size %s is assigned as "
"DP rank %s, PP rank %s, TP rank %s"
,
rank
,
world_size
,
_DP
.
rank_in_group
,
_PP
.
rank_in_group
,
_TP
.
rank_in_group
)
def
ensure_kv_transfer_initialized
(
vllm_config
:
"VllmConfig"
)
->
None
:
def
ensure_kv_transfer_initialized
(
vllm_config
:
"VllmConfig"
)
->
None
:
"""
"""
...
@@ -1011,6 +1054,11 @@ def destroy_model_parallel():
...
@@ -1011,6 +1054,11 @@ def destroy_model_parallel():
_PP
.
destroy
()
_PP
.
destroy
()
_PP
=
None
_PP
=
None
global
_DP
if
_DP
:
_DP
.
destroy
()
_DP
=
None
def
destroy_distributed_environment
():
def
destroy_distributed_environment
():
global
_WORLD
global
_WORLD
...
...
vllm/distributed/utils.py
View file @
3e472d88
...
@@ -11,7 +11,11 @@ from collections import deque
...
@@ -11,7 +11,11 @@ from collections import deque
from
typing
import
Any
,
Deque
,
Dict
,
Optional
,
Sequence
,
Tuple
from
typing
import
Any
,
Deque
,
Dict
,
Optional
,
Sequence
,
Tuple
import
torch
import
torch
from
torch.distributed
import
TCPStore
from
torch.distributed
import
ProcessGroup
,
TCPStore
from
torch.distributed.distributed_c10d
import
(
Backend
,
PrefixStore
,
_get_default_timeout
,
is_nccl_available
)
from
torch.distributed.rendezvous
import
rendezvous
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -227,3 +231,88 @@ class StatelessProcessGroup:
...
@@ -227,3 +231,88 @@ class StatelessProcessGroup:
world_size
=
world_size
,
world_size
=
world_size
,
store
=
store
,
store
=
store
,
data_expiration_seconds
=
data_expiration_seconds
)
data_expiration_seconds
=
data_expiration_seconds
)
def
stateless_init_torch_distributed_process_group
(
host
:
str
,
port
:
int
,
rank
:
int
,
world_size
:
int
,
backend
:
str
)
->
ProcessGroup
:
"""
A replacement for `torch.distributed.init_process_group` that does not
pollute the global state. The created ProcessGroup object can be used for
some operations such as `allreduce`, because it does not depend on the
global rank. However, some operations such as `broadcast` cannot be used
because it depends on the global rank.
# TODO: ask for help from PyTorch team if we need the `broadcast` operation.
This function is useful when we are not sure about the total number of
processes in the process group. For example, we may have process
1, 2, ..., 8 who want to communicate, and process 9 might be the same
process as process 1, or it might be a different process; process 10
might be the same process as process 5, or it might be a different process.
In this case, how can we reliably form a communication channel within
process 9 and 10, without affecting the communication channel within
process 1, 2, ..., 8?
One possible solution is to figure out if process 9 and 10 are the same
as process 1 and 5 beforehand, and then form a communication channel
based on the information, adjusting the ranks and world_size etc. However,
figuring out the information is not always easy, and it will interfere
with the main communication channel.
Our solution is to always form a communication channel with process 1, 2,
..., 8, and then use this function to form another communication channel
with process 9 and 10. This way, regardless of whether process 9 and 10
are the same as process 1 and 5, the main communication channel is
always formed with process 1, 2, ..., 8, and the additional communication
channel is formed with process 9 and 10.
"""
init_method
=
f
"tcp://
{
host
}
:
{
port
}
"
backend
=
Backend
(
backend
)
# it is basically string
timeout
=
_get_default_timeout
(
backend
)
store
,
rank
,
world_size
=
next
(
rendezvous
(
init_method
,
rank
,
world_size
,
timeout
=
timeout
))
store
.
set_timeout
(
timeout
)
group_rank
=
rank
group_size
=
world_size
# Use a PrefixStore to avoid accidental overrides of keys used by
# different systems (e.g. RPC) in case the store is multi-tenant.
prefix_store
=
PrefixStore
(
init_method
,
store
)
pg_options
=
ProcessGroup
.
Options
(
backend
=
backend
,
timeout
=
timeout
)
pg
:
ProcessGroup
=
ProcessGroup
(
prefix_store
,
group_rank
,
group_size
,
pg_options
,
)
if
backend
==
"gloo"
:
from
torch.distributed.distributed_c10d
import
ProcessGroupGloo
backend_class
=
ProcessGroupGloo
(
prefix_store
,
group_rank
,
group_size
,
timeout
=
timeout
)
backend_type
=
ProcessGroup
.
BackendType
.
GLOO
device
=
torch
.
device
(
"cpu"
)
elif
backend
==
"nccl"
:
assert
is_nccl_available
()
from
torch.distributed.distributed_c10d
import
ProcessGroupNCCL
backend_options
=
ProcessGroupNCCL
.
Options
()
backend_options
.
_timeout
=
timeout
backend_class
=
ProcessGroupNCCL
(
prefix_store
,
group_rank
,
group_size
,
backend_options
)
backend_type
=
ProcessGroup
.
BackendType
.
NCCL
device
=
torch
.
device
(
"cuda"
)
backend_class
.
_set_sequence_number_for_group
()
pg
.
_register_backend
(
device
,
backend_type
,
backend_class
)
return
pg
vllm/envs.py
View file @
3e472d88
...
@@ -90,6 +90,10 @@ if TYPE_CHECKING:
...
@@ -90,6 +90,10 @@ if TYPE_CHECKING:
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_RAY_BUNDLE_INDICES
:
str
=
""
VLLM_CUDART_SO_PATH
:
Optional
[
str
]
=
None
VLLM_CUDART_SO_PATH
:
Optional
[
str
]
=
None
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
:
bool
=
True
VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH
:
bool
=
True
VLLM_DP_RANK
:
int
=
0
VLLM_DP_SIZE
:
int
=
1
VLLM_DP_MASTER_IP
:
str
=
""
VLLM_DP_MASTER_PORT
:
int
=
0
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -593,6 +597,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -593,6 +597,22 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH"
:
"VLLM_USE_HPU_CONTIGUOUS_CACHE_FETCH"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CONTIGUOUS_PA"
,
"true"
).
lower
()
in
lambda
:
os
.
environ
.
get
(
"VLLM_CONTIGUOUS_PA"
,
"true"
).
lower
()
in
(
"1"
,
"true"
),
(
"1"
,
"true"
),
# Rank of the process in the data parallel setting
"VLLM_DP_RANK"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_DP_RANK"
,
"0"
)),
# World size of the data parallel setting
"VLLM_DP_SIZE"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_DP_SIZE"
,
"1"
)),
# IP address of the master node in the data parallel setting
"VLLM_DP_MASTER_IP"
:
lambda
:
os
.
getenv
(
"VLLM_DP_MASTER_IP"
,
"127.0.0.1"
),
# Port of the master node in the data parallel setting
"VLLM_DP_MASTER_PORT"
:
lambda
:
int
(
os
.
getenv
(
"VLLM_DP_MASTER_PORT"
,
"0"
)),
}
}
# end-env-vars-definition
# end-env-vars-definition
...
...
vllm/forward_context.py
View file @
3e472d88
...
@@ -4,9 +4,10 @@ import time
...
@@ -4,9 +4,10 @@ import time
from
collections
import
defaultdict
from
collections
import
defaultdict
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
Optional
from
typing
import
TYPE_CHECKING
,
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
import
torch.distributed
as
dist
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
...
@@ -32,6 +33,8 @@ class ForwardContext:
...
@@ -32,6 +33,8 @@ class ForwardContext:
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
attn_metadata
:
"AttentionMetadata"
# set dynamically for each forward pass
# TODO: remove after making all virtual_engines share the same kv cache
# TODO: remove after making all virtual_engines share the same kv cache
virtual_engine
:
int
# set dynamically for each forward pass
virtual_engine
:
int
# set dynamically for each forward pass
num_tokens_across_dp
:
Optional
[
List
[
int
]]
=
None
# set dynamically for each forward pass
_forward_context
:
Optional
[
ForwardContext
]
=
None
_forward_context
:
Optional
[
ForwardContext
]
=
None
...
@@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext:
...
@@ -48,7 +51,8 @@ def get_forward_context() -> ForwardContext:
@
contextmanager
@
contextmanager
def
set_forward_context
(
attn_metadata
:
Any
,
def
set_forward_context
(
attn_metadata
:
Any
,
vllm_config
:
VllmConfig
,
vllm_config
:
VllmConfig
,
virtual_engine
:
int
=
0
):
virtual_engine
:
int
=
0
,
num_tokens
:
int
=
0
):
"""A context manager that stores the current forward context,
"""A context manager that stores the current forward context,
can be attention metadata, etc.
can be attention metadata, etc.
Here we can inject common logic for every model forward pass.
Here we can inject common logic for every model forward pass.
...
@@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any,
...
@@ -57,12 +61,36 @@ def set_forward_context(attn_metadata: Any,
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
need_to_track_batchsize
=
track_batchsize
and
attn_metadata
is
not
None
if
need_to_track_batchsize
:
if
need_to_track_batchsize
:
forward_start_time
=
time
.
perf_counter
()
forward_start_time
=
time
.
perf_counter
()
num_tokens_across_dp
=
None
if
vllm_config
.
parallel_config
.
data_parallel_size
>
1
:
dp_size
=
vllm_config
.
parallel_config
.
data_parallel_size
dp_rank
=
vllm_config
.
parallel_config
.
data_parallel_rank
if
attn_metadata
is
not
None
:
if
hasattr
(
attn_metadata
,
"num_prefill_tokens"
):
# for v0 attention backends
batchsize
=
attn_metadata
.
num_prefill_tokens
+
\
attn_metadata
.
num_decode_tokens
else
:
# for v1 attention backends
batchsize
=
attn_metadata
.
num_input_tokens
else
:
batchsize
=
num_tokens
num_tokens_across_dp
=
[
0
]
*
dp_size
num_tokens_across_dp
[
dp_rank
]
=
batchsize
num_tokens_tensor
=
torch
.
tensor
(
num_tokens_across_dp
,
device
=
"cpu"
,
dtype
=
torch
.
int32
)
from
vllm.distributed.parallel_state
import
get_dp_group
dist
.
all_reduce
(
num_tokens_tensor
,
group
=
get_dp_group
().
cpu_group
)
num_tokens_across_dp
=
num_tokens_tensor
.
tolist
()
global
_forward_context
global
_forward_context
prev_context
=
_forward_context
prev_context
=
_forward_context
_forward_context
=
ForwardContext
(
_forward_context
=
ForwardContext
(
attn_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
attn_layers
=
vllm_config
.
compilation_config
.
static_forward_context
,
virtual_engine
=
virtual_engine
,
virtual_engine
=
virtual_engine
,
attn_metadata
=
attn_metadata
)
attn_metadata
=
attn_metadata
,
num_tokens_across_dp
=
num_tokens_across_dp
)
try
:
try
:
yield
yield
finally
:
finally
:
...
...
vllm/utils.py
View file @
3e472d88
...
@@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str:
...
@@ -501,6 +501,24 @@ def get_open_zmq_ipc_path() -> str:
def
get_open_port
()
->
int
:
def
get_open_port
()
->
int
:
"""
Get an open port for the vLLM process to listen on.
An edge case to handle, is when we run data parallel,
we need to avoid ports that are potentially used by
the data parallel master process.
Right now we reserve 10 ports for the data parallel master
process. Currently it uses 2 ports.
"""
if
"VLLM_DP_MASTER_PORT"
in
os
.
environ
:
dp_port
=
envs
.
VLLM_DP_MASTER_PORT
while
True
:
port
=
_get_open_port
()
if
port
>=
dp_port
and
port
<
dp_port
+
10
:
continue
return
port
return
_get_open_port
()
def
_get_open_port
()
->
int
:
port
=
envs
.
VLLM_PORT
port
=
envs
.
VLLM_PORT
if
port
is
not
None
:
if
port
is
not
None
:
while
True
:
while
True
:
...
...
vllm/v1/engine/core.py
View file @
3e472d88
...
@@ -219,6 +219,9 @@ class EngineCore:
...
@@ -219,6 +219,9 @@ class EngineCore:
def
wake_up
(
self
):
def
wake_up
(
self
):
self
.
model_executor
.
wake_up
()
self
.
model_executor
.
wake_up
()
def
execute_dummy_batch
(
self
):
self
.
model_executor
.
collective_rpc
(
"execute_dummy_batch"
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
self
.
model_executor
.
add_lora
(
lora_request
)
self
.
model_executor
.
add_lora
(
lora_request
)
...
...
vllm/v1/engine/core_client.py
View file @
3e472d88
...
@@ -87,6 +87,12 @@ class EngineCoreClient(ABC):
...
@@ -87,6 +87,12 @@ class EngineCoreClient(ABC):
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
)
->
None
:
raise
NotImplementedError
raise
NotImplementedError
def
execute_dummy_batch
(
self
)
->
None
:
raise
NotImplementedError
async
def
execute_dummy_batch_async
(
self
)
->
None
:
raise
NotImplementedError
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
def
abort_requests
(
self
,
request_ids
:
List
[
str
])
->
None
:
raise
NotImplementedError
raise
NotImplementedError
...
@@ -156,6 +162,9 @@ class InprocClient(EngineCoreClient):
...
@@ -156,6 +162,9 @@ class InprocClient(EngineCoreClient):
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
)
->
None
:
self
.
engine_core
.
wake_up
()
self
.
engine_core
.
wake_up
()
def
execute_dummy_batch
(
self
)
->
None
:
self
.
engine_core
.
execute_dummy_batch
()
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
None
:
self
.
engine_core
.
add_lora
(
lora_request
)
self
.
engine_core
.
add_lora
(
lora_request
)
...
@@ -331,6 +340,8 @@ class SyncMPClient(MPClient):
...
@@ -331,6 +340,8 @@ class SyncMPClient(MPClient):
def
wake_up
(
self
)
->
None
:
def
wake_up
(
self
)
->
None
:
self
.
_call_utility
(
"wake_up"
)
self
.
_call_utility
(
"wake_up"
)
def
execute_dummy_batch
(
self
)
->
None
:
self
.
_call_utility
(
"execute_dummy_batch"
)
class
AsyncMPClient
(
MPClient
):
class
AsyncMPClient
(
MPClient
):
"""Asyncio-compatible client for multi-proc EngineCore."""
"""Asyncio-compatible client for multi-proc EngineCore."""
...
@@ -414,5 +425,8 @@ class AsyncMPClient(MPClient):
...
@@ -414,5 +425,8 @@ class AsyncMPClient(MPClient):
async
def
wake_up_async
(
self
)
->
None
:
async
def
wake_up_async
(
self
)
->
None
:
await
self
.
_call_utility_async
(
"wake_up"
)
await
self
.
_call_utility_async
(
"wake_up"
)
async
def
execute_dummy_batch_async
(
self
)
->
None
:
await
self
.
_call_utility_async
(
"execute_dummy_batch"
)
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
None
:
async
def
add_lora_async
(
self
,
lora_request
:
LoRARequest
)
->
None
:
await
self
.
_call_utility_async
(
"add_lora"
,
lora_request
)
await
self
.
_call_utility_async
(
"add_lora"
,
lora_request
)
vllm/v1/engine/llm_engine.py
View file @
3e472d88
...
@@ -4,7 +4,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union
...
@@ -4,7 +4,7 @@ from typing import Dict, List, Mapping, Optional, Type, Union
from
typing_extensions
import
TypeVar
from
typing_extensions
import
TypeVar
from
vllm.config
import
VllmConfig
from
vllm.config
import
ParallelConfig
,
VllmConfig
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.engine.metrics_types
import
StatLoggerBase
from
vllm.envs
import
VLLM_ENABLE_V1_MULTIPROCESSING
from
vllm.envs
import
VLLM_ENABLE_V1_MULTIPROCESSING
...
@@ -47,6 +47,13 @@ class LLMEngine:
...
@@ -47,6 +47,13 @@ class LLMEngine:
self
.
model_config
=
vllm_config
.
model_config
self
.
model_config
=
vllm_config
.
model_config
self
.
cache_config
=
vllm_config
.
cache_config
self
.
cache_config
=
vllm_config
.
cache_config
# important: init dp group before init the engine_core
self
.
parallel_config
=
vllm_config
.
parallel_config
self
.
dp_enabled
=
self
.
parallel_config
.
data_parallel_size
>
1
# noqa
self
.
should_execute_dummy_batch
=
False
if
self
.
dp_enabled
:
self
.
dp_group
=
self
.
parallel_config
.
stateless_init_dp_group
()
# Tokenizer (+ ensure liveness if running in another process).
# Tokenizer (+ ensure liveness if running in another process).
self
.
tokenizer
=
init_tokenizer_from_configs
(
self
.
tokenizer
=
init_tokenizer_from_configs
(
model_config
=
vllm_config
.
model_config
,
model_config
=
vllm_config
.
model_config
,
...
@@ -106,7 +113,17 @@ class LLMEngine:
...
@@ -106,7 +113,17 @@ class LLMEngine:
return
self
.
output_processor
.
get_num_unfinished_requests
()
return
self
.
output_processor
.
get_num_unfinished_requests
()
def
has_unfinished_requests
(
self
)
->
bool
:
def
has_unfinished_requests
(
self
)
->
bool
:
return
self
.
output_processor
.
has_unfinished_requests
()
has_unfinished
=
self
.
output_processor
.
has_unfinished_requests
()
if
not
self
.
dp_enabled
:
return
has_unfinished
return
self
.
has_unfinished_requests_dp
(
has_unfinished
)
def
has_unfinished_requests_dp
(
self
,
has_unfinished
:
bool
)
->
bool
:
aggregated_has_unfinished
=
ParallelConfig
.
has_unfinished_dp
(
self
.
dp_group
,
has_unfinished
)
if
not
has_unfinished
and
aggregated_has_unfinished
:
self
.
should_execute_dummy_batch
=
True
return
aggregated_has_unfinished
@
classmethod
@
classmethod
def
validate_outputs
(
cls
,
outputs
,
output_type
):
def
validate_outputs
(
cls
,
outputs
,
output_type
):
...
@@ -145,6 +162,11 @@ class LLMEngine:
...
@@ -145,6 +162,11 @@ class LLMEngine:
def
step
(
self
)
->
List
[
RequestOutput
]:
def
step
(
self
)
->
List
[
RequestOutput
]:
if
self
.
should_execute_dummy_batch
:
self
.
should_execute_dummy_batch
=
False
self
.
engine_core
.
execute_dummy_batch
()
return
[]
# 1) Get EngineCoreOutput from the EngineCore.
# 1) Get EngineCoreOutput from the EngineCore.
outputs
=
self
.
engine_core
.
get_output
()
outputs
=
self
.
engine_core
.
get_output
()
...
...
vllm/v1/executor/multiproc_executor.py
View file @
3e472d88
...
@@ -239,7 +239,7 @@ class WorkerProc:
...
@@ -239,7 +239,7 @@ class WorkerProc:
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
ready_socket
.
send_string
(
WorkerProc
.
READY_STR
)
ready_socket
.
send
(
payload
)
ready_socket
.
send
(
payload
)
self
.
work
er
.
init_device
()
wrapp
er
.
init_device
()
self
.
worker
.
load_model
()
self
.
worker
.
load_model
()
@
staticmethod
@
staticmethod
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
3e472d88
...
@@ -1167,7 +1167,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
...
@@ -1167,7 +1167,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
for
k
,
v
in
self
.
intermediate_tensors
.
items
()
})
})
with
set_forward_context
(
None
,
self
.
vllm_config
):
with
set_forward_context
(
None
,
self
.
vllm_config
,
num_tokens
=
num_tokens
):
hidden_states
=
model
(
hidden_states
=
model
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
positions
=
positions
,
positions
=
positions
,
...
...
vllm/v1/worker/gpu_worker.py
View file @
3e472d88
...
@@ -235,6 +235,9 @@ class Worker(WorkerBase):
...
@@ -235,6 +235,9 @@ class Worker(WorkerBase):
else
:
else
:
self
.
profiler
.
stop
()
self
.
profiler
.
stop
()
def
execute_dummy_batch
(
self
)
->
None
:
self
.
model_runner
.
_dummy_run
(
1
)
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
def
add_lora
(
self
,
lora_request
:
LoRARequest
)
->
bool
:
return
self
.
model_runner
.
add_lora
(
lora_request
)
return
self
.
model_runner
.
add_lora
(
lora_request
)
...
...
vllm/worker/worker_base.py
View file @
3e472d88
...
@@ -567,6 +567,11 @@ class WorkerWrapperBase:
...
@@ -567,6 +567,11 @@ class WorkerWrapperBase:
self
.
worker
=
worker_class
(
**
kwargs
)
self
.
worker
=
worker_class
(
**
kwargs
)
assert
self
.
worker
is
not
None
assert
self
.
worker
is
not
None
def
init_device
(
self
):
with
set_current_vllm_config
(
self
.
vllm_config
):
# To make vLLM config available during device initialization
self
.
worker
.
init_device
()
# type: ignore
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
def
execute_method
(
self
,
method
:
Union
[
str
,
bytes
],
*
args
,
**
kwargs
):
try
:
try
:
target
=
self
if
self
.
worker
is
None
else
self
.
worker
target
=
self
if
self
.
worker
is
None
else
self
.
worker
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment