Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
663874e0
Unverified
Commit
663874e0
authored
Oct 04, 2024
by
youkaichao
Committed by
GitHub
Oct 04, 2024
Browse files
[torch.compile] improve allreduce registration (#9061)
parent
cc90419e
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
32 deletions
+21
-32
vllm/distributed/device_communicators/custom_all_reduce.py
vllm/distributed/device_communicators/custom_all_reduce.py
+6
-9
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+15
-23
No files found.
vllm/distributed/device_communicators/custom_all_reduce.py
View file @
663874e0
...
...
@@ -265,14 +265,12 @@ class CustomAllreduce:
def
custom_all_reduce
(
self
,
input
:
torch
.
Tensor
)
->
Optional
[
torch
.
Tensor
]:
# when custom allreduce is disabled, this will be None
if
self
.
disabled
:
if
self
.
disabled
or
not
self
.
should_custom_ar
(
input
)
:
return
None
if
self
.
_IS_CAPTURING
:
if
torch
.
cuda
.
is_current_stream_capturing
():
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_reg
(
input
)
else
:
if
self
.
should_custom_ar
(
input
):
# if warm up, mimic the allocation pattern
# since custom allreduce is out-of-place
return
torch
.
empty_like
(
input
)
...
...
@@ -281,7 +279,6 @@ class CustomAllreduce:
# custom allreduce incurs a cost of cudaMemcpy, which should
# be small(<=1% of overall latency) compared to the performance
# gains of using custom kernels
if
self
.
should_custom_ar
(
input
):
return
self
.
all_reduce_unreg
(
input
)
return
None
...
...
vllm/distributed/parallel_state.py
View file @
663874e0
...
...
@@ -105,7 +105,7 @@ if supports_custom_op():
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
group
.
_all_reduce
_in_place
(
tensor
)
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
...
...
@@ -118,7 +118,7 @@ if supports_custom_op():
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
return
group
.
_all_reduce
_out_place
(
tensor
)
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
...
...
@@ -338,14 +338,17 @@ class GroupCoordinator:
return
input_
if
not
supports_custom_op
():
return
self
.
_all_reduce
(
input_
)
self
.
_all_reduce_in_place
(
input_
)
return
input_
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
return
self
.
_
all_reduce
(
input_
)
return
self
.
tpu_communicator
.
all_reduce
(
input_
)
if
self
.
ca_comm
is
not
None
and
self
.
ca_comm
.
should_custom_ar
(
input_
):
if
self
.
ca_comm
is
not
None
and
\
not
self
.
ca_comm
.
disabled
and
\
self
.
ca_comm
.
should_custom_ar
(
input_
):
return
torch
.
ops
.
vllm
.
outplace_all_reduce
(
input_
,
group_name
=
self
.
unique_name
)
else
:
...
...
@@ -353,25 +356,15 @@ class GroupCoordinator:
group_name
=
self
.
unique_name
)
return
input_
def
_all_reduce
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
The actual all-reduce implementation.
NOTE: This operation will be applied in-place or out-of-place.
Always assume this function modifies its input, but use the return
value as the output.
"""
def
_all_reduce_out_place
(
self
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
ca_comm
=
self
.
ca_comm
# For TPUs, use TPU communicator.
tpu_comm
=
self
.
tpu_communicator
if
tpu_comm
is
not
None
and
not
tpu_comm
.
disabled
:
return
tpu_comm
.
all_reduce
(
input_
)
if
ca_comm
is
not
None
:
assert
ca_comm
is
not
None
assert
not
ca_comm
.
disabled
out
=
ca_comm
.
custom_all_reduce
(
input_
)
if
out
is
not
None
:
assert
out
is
not
None
return
out
def
_all_reduce_in_place
(
self
,
input_
:
torch
.
Tensor
)
->
None
:
pynccl_comm
=
self
.
pynccl_comm
if
(
pynccl_comm
is
not
None
and
not
pynccl_comm
.
disabled
):
pynccl_comm
.
all_reduce
(
input_
)
...
...
@@ -380,7 +373,6 @@ class GroupCoordinator:
ipex
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
else
:
torch
.
distributed
.
all_reduce
(
input_
,
group
=
self
.
device_group
)
return
input_
def
all_gather
(
self
,
input_
:
torch
.
Tensor
,
dim
:
int
=
-
1
)
->
torch
.
Tensor
:
world_size
=
self
.
world_size
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment