Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
17c79f3c
Unverified
Commit
17c79f3c
authored
Oct 22, 2024
by
youkaichao
Committed by
GitHub
Oct 22, 2024
Browse files
[torch.compile] auto infer dynamic_arg_dims from type annotation (#9589)
parent
cd5601ac
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
65 additions
and
19 deletions
+65
-19
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+63
-5
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+1
-7
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+1
-7
No files found.
vllm/compilation/decorators.py
View file @
17c79f3c
import
inspect
import
inspect
from
typing
import
Dict
,
List
,
Union
from
typing
import
Dict
,
List
,
Optional
,
Union
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
from
vllm.utils
import
supports_dynamo
logger
=
init_logger
(
__name__
)
def
support_torch_compile
(
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
def
support_torch_compile
(
cls
:
Optional
[
type
]
=
None
,
dynamic_arg_dims
:
Optional
[
Dict
[
str
,
Union
[
int
,
List
[
int
]]]]
=
None
):
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
Usage 1: use directly as a decorator without arguments:
```python
@support_torch_compile
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
Usage 2: use as a decorator with arguments:
```python
@support_torch_compile(dynamic_arg_dims={"x": 0, "y": 0})
class MyModel(nn.Module):
def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]):
...
```
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
integer or a list of integers.
Depending on the value of arguments:
if `dynamic_arg_dims` is `None`, it is inferred from the type annotation
of the `forward` method, based on the following default rules:
- if the argument is annotated as `torch.Tensor` or
`Optional[torch.Tensor]`, the first dimension will be
marked as dynamic.
- if the argument is annotated as `IntermediateTensors`, the first
dimension of all the tensors in the intermediate tensors
will be marked as dynamic.
During runtime, when we actually mark dimensions of tensors,
it depends on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
will be marked as dynamic.
...
@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
...
@@ -38,11 +72,35 @@ def support_torch_compile(dynamic_arg_dims: Dict[str, Union[int, List[int]]]):
if
not
hasattr
(
cls
,
'forward'
):
if
not
hasattr
(
cls
,
'forward'
):
raise
TypeError
(
"decorated class should have a forward method."
)
raise
TypeError
(
"decorated class should have a forward method."
)
sig
=
inspect
.
signature
(
cls
.
forward
)
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
inferred_dynamic_arg_dims
=
dynamic_arg_dims
if
inferred_dynamic_arg_dims
is
None
:
inferred_dynamic_arg_dims
=
{}
for
k
,
v
in
sig
.
parameters
.
items
():
if
v
.
annotation
in
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
IntermediateTensors
,
Optional
[
IntermediateTensors
]
]:
inferred_dynamic_arg_dims
[
k
]
=
0
logger
.
debug
((
"Inferred dynamic dimensions for "
"forward method of %s: %s"
),
cls
,
list
(
inferred_dynamic_arg_dims
.
keys
()))
if
len
(
inferred_dynamic_arg_dims
)
==
0
:
raise
ValueError
(
"No dynamic dimensions found in the forward method of "
f
"
{
cls
}
. Please provide dynamic_arg_dims explicitly."
)
for
k
in
inferred_dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
dynamic_arg_dims
)
return
_support_torch_compile
(
cls
,
inferred_dynamic_arg_dims
)
if
cls
is
not
None
:
# use `support_torch_compile` as a decorator without arguments
assert
isinstance
(
cls
,
type
)
return
cls_decorator_helper
(
cls
)
return
cls_decorator_helper
return
cls_decorator_helper
...
...
vllm/model_executor/models/gemma2.py
View file @
17c79f3c
...
@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -241,13 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
(
@
support_torch_compile
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
Gemma2Model
(
nn
.
Module
):
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/llama.py
View file @
17c79f3c
...
@@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -268,13 +268,7 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
(
@
support_torch_compile
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment