Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e00c094f
Unverified
Commit
e00c094f
authored
Oct 10, 2024
by
youkaichao
Committed by
GitHub
Oct 10, 2024
Browse files
[torch.compile] generic decorators (#9258)
parent
a78c6ba7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
74 additions
and
34 deletions
+74
-34
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+58
-30
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+8
-2
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+8
-2
No files found.
vllm/compilation/decorators.py
View file @
e00c094f
from
typing
import
List
,
Optional
,
Union
import
inspect
from
typing
import
Dict
,
List
,
Union
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
from
vllm.utils
import
supports_dynamo
def
support_compile_llama_style
(
cls
:
type
):
def
support_torch_compile
(
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
"""
A decorator to add support for compiling the forward method of a class.
`dynamic_arg_dims` is a dictionary that maps argument names to the dynamic
dimensions of the argument. The dynamic dimensions can be either a single
integer or a list of integers.
Depending on the value of arguments:
- if it is a single integer, the corresponding dimension of the argument
will be marked as dynamic.
- if it is `None`, ignored.
- if it is `IntermediateTensors`, all the tensors in the intermediate
tensors will be marked as dynamic.
- otherwise, it will raise an error.
NOTE: if an argument is `None`, it should always be passed as `None` during
the lifetime of the model, otherwise, it cannot be captured as a single
computation graph.
"""
def
cls_decorator_helper
(
cls
:
type
):
# helper to pass `dynamic_arg_dims`` to `_support_torch_compile``
# to avoid too much indentation for `_support_torch_compile``
sig
=
inspect
.
signature
(
cls
.
forward
)
for
k
in
dynamic_arg_dims
:
if
k
not
in
sig
.
parameters
:
raise
ValueError
(
f
"Argument
{
k
}
not found in the forward method of
{
cls
}
"
)
return
_support_torch_compile
(
cls
,
dynamic_arg_dims
)
return
cls_decorator_helper
def
_support_torch_compile
(
cls
:
type
,
dynamic_arg_dims
:
Dict
[
str
,
Union
[
int
,
List
[
int
]]]):
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
If a module's **forward signature** is compatible with llama, this
decorator can be used to enable the compilation of the forward method.
"""
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
...
@@ -37,48 +71,42 @@ def support_compile_llama_style(cls: type):
...
@@ -37,48 +71,42 @@ def support_compile_llama_style(cls: type):
cls
.
__init__
=
__init__
cls
.
__init__
=
__init__
def
__call__
(
def
__call__
(
self
,
*
args
,
**
kwargs
):
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
if
torch
.
compiler
.
is_compiling
():
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
return
self
.
forward
(
*
args
,
**
kwargs
)
intermediate_tensors
,
inputs_embeds
)
# the first compilation needs to have dynamic shapes marked
# the first compilation needs to have dynamic shapes marked
if
len
(
self
.
compiled_codes
)
<
1
:
if
len
(
self
.
compiled_codes
)
<
1
:
if
input_ids
is
not
None
:
sig
=
inspect
.
signature
(
self
.
__class__
.
forward
)
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
bound_args
=
sig
.
bind
(
self
,
*
args
,
**
kwargs
)
torch
.
_dynamo
.
mark_dynamic
(
positions
,
0
)
bound_args
.
apply_defaults
()
if
inputs_embeds
is
not
None
:
for
k
,
dims
in
dynamic_arg_dims
.
items
():
torch
.
_dynamo
.
mark_dynamic
(
inputs_embeds
,
0
)
arg
=
bound_args
.
arguments
.
get
(
k
)
if
intermediate_tensors
is
not
None
:
if
arg
is
not
None
:
for
tensors
in
intermediate_tensors
.
tensors
.
values
():
if
isinstance
(
arg
,
torch
.
Tensor
):
torch
.
_dynamo
.
mark_dynamic
(
tensors
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
arg
,
dims
)
elif
isinstance
(
arg
,
IntermediateTensors
):
for
tensor
in
arg
.
tensors
.
values
():
torch
.
_dynamo
.
mark_dynamic
(
tensor
,
dims
)
else
:
raise
ValueError
(
"Unsupported dynamic dimensions"
f
"
{
dims
}
for argument
{
k
}
with type
{
type
(
arg
)
}
."
)
# if we don't use custom dispatcher, we can directly call the
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
# with the overhead of guard evaluation and recompilation.
if
len
(
self
.
compiled_codes
)
<
1
or
not
self
.
use_custom_dispatcher
:
if
len
(
self
.
compiled_codes
)
<
1
or
not
self
.
use_custom_dispatcher
:
return
self
.
compiled_callable
(
input_ids
,
positions
,
kv_caches
,
return
self
.
compiled_callable
(
*
args
,
**
kwargs
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
# usually, capturing the model once is enough, and then we can
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# dispatch to the compiled code directly, without going through
# the Dynamo guard mechanism.
# the Dynamo guard mechanism.
with
self
.
dispatch_to_code
(
0
):
with
self
.
dispatch_to_code
(
0
):
model_output
=
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
model_output
=
self
.
forward
(
*
args
,
**
kwargs
)
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
return
model_output
cls
.
__call__
=
__call__
cls
.
__call__
=
__call__
...
...
vllm/model_executor/models/gemma2.py
View file @
e00c094f
...
@@ -21,7 +21,7 @@ from torch import nn
...
@@ -21,7 +21,7 @@ from torch import nn
from
transformers
import
Gemma2Config
from
transformers
import
Gemma2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile
_llama_style
from
vllm.compilation.decorators
import
support_
torch_
compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -239,7 +239,13 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_compile_llama_style
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
Gemma2Model
(
nn
.
Module
):
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/llama.py
View file @
e00c094f
...
@@ -28,7 +28,7 @@ from torch import nn
...
@@ -28,7 +28,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile
_llama_style
from
vllm.compilation.decorators
import
support_
torch_
compile
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
...
@@ -266,7 +266,13 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -266,7 +266,13 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_compile_llama_style
@
support_torch_compile
(
dynamic_arg_dims
=
{
"input_ids"
:
0
,
"positions"
:
0
,
"inputs_embeds"
:
0
,
"intermediate_tensors"
:
0
,
})
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment