Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d4bf085a
Unverified
Commit
d4bf085a
authored
Sep 21, 2024
by
Kunshang Ji
Committed by
GitHub
Sep 20, 2024
Browse files
[MISC] add support custom_op check (#8557)
Co-authored-by:
youkaichao
<
youkaichao@126.com
>
parent
0057894e
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
33 additions
and
22 deletions
+33
-22
vllm/distributed/parallel_state.py
vllm/distributed/parallel_state.py
+27
-22
vllm/utils.py
vllm/utils.py
+6
-0
No files found.
vllm/distributed/parallel_state.py
View file @
d4bf085a
...
@@ -36,6 +36,7 @@ from torch.distributed import Backend, ProcessGroup
...
@@ -36,6 +36,7 @@ from torch.distributed import Backend, ProcessGroup
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
supports_custom_op
@
dataclass
@
dataclass
...
@@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None:
...
@@ -95,32 +96,33 @@ def _register_group(group: "GroupCoordinator") -> None:
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
_groups
[
group
.
unique_name
]
=
weakref
.
ref
(
group
)
# type: ignore
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
if
supports_custom_op
():
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
torch
.
library
.
custom_op
(
"vllm::inplace_all_reduce"
,
mutates_args
=
[
"tensor"
])
def
inplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
group
.
_all_reduce
(
tensor
)
@
inplace_all_reduce
.
register_fake
@
inplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
None
:
return
return
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
torch
.
library
.
custom_op
(
"vllm::outplace_all_reduce"
,
mutates_args
=
[])
def
outplace_all_reduce
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
assert
group_name
in
_groups
,
f
"Group
{
group_name
}
is not found."
group
=
_groups
[
group_name
]()
if
group
is
None
:
raise
ValueError
(
f
"Group
{
group_name
}
is destroyed."
)
return
group
.
_all_reduce
(
tensor
)
@
outplace_all_reduce
.
register_fake
@
outplace_all_reduce
.
register_fake
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
def
_
(
tensor
:
torch
.
Tensor
,
group_name
:
str
)
->
torch
.
Tensor
:
return
torch
.
empty_like
(
tensor
)
return
torch
.
empty_like
(
tensor
)
class
GroupCoordinator
:
class
GroupCoordinator
:
...
@@ -335,6 +337,9 @@ class GroupCoordinator:
...
@@ -335,6 +337,9 @@ class GroupCoordinator:
if
self
.
world_size
==
1
:
if
self
.
world_size
==
1
:
return
input_
return
input_
if
not
supports_custom_op
():
return
self
.
_all_reduce
(
input_
)
if
self
.
tpu_communicator
is
not
None
and
\
if
self
.
tpu_communicator
is
not
None
and
\
not
self
.
tpu_communicator
.
disabled
:
not
self
.
tpu_communicator
.
disabled
:
# TPU handles Dynamo with its own logic.
# TPU handles Dynamo with its own logic.
...
...
vllm/utils.py
View file @
d4bf085a
...
@@ -1245,6 +1245,12 @@ def supports_dynamo() -> bool:
...
@@ -1245,6 +1245,12 @@ def supports_dynamo() -> bool:
return
base_torch_version
>=
Version
(
"2.4.0"
)
return
base_torch_version
>=
Version
(
"2.4.0"
)
# Some backends use pytorch version < 2.4.0 which doesn't
# support `torch.library.custom_op`.
def
supports_custom_op
()
->
bool
:
return
hasattr
(
torch
.
library
,
"custom_op"
)
class
AtomicCounter
:
class
AtomicCounter
:
"""An atomic, thread-safe counter"""
"""An atomic, thread-safe counter"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment