Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e00c094f
Unverified
Commit
e00c094f
authored
Oct 10, 2024
by
youkaichao
Committed by
GitHub
Oct 10, 2024
Browse files
[torch.compile] generic decorators (#9258)
parent
a78c6ba7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
34 deletions
+74
-34
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+58
-30
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+8
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-2
No files found.
vllm/compilation/decorators.py
View file @
e00c094f
from
typing
import
List
,
Optional
,
Union
import
inspect
from
typing
import
Dict
,
List
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
def
support_compile_llama_style
(
cls
:
type
):
def
support_torch_compile
(
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
"""
A decorator to add support for compiling the forward method of a class.
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
Depending on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
- otherwise, it will raise an error.
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
"""
def
cls_decorator_helper
(
cls
:
type
):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
dynamic_arg_dims
)
return
cls_decorator_helper
def
_support_torch_compile
(
cls
:
type
,
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
"""
A decorator to add support for compiling the forward method of a class.
If a module's **forward signature** is compatible with llama, this
decorator can be used to enable the compilation of the forward method.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
...
...
@@ -37,48 +71,42 @@ def support_compile_llama_style(cls: type):
cls
.
__init__
=
__init__
def
__call__
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
def
__call__
(
self
,
*
args
,
**
kwargs
):
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
self
.
forward
(
*
args
,
**
kwargs
)
# the first compilation needs to have dynamic shapes marked
if
len
(
self
.
compiled_codes
)
<
1
:
if
input_ids
is
not
None
:
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
positions
,
0
)
if
inputs_embeds
is
not
None
:
torch
.
_dynamo
.
mark_dynamic
(
inputs_embeds
,
0
)
if
intermediate_tensors
is
not
None
:
for
tensors
in
intermediate_tensors
.
tensors
.
values
():
torch
.
_dynamo
.
mark_dynamic
(
tensors
,
0
)
sig
=
inspect
.
signature
(
self
.
__class__
.
forward
)
bound_args
=
sig
.
bind
(
self
,
*
args
,
**
kwargs
)
bound_args
.
apply_defaults
()
for
k
,
dims
in
dynamic_arg_dims
.
items
():
arg
=
bound_args
.
arguments
.
get
(
k
)
if
arg
is
not
None
:
if
isinstance
(
arg
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
elif
isinstance
(
arg
,
IntermediateTensors
):
for
tensor
in
arg
.
tensors
.
values
():
torch
.
_dynamo
.
mark_dynamic
(
tensor
,
dims
)
else
:
raise
ValueError
(
"Unsupported dynamic dimensions"
f
"
{
dims
}
for argument
{
k
}
with type
{
type
(
arg
)
}
."
)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if
len
(
self
.
compiled_codes
)
<
1
or
not
self
.
use_custom_dispatcher
:
return
self
.
compiled_callable
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# the Dynamo guard mechanism.
with
self
.
dispatch_to_code
(
0
):
model_output
=
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
return
model_output
cls
.
__call__
=
__call__
...
...
vllm/model_executor/models/gemma2.py
View file @
e00c094f
...
...
@@ -21,7 +21,7 @@ from torch import nn
from
transformers
import
Gemma2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile
_llama_style
from
vllm.compilation.decorators
import
support_
torch_
compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
...
...
@@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_compile_llama_style
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
...
...
vllm/model_executor/models/llama.py
View file @
e00c094f
...
...
@@ -28,7 +28,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile
_llama_style
from
vllm.compilation.decorators
import
support_
torch_
compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
...
...
@@ -266,7 +266,13 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_compile_llama_style
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment