Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eea55cca
Unverified
Commit
eea55cca
authored
Nov 11, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[1/N] torch.compile user interface design (#10237)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
9cdba966
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
37 deletions
+55
-37
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+9
-5
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+14
-7
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+14
-13
vllm/config.py
vllm/config.py
+18
-12
No files found.
tests/compile/piecewise/test_simple.py
View file @
eea55cca
...
...
@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.utils
import
direct_register_custom_op
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
global_counter
=
0
# create a library to hold the custom op
...
...
@@ -48,7 +47,11 @@ direct_register_custom_op(
@
support_torch_compile
class
SillyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def
test_simple_piecewise_compile
():
model
=
SillyModel
()
directory
=
os
.
path
.
dirname
(
__file__
)
config
=
os
.
path
.
join
(
directory
,
"piecewise_compilation_config.json"
)
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
model
=
SillyModel
(
vllm_config
=
VllmConfig
(),
prefix
=
''
)
inputs
=
torch
.
randn
(
100
).
cuda
()
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
eea55cca
...
...
@@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.plugins
import
set_compilation_config
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
config
:
LlamaConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
self
.
embedding_tokens
=
nn
.
Embedding
(
num_embeddings
=
config
.
vocab_size
,
...
...
@@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel
.
NO_COMPILATION
)
set_compilation_config
(
None
)
cls
=
LlamaModel
if
use_compile
:
cls
=
support_torch_compile
(
LlamaModel
)
model
=
cls
(
llama_config
).
eval
().
cuda
()
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
VllmConfig
(),
prefix
=
""
).
eval
().
cuda
()
B
=
16
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
...
...
@@ -357,7 +363,6 @@ def test_toy_llama():
def
benchmark
():
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
from
triton.testing
import
do_bench
cls
=
support_torch_compile
(
LlamaModel
)
# similar to llama 3.1-8B
llama_config
=
LlamaConfig
(
hidden_size
=
4096
,
...
...
@@ -390,7 +395,9 @@ def benchmark():
else
:
set_compilation_config
(
None
)
model
=
cls
(
llama_config
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
VllmConfig
(),
prefix
=
""
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
B
=
256
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
...
...
vllm/compilation/decorators.py
View file @
eea55cca
...
...
@@ -6,6 +6,7 @@ import torch
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
...
...
@@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
"""
A decorator to add support for compiling the forward method of a class.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
():
if
TorchCompileWrapperWithCustomDispatcher
in
cls
.
__bases__
:
# support decorating multiple times
return
cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
if
TorchCompileWrapperWithCustomDispatcher
not
in
cls
.
__bases__
:
# support decorating multiple times
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
old_init
(
self
,
*
args
,
**
kwargs
)
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
if
self
.
do_not_compile
:
return
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
cls
.
__init__
=
__init__
# type: ignore
...
...
@@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if
torch
.
compiler
.
is_compiling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
# the first compilation needs to have dynamic shapes marked
...
...
vllm/config.py
View file @
eea55cca
...
...
@@ -2041,12 +2041,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase.
"""
model_config
:
ModelConfig
cache_config
:
CacheConfig
parallel_config
:
ParallelConfig
scheduler_config
:
SchedulerConfig
device_config
:
DeviceConfig
load_config
:
LoadConfig
model_config
:
ModelConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
cache_config
:
CacheConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
parallel_config
:
ParallelConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
scheduler_config
:
SchedulerConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
device_config
:
DeviceConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
load_config
:
LoadConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
lora_config
:
Optional
[
LoRAConfig
]
=
None
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
decoding_config
:
Optional
[
DecodingConfig
]
=
None
...
...
@@ -2091,11 +2094,14 @@ class VllmConfig:
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
self
.
speculative_config
,
self
.
device_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
model_config
is
not
None
:
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
self
.
speculative_config
,
self
.
device_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
cache_config
is
not
None
:
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
...
...
@@ -2149,4 +2155,4 @@ class VllmConfig:
self
.
scheduler_config
.
num_scheduler_steps
,
self
.
cache_config
.
enable_prefix_caching
,
self
.
model_config
.
use_async_output_proc
,
self
.
model_config
.
mm_processor_kwargs
)
\ No newline at end of file
self
.
model_config
.
mm_processor_kwargs
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment