Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
702bee46
Unverified
Commit
702bee46
authored
May 12, 2024
by
youkaichao
Committed by
GitHub
May 12, 2024
Browse files
[Core][Distributed] refactor custom allreduce to support multiple tp groups (#4754)
parent
a7be4d00
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
327 additions
and
226 deletions
+327
-226
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+11
-11
tests/distributed/test_custom_all_reduce.py
tests/distributed/test_custom_all_reduce.py
+58
-29
tests/distributed/test_pynccl.py
tests/distributed/test_pynccl.py
+2
-2
vllm/distributed/communication_op.py
vllm/distributed/communication_op.py
+34
-11
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+176
-144
vllm/distributed/device_communicators/pynccl.py
vllm/distributed/device_communicators/pynccl.py
+3
-1
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+27
-1
vllm/test_utils.py
vllm/test_utils.py
+8
-9
vllm/worker/model_runner.py
vllm/worker/model_runner.py
+4
-11
vllm/worker/worker.py
vllm/worker/worker.py
+4
-7
No files found.
tests/distributed/test_comm_ops.py
View file @
702bee46
...
@@ -16,7 +16,7 @@ from vllm.test_utils import (init_test_distributed_environment,
...
@@ -16,7 +16,7 @@ from vllm.test_utils import (init_test_distributed_environment,
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
all_reduce_test_worker
(
t
ensor_parallel
_size
:
int
,
rank
:
int
,
def
all_reduce_test_worker
(
t
p_size
:
int
,
pp
_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
distributed_init_port
:
str
):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# so that each worker can see all the GPUs
...
@@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -24,12 +24,12 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel
_size
,
rank
,
init_test_distributed_environment
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
)
distributed_init_port
)
num_elements
=
8
num_elements
=
8
all_tensors
=
[
all_tensors
=
[
torch
.
arange
(
num_elements
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
torch
.
arange
(
num_elements
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
(
r
+
1
)
for
r
in
range
(
t
ensor_parallel
_size
)
(
r
+
1
)
for
r
in
range
(
t
p
_size
)
]
]
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
t
=
all_tensors
[
rank
]
t
=
all_tensors
[
rank
]
...
@@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -38,7 +38,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int,
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
all_gather_test_worker
(
t
ensor_parallel
_size
:
int
,
rank
:
int
,
def
all_gather_test_worker
(
t
p_size
:
int
,
pp
_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
distributed_init_port
:
str
):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# so that each worker can see all the GPUs
...
@@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -46,7 +46,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel
_size
,
rank
,
init_test_distributed_environment
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
)
distributed_init_port
)
num_dimensions
=
3
num_dimensions
=
3
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
...
@@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -57,7 +57,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
all_tensors
=
[
all_tensors
=
[
torch
.
arange
(
total_size
,
dtype
=
torch
.
float32
,
torch
.
arange
(
total_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
reshape
(
tensor_size
)
*
(
r
+
1
)
device
=
"cuda"
).
reshape
(
tensor_size
)
*
(
r
+
1
)
for
r
in
range
(
t
ensor_parallel
_size
)
for
r
in
range
(
t
p
_size
)
]
]
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
t
=
all_tensors
[
rank
]
t
=
all_tensors
[
rank
]
...
@@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -66,7 +66,7 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int,
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
broadcast_tensor_dict_test_worker
(
t
ensor_parallel
_size
:
int
,
rank
:
int
,
def
broadcast_tensor_dict_test_worker
(
t
p_size
:
int
,
pp
_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
distributed_init_port
:
str
):
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# it is important to delete the CUDA_VISIBLE_DEVICES environment variable
# so that each worker can see all the GPUs
# so that each worker can see all the GPUs
...
@@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -74,7 +74,7 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
tensor_parallel
_size
,
rank
,
init_test_distributed_environment
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
)
distributed_init_port
)
test_dict
=
{
test_dict
=
{
# device tensor
# device tensor
...
@@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
...
@@ -106,10 +106,10 @@ def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"t
ensor_parallel
_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"t
p
_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
all_reduce_test_worker
,
all_gather_test_worker
,
all_reduce_test_worker
,
all_gather_test_worker
,
broadcast_tensor_dict_test_worker
broadcast_tensor_dict_test_worker
])
])
def
test_multi_process_tensor_parallel
(
t
ensor_parallel
_size
,
test_target
):
def
test_multi_process_tensor_parallel
(
t
p
_size
,
test_target
):
multi_process_tensor_parallel
(
t
ensor_parallel
_size
,
test_target
)
multi_process_tensor_parallel
(
t
p
_size
,
1
,
test_target
)
tests/distributed/test_custom_all_reduce.py
View file @
702bee46
...
@@ -6,8 +6,10 @@ import ray
...
@@ -6,8 +6,10 @@ import ray
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
vllm.distributed
import
tensor_model_parallel_all_reduce
from
vllm.distributed.communication_op
import
(
# noqa
from
vllm.distributed.device_communicators
import
custom_all_reduce
graph_capture
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
(
get_tensor_model_parallel_group
,
get_tp_ca_communicator
)
from
vllm.test_utils
import
(
init_test_distributed_environment
,
from
vllm.test_utils
import
(
init_test_distributed_environment
,
multi_process_tensor_parallel
)
multi_process_tensor_parallel
)
...
@@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes):
...
@@ -18,17 +20,36 @@ for i, v in enumerate(test_sizes):
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
graph_allreduce
(
world
_size
,
rank
,
distributed_init_port
):
def
graph_allreduce
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
world
_size
,
rank
,
init_test_distributed_environment
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
)
distributed_init_port
)
custom_all_reduce
.
init_custom_ar
()
group
=
get_tensor_model_parallel_group
()
# A small all_reduce for warmup.
# this is needed because device communicators might be created lazily
# (e.g. NCCL). This will ensure that the communicator is initialized
# before any communication happens, so that this group can be used for
# graph capture immediately.
data
=
torch
.
zeros
(
1
)
data
=
data
.
to
(
device
=
device
)
torch
.
distributed
.
all_reduce
(
data
,
group
=
group
)
torch
.
cuda
.
synchronize
()
del
data
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication
=
rank
//
tp_size
+
1
for
sz
in
test_sizes
:
for
sz
in
test_sizes
:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
for
dtype
in
[
torch
.
float32
,
torch
.
float16
,
torch
.
bfloat16
]:
with
custom_all_reduce
.
capture
():
with
graph_
capture
():
# use integers so result matches NCCL exactly
# use integers so result matches NCCL exactly
inp1
=
torch
.
randint
(
1
,
inp1
=
torch
.
randint
(
1
,
16
,
(
sz
,
),
16
,
(
sz
,
),
...
@@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
...
@@ -41,44 +62,52 @@ def graph_allreduce(world_size, rank, distributed_init_port):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
graph
=
torch
.
cuda
.
CUDAGraph
()
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
with
torch
.
cuda
.
graph
(
graph
):
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
for
i
in
range
(
num_communication
):
# the input buffer is immediately modified to test
out1
=
tensor_model_parallel_all_reduce
(
inp1
)
# synchronization
# the input buffer is immediately modified to test
dist
.
all_reduce
(
inp1
)
# synchronization
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp1
,
group
=
group
)
dist
.
all_reduce
(
inp2
)
out2
=
tensor_model_parallel_all_reduce
(
inp2
)
dist
.
all_reduce
(
inp2
,
group
=
group
)
graph
.
replay
()
graph
.
replay
()
assert
torch
.
allclose
(
out1
,
inp1
)
assert
torch
.
allclose
(
out1
,
inp1
)
assert
torch
.
allclose
(
out2
,
inp2
)
assert
torch
.
allclose
(
out2
,
inp2
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
@
ray
.
remote
(
num_gpus
=
1
,
max_calls
=
1
)
def
eager_allreduce
(
world
_size
,
rank
,
distributed_init_port
):
def
eager_allreduce
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
):
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
del
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
rank
}
"
)
torch
.
cuda
.
set_device
(
device
)
torch
.
cuda
.
set_device
(
device
)
init_test_distributed_environment
(
1
,
world
_size
,
rank
,
init_test_distributed_environment
(
tp_size
,
pp
_size
,
rank
,
distributed_init_port
)
distributed_init_port
)
# we use the first group to communicate once
# and the second group to communicate twice
# and so on
# this is used to demonstrate that each group can
# communicate independently
num_communication
=
rank
//
tp_size
+
1
sz
=
1024
sz
=
1024
custom_all_reduce
.
init_custom_ar
()
fa
=
get_tp_ca_communicator
()
fa
=
custom_all_reduce
.
get_handle
()
inp
=
torch
.
ones
(
sz
,
dtype
=
torch
.
float32
,
device
=
device
)
inp
=
torch
.
ones
(
sz
,
dtype
=
torch
.
float32
,
device
=
device
)
out
=
fa
.
all_reduce_unreg
(
inp
)
out
=
inp
assert
torch
.
allclose
(
out
,
inp
*
world_size
)
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce_unreg
(
out
)
assert
torch
.
allclose
(
out
,
inp
*
(
tp_size
**
num_communication
))
inp
=
torch
.
ones
(
sz
*
4
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
inp
=
torch
.
ones
(
sz
*
4
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
out
=
fa
.
all_reduce_unreg
(
inp
)
out
=
inp
assert
torch
.
allclose
(
out
,
inp
*
world_size
)
for
_
in
range
(
num_communication
):
out
=
fa
.
all_reduce_unreg
(
out
)
assert
torch
.
allclose
(
out
,
inp
*
(
tp_size
**
num_communication
))
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"pipeline_parallel_size"
,
[
1
,
2
])
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
eager_allreduce
,
graph_allreduce
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
eager_allreduce
,
graph_allreduce
])
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
def
test_custom_allreduce
(
tp_size
,
pipeline_parallel_size
,
test_target
):
multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
)
world_size
=
tp_size
*
pipeline_parallel_size
if
world_size
>
torch
.
cuda
.
device_count
():
pytest
.
skip
(
"Not enough GPUs to run the test."
)
if
__name__
==
"__main__"
:
multi_process_tensor_parallel
(
tp_size
,
pipeline_parallel_size
,
test_target
)
multi_process_tensor_parallel
(
2
,
graph_allreduce
)
tests/distributed/test_pynccl.py
View file @
702bee46
...
@@ -5,7 +5,7 @@ import pytest
...
@@ -5,7 +5,7 @@ import pytest
import
torch
import
torch
from
vllm.distributed.communication_op
import
(
# noqa
from
vllm.distributed.communication_op
import
(
# noqa
graph_
capture_
mode
,
tensor_model_parallel_all_reduce
)
graph_mode
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl
import
PyNcclCommunicator
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.device_communicators.pynccl_wrapper
import
NCCLLibrary
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
from
vllm.distributed.parallel_state
import
(
ensure_model_parallel_initialized
,
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
...
@@ -103,7 +103,7 @@ def multiple_tp_with_vllm_worker_fn():
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
device
=
torch
.
device
(
f
"cuda:
{
torch
.
distributed
.
get_rank
()
}
"
)
ensure_model_parallel_initialized
(
2
,
2
)
ensure_model_parallel_initialized
(
2
,
2
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
tensor
=
torch
.
ones
(
16
,
1024
,
1024
,
dtype
=
torch
.
float32
,
device
=
device
)
with
graph_
capture_
mode
():
with
graph_mode
():
# two tp groups can communicate independently
# two tp groups can communicate independently
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
if
torch
.
distributed
.
get_rank
()
in
[
0
,
1
]:
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
tensor
=
tensor_model_parallel_all_reduce
(
tensor
)
...
...
vllm/distributed/communication_op.py
View file @
702bee46
from
collections
import
namedtuple
from
collections
import
namedtuple
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
,
nullcontext
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
...
@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
...
@@ -9,12 +9,13 @@ from .parallel_state import (get_cpu_world_group,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_tp_ca_communicator
,
get_tp_pynccl_communicator
)
get_tp_pynccl_communicator
)
@
contextmanager
@
contextmanager
def
graph_
capture_
mode
():
def
graph_mode
():
# In graph
captur
e, we have to be very careful about the collective
# In graph
mod
e, we have to be very careful about the collective
# operations. The current status is:
# operations. The current status is:
# allreduce \ Mode | Eager | Graph |
# allreduce \ Mode | Eager | Graph |
# --------------------------------------------
# --------------------------------------------
...
@@ -24,10 +25,32 @@ def graph_capture_mode():
...
@@ -24,10 +25,32 @@ def graph_capture_mode():
#
#
# Note that custom allreduce will have a runtime check, if the tensor size
# Note that custom allreduce will have a runtime check, if the tensor size
# is too large, it will fallback to the next available option.
# is too large, it will fallback to the next available option.
# In summary: When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
pynccl_comm
=
get_tp_pynccl_communicator
()
pynccl_comm
=
get_tp_pynccl_communicator
()
assert
pynccl_comm
is
not
None
if
pynccl_comm
is
None
:
with
pynccl_comm
.
change_state
(
enable
=
True
,
context
=
nullcontext
()
stream
=
torch
.
cuda
.
current_stream
()):
else
:
context
=
pynccl_comm
.
change_state
(
enable
=
True
,
stream
=
torch
.
cuda
.
current_stream
())
with
context
:
yield
@
contextmanager
def
graph_capture
():
"""
`graph_capture` is a context manager which should include the code that
is capturing the CUDA graph. Its main purpose is to ensure that the
some operations will be run after the graph is captured, before the graph
is replayed.
"""
ca_comm
=
get_tp_ca_communicator
()
context
=
nullcontext
()
if
ca_comm
is
None
else
ca_comm
.
capture
()
with
context
:
yield
yield
...
@@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
...
@@ -43,15 +66,15 @@ def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor:
TLDR: always assume this function modifies its input, but use the return
TLDR: always assume this function modifies its input, but use the return
value as the output.
value as the output.
"""
"""
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
ca_comm
=
get_tp_ca_communicator
()
custom_all_reduce
)
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
get_tensor_model_parallel_world_size
()
==
1
:
if
get_tensor_model_parallel_world_size
()
==
1
:
return
input_
return
input_
out
=
custom_all_reduce
(
input_
)
if
ca_comm
is
not
None
:
if
out
is
not
None
:
out
=
ca_comm
.
custom_all_reduce
(
input_
)
return
out
if
out
is
not
None
:
return
out
pynccl_comm
=
get_tp_pynccl_communicator
()
pynccl_comm
=
get_tp_pynccl_communicator
()
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
pynccl_comm
.
all_reduce
(
input_
)
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
702bee46
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
List
,
Optional
,
Union
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
torch.distributed
import
ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.distributed.parallel_state
import
(
get_local_rank
,
get_tensor_model_parallel_cpu_group
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
try
:
try
:
import
pynvml
import
pynvml
from
vllm._C
import
custom_ar
from
vllm._C
import
custom_ar
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
except
ImportError
:
except
ImportError
:
# For AMD GPUs
# For AMD GPUs
custom_ar
=
None
custom_ar
=
None
pynvml
=
None
pynvml
=
None
logger
=
init_logger
(
__name__
)
@
contextmanager
def
_nvml
():
try
:
yield
finally
:
pass
_CA_HANDLE
:
Optional
[
"CustomAllreduce"
]
=
None
_IS_CAPTURING
=
False
logger
=
init_logger
(
__name__
)
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
def
init_custom_ar
()
->
None
:
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
global
_CA_HANDLE
if
_CA_HANDLE
is
not
None
:
return
rank
=
get_tensor_model_parallel_rank
()
world_size
=
get_tensor_model_parallel_world_size
()
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world size: "
"%d. Supported world sizes: %s. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
_SUPPORTED_WORLD_SIZES
))
return
num_dev
=
torch
.
cuda
.
device_count
()
# note: num dev can be larger than world_size if we're only using
# first few GPUs
if
num_dev
<
world_size
:
logger
.
warning
(
"Cannot test GPU P2P because not all GPUs are visible to the "
"current process. This might be the case if 'CUDA_VISIBLE_DEVICES'"
" is set."
)
return
# we only use a subset of GPUs here
# so we only need to check the nvlink connectivity of these GPUs
num_dev
=
world_size
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
num_dev
))
# this checks hardware and driver support for NVLink
full_nvlink
=
_is_full_nvlink
(
device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on more"
" than two PCIe-only GPUs. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks GPU P2P"
" capability or P2P test failed. To silence this warning, specify"
" disable_custom_all_reduce=True explicitly."
)
return
_CA_HANDLE
=
CustomAllreduce
(
rank
,
world_size
,
full_nvlink
)
def
begin_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
True
def
end_capture
()
->
None
:
global
_IS_CAPTURING
_IS_CAPTURING
=
False
def
is_capturing
()
->
bool
:
return
_IS_CAPTURING
and
_CA_HANDLE
is
not
None
def
get_handle
()
->
Optional
[
"CustomAllreduce"
]:
return
_CA_HANDLE
def
is_initialized
()
->
bool
:
return
_CA_HANDLE
is
not
None
@
contextmanager
def
capture
():
try
:
begin_capture
()
yield
finally
:
end_capture
()
handle
=
get_handle
()
if
handle
is
not
None
:
handle
.
register_graph_buffers
()
def
custom_all_reduce
(
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
ca_handle
=
get_handle
()
# when custom allreduce is disabled, this will be None
if
ca_handle
is
None
:
return
None
if
is_capturing
():
if
torch
.
cuda
.
is_current_stream_capturing
():
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_reg
(
input
)
else
:
if
ca_handle
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
ca_handle
.
should_custom_ar
(
input
):
return
ca_handle
.
all_reduce_unreg
(
input
)
return
None
@
contextmanager
def
_nvml
():
try
:
pynvml
.
nvmlInit
()
yield
finally
:
pynvml
.
nvmlShutdown
()
@
_nvml
()
@
_nvml
()
...
@@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
...
@@ -188,22 +76,112 @@ def _can_p2p(rank: int, world_size: int) -> bool:
class
CustomAllreduce
:
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
# max_size: max supported allreduce size
# max_size: max supported allreduce size
def
__init__
(
self
,
def
__init__
(
self
,
rank
,
group
:
Optional
[
ProcessGroup
]
=
None
,
world_size
,
device
:
Optional
[
Union
[
int
,
str
,
torch
.
device
]]
=
None
,
full_nvlink
,
max_size
=
8192
*
1024
)
->
None
:
max_size
=
8192
*
1024
)
->
None
:
"""
Args:
group: the process group to work on. If None, it will use the
default process group.
device: the device to bind the CustomAllreduce to. If None,
it will be bind to f"cuda:{local_rank}".
It is the caller's responsibility to make sure each communicator
is bind to a unique device, and all communicators in this group
are in the same node.
"""
self
.
_IS_CAPTURING
=
False
self
.
disabled
=
True
if
custom_ar
is
None
:
# disable because of missing custom allreduce library
# e.g. in a non-cuda environment
return
group
=
group
or
get_tensor_model_parallel_cpu_group
()
self
.
group
=
group
assert
dist
.
get_backend
(
group
)
!=
dist
.
Backend
.
NCCL
,
(
"CustomAllreduce should be attached to a non-NCCL group."
)
rank
=
dist
.
get_rank
(
group
=
self
.
group
)
world_size
=
dist
.
get_world_size
(
group
=
self
.
group
)
if
world_size
==
1
:
# No need to initialize custom allreduce for single GPU case.
return
if
world_size
not
in
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
:
logger
.
warning
(
"Custom allreduce is disabled due to an unsupported world"
" size: %d. Supported world sizes: %s. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
,
world_size
,
str
(
CustomAllreduce
.
_SUPPORTED_WORLD_SIZES
))
return
if
device
is
None
:
local_rank
=
get_local_rank
()
device
=
torch
.
device
(
f
"cuda:
{
local_rank
}
"
)
elif
isinstance
(
device
,
int
):
device
=
torch
.
device
(
f
"cuda:
{
device
}
"
)
elif
isinstance
(
device
,
str
):
device
=
torch
.
device
(
device
)
# now `device` is a `torch.device` object
assert
isinstance
(
device
,
torch
.
device
)
self
.
device
=
device
cuda_visible_devices
=
envs
.
CUDA_VISIBLE_DEVICES
if
cuda_visible_devices
:
device_ids
=
list
(
map
(
int
,
cuda_visible_devices
.
split
(
","
)))
else
:
device_ids
=
list
(
range
(
torch
.
cuda
.
device_count
()))
physical_device_id
=
device_ids
[
device
.
index
]
tensor
=
torch
.
tensor
([
physical_device_id
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
gather_list
=
[
torch
.
tensor
([
0
],
dtype
=
torch
.
int
,
device
=
"cpu"
)
for
_
in
range
(
world_size
)
]
dist
.
all_gather
(
gather_list
,
tensor
,
group
=
self
.
group
)
physical_device_ids
=
[
t
.
item
()
for
t
in
gather_list
]
# test nvlink first, this will filter out most of the cases
# where custom allreduce is not supported
# this checks hardware and driver support for NVLink
full_nvlink
=
_is_full_nvlink
(
physical_device_ids
)
if
world_size
>
2
and
not
full_nvlink
:
logger
.
warning
(
"Custom allreduce is disabled because it's not supported on"
" more than two PCIe-only GPUs. To silence this warning, "
"specify disable_custom_all_reduce=True explicitly."
)
return
# test P2P capability, this checks software/cudaruntime support
# this is expensive to compute at the first time
# then we cache the result
if
not
_can_p2p
(
rank
,
world_size
):
logger
.
warning
(
"Custom allreduce is disabled because your platform lacks "
"GPU P2P capability or P2P test failed. To silence this "
"warning, specify disable_custom_all_reduce=True explicitly."
)
return
self
.
disabled
=
False
# buffers memory are owned by this Python class and passed to C++
# buffers memory are owned by this Python class and passed to C++
# meta data composes of two parts: meta data for synchronization
# meta data composes of two parts: meta data for synchronization
# (256 bytes) and a temporary buffer for storing intermediate
# (256 bytes) and a temporary buffer for storing intermediate
# allreduce results.
# allreduce results.
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
self
.
meta
=
torch
.
zeros
(
custom_ar
.
meta_size
()
+
max_size
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
device
=
self
.
device
)
# This is a pre-registered IPC buffer. In eager mode, input tensors
# This is a pre-registered IPC buffer. In eager mode, input tensors
# are first copied into this buffer before allreduce is performed
# are first copied into this buffer before allreduce is performed
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
self
.
buffer
=
torch
.
empty
(
max_size
,
dtype
=
torch
.
uint8
,
device
=
self
.
device
)
# This is a buffer for storing the tuples of pointers pointing to
# This is a buffer for storing the tuples of pointers pointing to
# IPC buffers from all ranks. Each registered tuple has size of
# IPC buffers from all ranks. Each registered tuple has size of
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
# 8*world_size bytes where world_size is at most 8. Allocating 8MB
...
@@ -211,8 +189,9 @@ class CustomAllreduce:
...
@@ -211,8 +189,9 @@ class CustomAllreduce:
# needs less than 10000 of registered tuples.
# needs less than 10000 of registered tuples.
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
self
.
rank_data
=
torch
.
empty
(
8
*
1024
*
1024
,
dtype
=
torch
.
uint8
,
dtype
=
torch
.
uint8
,
device
=
"cuda"
)
device
=
self
.
device
)
self
.
max_size
=
max_size
self
.
max_size
=
max_size
self
.
rank
=
rank
self
.
world_size
=
world_size
self
.
world_size
=
world_size
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
handles
,
offsets
=
self
.
_get_ipc_meta
(
self
.
meta
)
self
.
full_nvlink
=
full_nvlink
self
.
full_nvlink
=
full_nvlink
...
@@ -221,6 +200,21 @@ class CustomAllreduce:
...
@@ -221,6 +200,21 @@ class CustomAllreduce:
self
.
full_nvlink
)
self
.
full_nvlink
)
self
.
register_buffer
(
self
.
buffer
)
self
.
register_buffer
(
self
.
buffer
)
@
contextmanager
def
capture
(
self
):
"""
The main responsibility of this context manager is the
`register_graph_buffers` call at the end of the context.
It records all the buffer addresses used in the CUDA graph.
"""
try
:
self
.
_IS_CAPTURING
=
True
yield
finally
:
self
.
_IS_CAPTURING
=
False
if
not
self
.
disabled
:
self
.
register_graph_buffers
()
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
def
_get_ipc_meta
(
self
,
inp
:
torch
.
Tensor
):
data
=
inp
.
untyped_storage
().
_share_cuda_
()
data
=
inp
.
untyped_storage
().
_share_cuda_
()
shard_data
=
(
shard_data
=
(
...
@@ -230,14 +224,29 @@ class CustomAllreduce:
...
@@ -230,14 +224,29 @@ class CustomAllreduce:
return
self
.
_gather_ipc_meta
(
shard_data
)
return
self
.
_gather_ipc_meta
(
shard_data
)
def
_gather_ipc_meta
(
self
,
shard_data
):
def
_gather_ipc_meta
(
self
,
shard_data
):
all_data
:
List
[
Optional
[
Any
]]
=
[
None
]
*
self
.
world_size
# Note: don't use `[[None]] * self.world_size` here
dist
.
all_gather_object
(
all_data
,
shard_data
)
# because it will create a list of the same reference
all_data
:
List
[
Optional
[
Any
]]
=
[[
None
]
for
i
in
range
(
self
.
world_size
)]
all_data
[
self
.
rank
][
0
]
=
shard_data
ranks
=
dist
.
get_process_group_ranks
(
group
=
self
.
group
)
ranks
.
sort
()
for
i
,
rank
in
enumerate
(
ranks
):
dist
.
broadcast_object_list
(
all_data
[
i
],
src
=
rank
,
group
=
self
.
group
,
device
=
"cpu"
)
# we cannot directly use `dist.all_gather_object` here
# because it is incompatible with `gloo` backend under inference mode.
# see https://github.com/pytorch/pytorch/issues/126032 for details.
handles
=
[]
handles
=
[]
offsets
=
[]
offsets
=
[]
for
i
in
range
(
len
(
all_data
)):
for
i
in
range
(
len
(
all_data
)):
handles
.
append
(
all_data
[
i
][
0
])
# type: ignore
handles
.
append
(
all_data
[
i
][
0
]
[
0
]
)
# type: ignore
offsets
.
append
(
all_data
[
i
][
1
])
# type: ignore
offsets
.
append
(
all_data
[
i
][
0
][
1
])
# type: ignore
return
handles
,
offsets
return
handles
,
offsets
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
def
register_buffer
(
self
,
inp
:
torch
.
Tensor
):
...
@@ -269,8 +278,31 @@ class CustomAllreduce:
...
@@ -269,8 +278,31 @@ class CustomAllreduce:
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
custom_ar
.
all_reduce_unreg
(
self
.
_ptr
,
inp
,
self
.
buffer
,
out
)
return
out
return
out
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# when custom allreduce is disabled, this will be None
if
self
.
disabled
:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_reg
(
input
)
else
:
if
self
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
else
:
# note: outside of cuda graph context,
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_unreg
(
input
)
return
None
def
close
(
self
):
def
close
(
self
):
if
self
.
_ptr
:
if
not
self
.
disabled
and
self
.
_ptr
:
custom_ar
.
dispose
(
self
.
_ptr
)
custom_ar
.
dispose
(
self
.
_ptr
)
self
.
_ptr
=
0
self
.
_ptr
=
0
...
...
vllm/distributed/device_communicators/pynccl.py
View file @
702bee46
...
@@ -96,8 +96,10 @@ class PyNcclCommunicator:
...
@@ -96,8 +96,10 @@ class PyNcclCommunicator:
self
.
stream
=
torch
.
cuda
.
Stream
()
self
.
stream
=
torch
.
cuda
.
Stream
()
# A small all_reduce for warmup.
# A small all_reduce for warmup.
self
.
all_reduce
(
torch
.
zeros
(
1
,
device
=
device
))
data
=
torch
.
zeros
(
1
,
device
=
device
)
self
.
all_reduce
(
data
)
self
.
stream
.
synchronize
()
self
.
stream
.
synchronize
()
del
data
# by default it is disabled, e.g. in profiling models and prefill phase.
# by default it is disabled, e.g. in profiling models and prefill phase.
# to use it, use under `with obj.change_state(enable=True)`, usually
# to use it, use under `with obj.change_state(enable=True)`, usually
...
...
vllm/distributed/parallel_state.py
View file @
702bee46
...
@@ -13,10 +13,13 @@ from vllm.logger import init_logger
...
@@ -13,10 +13,13 @@ from vllm.logger import init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_ENABLE_CUSTOM_ALL_REDUCE
=
True
# Tensor model parallel group that the current rank belongs to.
# Tensor model parallel group that the current rank belongs to.
_TP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_CPU_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_CPU_GROUP
:
Optional
[
ProcessGroup
]
=
None
_TP_PYNCCL_COMMUNICATOR
=
None
_TP_PYNCCL_COMMUNICATOR
=
None
_TP_CA_COMMUNICATOR
=
None
# Pipeline model parallel group that the current rank belongs to.
# Pipeline model parallel group that the current rank belongs to.
_PP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
_PP_DEVICE_GROUP
:
Optional
[
ProcessGroup
]
=
None
...
@@ -47,11 +50,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None
...
@@ -47,11 +50,21 @@ _PP_GLOBAL_RANKS: Optional[List[int]] = None
_LOCAL_RANK
=
-
1
_LOCAL_RANK
=
-
1
def
set_custom_all_reduce
(
enable
:
bool
):
global
_ENABLE_CUSTOM_ALL_REDUCE
_ENABLE_CUSTOM_ALL_REDUCE
=
enable
def
get_tp_pynccl_communicator
():
def
get_tp_pynccl_communicator
():
global
_TP_PYNCCL_COMMUNICATOR
global
_TP_PYNCCL_COMMUNICATOR
return
_TP_PYNCCL_COMMUNICATOR
return
_TP_PYNCCL_COMMUNICATOR
def
get_tp_ca_communicator
():
global
_TP_CA_COMMUNICATOR
return
_TP_CA_COMMUNICATOR
def
get_local_rank
():
def
get_local_rank
():
global
_LOCAL_RANK
global
_LOCAL_RANK
return
_LOCAL_RANK
return
_LOCAL_RANK
...
@@ -100,6 +113,9 @@ def init_distributed_environment(
...
@@ -100,6 +113,9 @@ def init_distributed_environment(
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
data
=
data
.
to
(
device
=
f
"cuda:
{
local_rank
}
"
)
data
=
data
.
to
(
device
=
f
"cuda:
{
local_rank
}
"
)
torch
.
distributed
.
all_reduce
(
data
)
torch
.
distributed
.
all_reduce
(
data
)
if
torch
.
cuda
.
is_available
():
torch
.
cuda
.
synchronize
()
del
data
def
initialize_model_parallel
(
def
initialize_model_parallel
(
...
@@ -149,7 +165,8 @@ def initialize_model_parallel(
...
@@ -149,7 +165,8 @@ def initialize_model_parallel(
rank
=
torch
.
distributed
.
get_rank
()
rank
=
torch
.
distributed
.
get_rank
()
# Build the tensor model-parallel groups.
# Build the tensor model-parallel groups.
global
_TP_DEVICE_GROUP
,
_TP_CPU_GROUP
,
_TP_PYNCCL_COMMUNICATOR
global
_TP_DEVICE_GROUP
,
_TP_CPU_GROUP
global
_TP_PYNCCL_COMMUNICATOR
,
_TP_CA_COMMUNICATOR
assert
_TP_DEVICE_GROUP
is
None
,
(
assert
_TP_DEVICE_GROUP
is
None
,
(
"tensor model parallel group is already initialized"
)
"tensor model parallel group is already initialized"
)
for
i
in
range
(
num_tensor_model_parallel_groups
):
for
i
in
range
(
num_tensor_model_parallel_groups
):
...
@@ -168,6 +185,15 @@ def initialize_model_parallel(
...
@@ -168,6 +185,15 @@ def initialize_model_parallel(
device
=
_LOCAL_RANK
,
device
=
_LOCAL_RANK
,
)
)
# Initialize a custom fast all-reduce implementation.
if
_ENABLE_CUSTOM_ALL_REDUCE
:
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
CustomAllreduce
)
_TP_CA_COMMUNICATOR
=
CustomAllreduce
(
group
=
_TP_CPU_GROUP
,
device
=
_LOCAL_RANK
,
)
# Build the pipeline model-parallel groups.
# Build the pipeline model-parallel groups.
global
_PP_DEVICE_GROUP
global
_PP_DEVICE_GROUP
global
_PP_GLOBAL_RANKS
global
_PP_GLOBAL_RANKS
...
...
vllm/test_utils.py
View file @
702bee46
...
@@ -6,24 +6,24 @@ from vllm.utils import get_open_port
...
@@ -6,24 +6,24 @@ from vllm.utils import get_open_port
def
init_test_distributed_environment
(
def
init_test_distributed_environment
(
p
ipeline_parallel
_size
:
int
,
t
p_size
:
int
,
tensor_parallel
_size
:
int
,
pp
_size
:
int
,
rank
:
int
,
rank
:
int
,
distributed_init_port
:
str
,
distributed_init_port
:
str
,
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
init_distributed_environment
(
init_distributed_environment
(
world_size
=
p
ipeline_parallel_size
*
tensor_parallel
_size
,
world_size
=
p
p_size
*
tp
_size
,
rank
=
rank
,
rank
=
rank
,
distributed_init_method
=
distributed_init_method
,
distributed_init_method
=
distributed_init_method
,
local_rank
=
local_rank
)
local_rank
=
local_rank
)
ensure_model_parallel_initialized
(
tensor_parallel_size
,
ensure_model_parallel_initialized
(
tp_size
,
pp_size
)
pipeline_parallel_size
)
def
multi_process_tensor_parallel
(
def
multi_process_tensor_parallel
(
tensor_parallel_size
:
int
,
tp_size
:
int
,
pp_size
:
int
,
test_target
,
test_target
,
)
->
None
:
)
->
None
:
# Using ray helps debugging the error when it failed
# Using ray helps debugging the error when it failed
...
@@ -32,10 +32,9 @@ def multi_process_tensor_parallel(
...
@@ -32,10 +32,9 @@ def multi_process_tensor_parallel(
distributed_init_port
=
get_open_port
()
distributed_init_port
=
get_open_port
()
refs
=
[]
refs
=
[]
for
rank
in
range
(
t
ensor_parallel
_size
):
for
rank
in
range
(
t
p_size
*
pp
_size
):
refs
.
append
(
refs
.
append
(
test_target
.
remote
(
tensor_parallel_size
,
rank
,
test_target
.
remote
(
tp_size
,
pp_size
,
rank
,
distributed_init_port
))
distributed_init_port
))
ray
.
get
(
refs
)
ray
.
get
(
refs
)
ray
.
shutdown
()
ray
.
shutdown
()
vllm/worker/model_runner.py
View file @
702bee46
...
@@ -12,8 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -12,8 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
ModelConfig
,
ParallelConfig
,
SchedulerConfig
,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed
import
broadcast_tensor_dict
from
vllm.distributed.communication_op
import
graph_capture_mode
from
vllm.distributed.communication_op
import
graph_capture
,
graph_mode
from
vllm.distributed.device_communicators
import
custom_all_reduce
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.layers
import
LoRAMapping
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
...
@@ -942,13 +941,7 @@ class ModelRunner:
...
@@ -942,13 +941,7 @@ class ModelRunner:
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
bs
for
bs
in
_BATCH_SIZES_TO_CAPTURE
if
bs
<=
graph_batch_size
]
]
# NOTE(woosuk): There are 3 backends for all-reduce: custom all-reduce
with
graph_capture
():
# kernel, pynccl, and PyTorch NCCL. When using CUDA graph, we use
# either custom all-reduce kernel or pynccl. When not using CUDA
# graph, we use either custom all-reduce kernel or PyTorch NCCL.
# We always prioritize using custom all-reduce kernel but fall back
# to PyTorch or pynccl if it is disabled or not supported.
with
custom_all_reduce
.
capture
():
# NOTE: Capturing the largest batch size first may help reduce the
# NOTE: Capturing the largest batch size first may help reduce the
# memory usage of CUDA graph.
# memory usage of CUDA graph.
for
batch_size
in
reversed
(
batch_size_capture_list
):
for
batch_size
in
reversed
(
batch_size_capture_list
):
...
@@ -1040,7 +1033,7 @@ class CUDAGraphRunner:
...
@@ -1040,7 +1033,7 @@ class CUDAGraphRunner:
# Run the model once without capturing the graph.
# Run the model once without capturing the graph.
# This is to make sure that the captured graph does not include the
# This is to make sure that the captured graph does not include the
# kernel launches for initial benchmarking (e.g., Triton autotune).
# kernel launches for initial benchmarking (e.g., Triton autotune).
with
graph_
capture_
mode
():
with
graph_mode
():
self
.
model
(
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
@@ -1055,7 +1048,7 @@ class CUDAGraphRunner:
...
@@ -1055,7 +1048,7 @@ class CUDAGraphRunner:
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
# https://stackoverflow.com/questions/31039022/python-multi-line-with-statement
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
self
.
_graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
torch
.
cuda
.
graph
(
self
.
_graph
,
pool
=
memory_pool
):
# noqa: SIM117
with
graph_
capture_
mode
():
with
graph_mode
():
hidden_states
=
self
.
model
(
hidden_states
=
self
.
model
(
input_ids
,
input_ids
,
positions
,
positions
,
...
...
vllm/worker/worker.py
View file @
702bee46
...
@@ -11,9 +11,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
...
@@ -11,9 +11,8 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
VisionLanguageConfig
)
VisionLanguageConfig
)
from
vllm.distributed
import
(
broadcast_tensor_dict
,
from
vllm.distributed
import
(
broadcast_tensor_dict
,
ensure_model_parallel_initialized
,
ensure_model_parallel_initialized
,
init_distributed_environment
)
init_distributed_environment
,
from
vllm.distributed.device_communicators.custom_all_reduce
import
(
set_custom_all_reduce
)
init_custom_ar
)
from
vllm.lora.request
import
LoRARequest
from
vllm.lora.request
import
LoRARequest
from
vllm.model_executor
import
set_random_seed
from
vllm.model_executor
import
set_random_seed
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
from
vllm.sequence
import
ExecuteModelRequest
,
PoolerOutput
,
SamplerOutput
...
@@ -302,16 +301,14 @@ def init_worker_distributed_environment(
...
@@ -302,16 +301,14 @@ def init_worker_distributed_environment(
local_rank
:
int
=
-
1
,
local_rank
:
int
=
-
1
,
)
->
None
:
)
->
None
:
"""Initialize the distributed environment."""
"""Initialize the distributed environment."""
set_custom_all_reduce
(
not
parallel_config
.
disable_custom_all_reduce
)
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
init_distributed_environment
(
parallel_config
.
world_size
,
rank
,
distributed_init_method
,
local_rank
)
distributed_init_method
,
local_rank
)
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
ensure_model_parallel_initialized
(
parallel_config
.
tensor_parallel_size
,
parallel_config
.
pipeline_parallel_size
)
parallel_config
.
pipeline_parallel_size
)
# Initialize a custom fast all-reduce implementation.
if
not
parallel_config
.
disable_custom_all_reduce
:
init_custom_ar
()
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
def
_check_if_gpu_supports_dtype
(
torch_dtype
:
torch
.
dtype
):
# Check if the GPU supports the dtype.
# Check if the GPU supports the dtype.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment