Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f877a7d1
Unverified
Commit
f877a7d1
authored
Dec 01, 2024
by
Cyrus Leung
Committed by
GitHub
Nov 30, 2024
Browse files
[Misc] Improve type annotations for `support_torch_compile` (#10763)
Signed-off-by:
DarkLight1337
<
tlleungac@connect.ust.hk
>
parent
13370712
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
9 deletions
+29
-9
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+29
-9
No files found.
vllm/compilation/decorators.py
View file @
f877a7d1
import
inspect
from
typing
import
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
TypeVar
,
Union
,
overload
import
torch
import
torch.nn
as
nn
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
...
...
@@ -12,10 +13,27 @@ from vllm.utils import supports_dynamo
logger
=
init_logger
(
__name__
)
_T
=
TypeVar
(
"_T"
,
bound
=
type
[
nn
.
Module
])
@
overload
def
support_torch_compile
(
*
,
dynamic_arg_dims
:
Optional
[
Dict
[
str
,
Union
[
int
,
List
[
int
]]]],
)
->
Callable
[[
_T
],
_T
]:
...
@
overload
def
support_torch_compile
(
cls
:
_T
)
->
_T
:
...
def
support_torch_compile
(
cls
:
Optional
[
type
]
=
None
,
dynamic_arg_dims
:
Optional
[
Dict
[
str
,
Union
[
int
,
List
[
int
]]]]
=
None
):
cls
:
Optional
[
_T
]
=
None
,
*
,
dynamic_arg_dims
:
Optional
[
Dict
[
str
,
Union
[
int
,
List
[
int
]]]]
=
None
,
)
->
Union
[
Callable
[[
_T
],
_T
],
_T
]:
"""
A decorator to add support for compiling the forward method of a class.
...
...
@@ -66,7 +84,7 @@ def support_torch_compile(
computation graph.
"""
def
cls_decorator_helper
(
cls
:
type
)
:
def
cls_decorator_helper
(
cls
:
_T
)
->
_T
:
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
if
not
hasattr
(
cls
,
'forward'
):
...
...
@@ -105,8 +123,10 @@ def support_torch_compile(
return
cls_decorator_helper
def
_support_torch_compile
(
cls
:
type
,
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
def
_support_torch_compile
(
cls
:
_T
,
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]],
)
->
_T
:
"""
A decorator to add support for compiling the forward method of a class.
"""
...
...
@@ -119,7 +139,7 @@ def _support_torch_compile(cls: type,
# other than TorchCompileWrapperWithCustomDispatcher
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
# type: ignore
old_init
=
cls
.
__init__
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
...
...
@@ -135,7 +155,7 @@ def _support_torch_compile(cls: type,
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
,
compilation_level
=
vllm_config
.
compilation_config
.
level
)
cls
.
__init__
=
__init__
# type: ignore
cls
.
__init__
=
__init__
def
__call__
(
self
,
*
args
,
**
kwargs
):
# torch.compiler.is_compiling() means we are inside the compilation
...
...
@@ -180,5 +200,5 @@ def _support_torch_compile(cls: type,
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
return
model_output
cls
.
__call__
=
__call__
# type: ignore
cls
.
__call__
=
__call__
return
cls
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment