Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
99aa4edd
Unverified
Commit
99aa4edd
authored
Sep 16, 2024
by
youkaichao
Committed by
GitHub
Sep 16, 2024
Browse files
[torch.compile] register allreduce operations as custom ops (#8526)
parent
ee2bceaa
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
137 additions
and
50 deletions
+137
-50
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+3
-7
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+0
-12
csrc/ops.h
csrc/ops.h
+0
-2
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+0
-5
tests/compile/__init__.py
tests/compile/__init__.py
+0
-0
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+13
-2
vllm/_custom_ops.py
vllm/_custom_ops.py
+0
-6
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+19
-2
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+102
-14
No files found.
.buildkite/test-pipeline.yaml
View file @
99aa4edd
...
...
@@ -163,13 +163,6 @@ steps:
-
python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
-
python3 offline_inference_encoder_decoder.py
-
label
:
torch compile integration test
source_file_dependencies
:
-
vllm/
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
pytest -v -s ./compile/test_wrapper.py
-
label
:
Prefix Caching Test
# 7min
#mirror_hardwares: [amd]
source_file_dependencies
:
...
...
@@ -348,7 +341,10 @@ steps:
-
vllm/executor/
-
vllm/model_executor/models/
-
tests/distributed/
-
vllm/compilation
commands
:
-
pytest -v -s ./compile/test_full_graph.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus
# Avoid importing model tests that cause CUDA reinitialization error
...
...
csrc/custom_all_reduce.cu
View file @
99aa4edd
...
...
@@ -55,18 +55,6 @@ bool _is_weak_contiguous(torch::Tensor& t) {
t
.
numel
()
*
t
.
element_size
());
}
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
)
{
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
// custom allreduce requires input byte size to be multiples of 16
if
(
inp_size
%
16
!=
0
)
return
false
;
if
(
!
_is_weak_contiguous
(
inp
))
return
false
;
if
(
world_size
==
2
||
full_nvlink
)
return
inp_size
<=
max_size
;
// for 4 or more non NVLink-capable GPUs, custom allreduce provides little
// performance improvement over NCCL.
return
false
;
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
...
...
csrc/ops.h
View file @
99aa4edd
...
...
@@ -241,8 +241,6 @@ fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int64_t
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int64_t
max_size
,
int64_t
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
...
...
csrc/torch_bindings.cpp
View file @
99aa4edd
...
...
@@ -411,11 +411,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
"bool full_nvlink) -> int"
);
custom_ar
.
impl
(
"init_custom_ar"
,
torch
::
kCUDA
,
&
init_custom_ar
);
custom_ar
.
def
(
"should_custom_ar(Tensor inp, int max_size, int world_size, "
"bool full_nvlink) -> bool"
);
custom_ar
.
impl
(
"should_custom_ar"
,
torch
::
kCUDA
,
&
should_custom_ar
);
custom_ar
.
def
(
"all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"
);
custom_ar
.
impl
(
"all_reduce_reg"
,
torch
::
kCUDA
,
&
all_reduce_reg
);
...
...
tests/compile/__init__.py
0 → 100644
View file @
99aa4edd
tests/compile/test_full_graph.py
View file @
99aa4edd
...
...
@@ -2,9 +2,20 @@ import os
import
pytest
from
vllm.utils
import
cuda_device_count_stateless
from
..utils
import
fork_new_process_for_each_test
@
pytest
.
mark
.
parametrize
(
"model"
,
[
"meta-llama/Meta-Llama-3-8B"
])
def
test_full_graph
(
model
):
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
1
,
2
])
@
fork_new_process_for_each_test
def
test_full_graph
(
model
,
tp_size
):
# Skip the test if there are not enough CUDA devices.
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
"Not enough CUDA devices for the test."
)
# make sure these models can be captured in full graph mode
if
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
not
in
os
.
environ
:
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
...
...
@@ -17,7 +28,7 @@ def test_full_graph(model):
"The future of AI is"
,
]
sampling_params
=
SamplingParams
(
temperature
=
0
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
)
llm
=
LLM
(
model
=
model
,
enforce_eager
=
True
,
tensor_parallel_size
=
tp_size
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
vllm/_custom_ops.py
View file @
99aa4edd
...
...
@@ -870,12 +870,6 @@ def init_custom_ar(meta: torch.Tensor, rank_data: torch.Tensor,
offsets
,
rank
,
full_nvlink
)
def
should_custom_ar
(
inp
:
torch
.
Tensor
,
max_size
:
int
,
world_size
:
int
,
full_nvlink
:
bool
)
->
bool
:
return
torch
.
ops
.
_C_custom_ar
.
should_custom_ar
(
inp
,
max_size
,
world_size
,
full_nvlink
)
def
all_reduce_reg
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_C_custom_ar
.
all_reduce_reg
(
fa
,
inp
,
out
)
...
...
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
99aa4edd
...
...
@@ -33,6 +33,12 @@ def _can_p2p(rank: int, world_size: int) -> bool:
return
True
def
is_weak_contiguous
(
inp
:
torch
.
Tensor
):
return
inp
.
is_contiguous
()
or
(
inp
.
storage
().
nbytes
()
-
inp
.
storage_offset
()
*
inp
.
element_size
()
==
inp
.
numel
()
*
inp
.
element_size
())
class
CustomAllreduce
:
_SUPPORTED_WORLD_SIZES
=
[
2
,
4
,
6
,
8
]
...
...
@@ -224,8 +230,19 @@ class CustomAllreduce:
ops
.
register_graph_buffers
(
self
.
_ptr
,
handles
,
offsets
)
def
should_custom_ar
(
self
,
inp
:
torch
.
Tensor
):
return
ops
.
should_custom_ar
(
inp
,
self
.
max_size
,
self
.
world_size
,
self
.
full_nvlink
)
if
self
.
disabled
:
return
False
inp_size
=
inp
.
numel
()
*
inp
.
element_size
()
# custom allreduce requires input byte size to be multiples of 16
if
inp_size
%
16
!=
0
:
return
False
if
not
is_weak_contiguous
(
inp
):
return
False
# for 4 or more non NVLink-capable GPUs, custom allreduce provides
# little performance improvement over NCCL.
if
self
.
world_size
==
2
or
self
.
full_nvlink
:
return
inp_size
<
self
.
max_size
return
False
# all reduce, assuming inp tensor is IPC registered with register_buffer,
# or, in the context of cuda graphs, register_graph_buffers
...
...
vllm/distributed/parallel_state.py
View file @
99aa4edd
...
...
@@ -21,11 +21,12 @@ If you only need to use the distributed environment without model/pipeline
"""
import
contextlib
import
pickle
import
weakref
from
collections
import
namedtuple
from
contextlib
import
contextmanager
,
nullcontext
from
dataclasses
import
dataclass
from
multiprocessing
import
shared_memory
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
from
unittest.mock
import
patch
import
torch
...
...
@@ -69,6 +70,58 @@ def _split_tensor_dict(
return
metadata_list
,
tensor_list
_group_name_counter
:
Dict
[
str
,
int
]
=
{}
def
_get_unique_name
(
name
:
str
)
->
str
:
"""Get a unique name for the group.
Example:
_get_unique_name("tp") -> "tp:0"
_get_unique_name("tp") -> "tp:1"
"""
if
name
not
in
_group_name_counter
:
_group_name_counter
[
name
]
=
0
newname
=
f
"
{
name
}
:
{
_group_name_counter
[
name
]
}
"
_group_name_counter
[
name
]
+=
1
return
newname
_groups
:
Dict
[
str
,
Callable
[[],
"GroupCoordinator"
]]
=
{}
def
_register_group
(
group
:
"GroupCoordinator"
)
->
None
:
# looks like Python 3.8 does not understand `ReferenceType`
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
"""
PyTorch ProcessGroup wrapper for a group of processes.
...
...
@@ -111,7 +164,11 @@ class GroupCoordinator:
use_custom_allreduce
:
bool
,
use_tpu_communicator
:
bool
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
):
group_name
=
group_name
or
"anonymous"
self
.
unique_name
=
_get_unique_name
(
group_name
)
_register_group
(
self
)
self
.
rank
=
torch
.
distributed
.
get_rank
()
self
.
local_rank
=
local_rank
...
...
@@ -149,28 +206,24 @@ class GroupCoordinator:
from
vllm.distributed.device_communicators.pynccl
import
(
PyNcclCommunicator
)
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
self
.
pynccl_comm
:
Optional
[
PyNcclCommunicator
]
=
None
if
use_pynccl
and
self
.
world_size
>
1
:
self
.
pynccl_comm
=
PyNcclCommunicator
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
pynccl_comm
=
None
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
self
.
ca_comm
:
Optional
[
CustomAllreduce
]
=
None
if
use_custom_allreduce
and
self
.
world_size
>
1
:
# Initialize a custom fast all-reduce implementation.
self
.
ca_comm
=
CustomAllreduce
(
group
=
self
.
cpu_group
,
device
=
self
.
device
,
)
else
:
self
.
ca_comm
=
None
from
vllm.distributed.device_communicators.tpu_communicator
import
(
TpuCommunicator
)
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
self
.
tpu_communicator
:
Optional
[
TpuCommunicator
]
=
None
if
use_tpu_communicator
and
self
.
world_size
>
1
:
self
.
tpu_communicator
=
TpuCommunicator
(
group
=
self
.
cpu_group
)
...
...
@@ -264,16 +317,46 @@ class GroupCoordinator:
def
all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
User-facing all-reduce function before we actually call the
all-reduce operation.
We need this because Dynamo does not support passing an arbitrary
object (`self` in this case) to a custom op. We need to pass the
group name as a string, and then look up the group coordinator from
the group name, dispatch the all-reduce operation to the group
coordinator.
In addition, PyTorch custom ops do not support mutation or returning
a new tensor in the same op. So we need to figure out if the op is
in-place or out-of-place ahead of time.
"""
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
_all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
self
.
ca_comm
.
should_custom_ar
(
input_
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
torch
.
ops
.
vllm
.
inplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
ca_comm
=
self
.
ca_comm
# Bypass the function if we are using only 1 GPU.
if
self
.
world_size
==
1
:
return
input_
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
...
...
@@ -758,6 +841,7 @@ def init_world_group(ranks: List[int], local_rank: int,
use_pynccl
=
False
,
use_custom_allreduce
=
False
,
use_tpu_communicator
=
False
,
group_name
=
"world"
,
)
...
...
@@ -767,6 +851,7 @@ def init_model_parallel_group(
backend
:
str
,
use_custom_allreduce
:
Optional
[
bool
]
=
None
,
use_message_queue_broadcaster
:
bool
=
False
,
group_name
:
Optional
[
str
]
=
None
,
)
->
GroupCoordinator
:
if
use_custom_allreduce
is
None
:
use_custom_allreduce
=
_ENABLE_CUSTOM_ALL_REDUCE
...
...
@@ -778,6 +863,7 @@ def init_model_parallel_group(
use_custom_allreduce
=
use_custom_allreduce
,
use_tpu_communicator
=
True
,
use_message_queue_broadcaster
=
use_message_queue_broadcaster
,
group_name
=
group_name
,
)
...
...
@@ -931,7 +1017,8 @@ def initialize_model_parallel(
_TP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_message_queue_broadcaster
=
True
)
use_message_queue_broadcaster
=
True
,
group_name
=
"tp"
)
# Build the pipeline model-parallel groups.
num_pipeline_model_parallel_groups
:
int
=
(
world_size
//
...
...
@@ -947,7 +1034,8 @@ def initialize_model_parallel(
_PP
=
init_model_parallel_group
(
group_ranks
,
get_world_group
().
local_rank
,
backend
,
use_custom_allreduce
=
False
)
use_custom_allreduce
=
False
,
group_name
=
"pp"
)
def
ensure_model_parallel_initialized
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment