Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
eea55cca
Unverified
Commit
eea55cca
authored
Nov 11, 2024
by
youkaichao
Committed by
GitHub
Nov 11, 2024
Browse files
[1/N] torch.compile user interface design (#10237)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
9cdba966
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
55 additions
and
37 deletions
+55
-37
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+9
-5
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+14
-7
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+14
-13
vllm/config.py
vllm/config.py
+18
-12
No files found.
tests/compile/piecewise/test_simple.py
View file @
eea55cca
...
@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
...
@@ -12,10 +12,9 @@ from vllm.compilation.compile_context import set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
global_counter
=
0
global_counter
=
0
# create a library to hold the custom op
# create a library to hold the custom op
...
@@ -48,7 +47,11 @@ direct_register_custom_op(
...
@@ -48,7 +47,11 @@ direct_register_custom_op(
@
support_torch_compile
@
support_torch_compile
class
SillyModel
(
nn
.
Module
):
class
SillyModel
(
nn
.
Module
):
def
__init__
(
self
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
...
@@ -74,11 +77,12 @@ class SillyModel(nn.Module):
def
test_simple_piecewise_compile
():
def
test_simple_piecewise_compile
():
model
=
SillyModel
()
directory
=
os
.
path
.
dirname
(
__file__
)
directory
=
os
.
path
.
dirname
(
__file__
)
config
=
os
.
path
.
join
(
directory
,
"piecewise_compilation_config.json"
)
config
=
os
.
path
.
join
(
directory
,
"piecewise_compilation_config.json"
)
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
model
=
SillyModel
(
vllm_config
=
VllmConfig
(),
prefix
=
''
)
inputs
=
torch
.
randn
(
100
).
cuda
()
inputs
=
torch
.
randn
(
100
).
cuda
()
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
eea55cca
...
@@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
...
@@ -19,6 +19,7 @@ from vllm.compilation.config import CompilationConfig
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
VllmConfig
from
vllm.plugins
import
set_compilation_config
from
vllm.plugins
import
set_compilation_config
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -195,9 +196,15 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_torch_compile
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
)
->
None
:
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
config
:
LlamaConfig
,
prefix
:
str
=
''
,
**
kwargs
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
embedding_tokens
=
nn
.
Embedding
(
self
.
embedding_tokens
=
nn
.
Embedding
(
num_embeddings
=
config
.
vocab_size
,
num_embeddings
=
config
.
vocab_size
,
...
@@ -265,10 +272,9 @@ def run_model(llama_config,
...
@@ -265,10 +272,9 @@ def run_model(llama_config,
CompilationLevel
.
NO_COMPILATION
)
CompilationLevel
.
NO_COMPILATION
)
set_compilation_config
(
None
)
set_compilation_config
(
None
)
cls
=
LlamaModel
model
=
LlamaModel
(
config
=
llama_config
,
if
use_compile
:
vllm_config
=
VllmConfig
(),
cls
=
support_torch_compile
(
LlamaModel
)
prefix
=
""
).
eval
().
cuda
()
model
=
cls
(
llama_config
).
eval
().
cuda
()
B
=
16
# max batch size
B
=
16
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
...
@@ -357,7 +363,6 @@ def test_toy_llama():
...
@@ -357,7 +363,6 @@ def test_toy_llama():
def
benchmark
():
def
benchmark
():
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
cls
=
support_torch_compile
(
LlamaModel
)
# similar to llama 3.1-8B
# similar to llama 3.1-8B
llama_config
=
LlamaConfig
(
hidden_size
=
4096
,
llama_config
=
LlamaConfig
(
hidden_size
=
4096
,
...
@@ -390,7 +395,9 @@ def benchmark():
...
@@ -390,7 +395,9 @@ def benchmark():
else
:
else
:
set_compilation_config
(
None
)
set_compilation_config
(
None
)
model
=
cls
(
llama_config
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
VllmConfig
(),
prefix
=
""
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
B
=
256
# max batch size
B
=
256
# max batch size
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
input_ids
=
torch
.
randint
(
0
,
llama_config
.
vocab_size
,
(
B
,
)).
cuda
()
...
...
vllm/compilation/decorators.py
View file @
eea55cca
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
from
vllm.utils
import
supports_dynamo
...
@@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
...
@@ -110,26 +111,26 @@ def _support_torch_compile(cls: type,
"""
"""
A decorator to add support for compiling the forward method of a class.
A decorator to add support for compiling the forward method of a class.
"""
"""
if
TorchCompileWrapperWithCustomDispatcher
in
cls
.
__bases__
:
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# support decorating multiple times
# will handle the compilation, so we don't need to do anything here.
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
():
return
cls
return
cls
# take care of method resolution order
# take care of method resolution order
# make sure super().__init__ is called on the base class
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
# other than TorchCompileWrapperWithCustomDispatcher
if
TorchCompileWrapperWithCustomDispatcher
not
in
cls
.
__bases__
:
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
# support decorating multiple times
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
# type: ignore
old_init
=
cls
.
__init__
# type: ignore
def
__init__
(
self
,
*
args
,
**
kwargs
):
def
__init__
(
self
,
*
,
vllm_config
:
VllmConfig
,
prefix
:
str
=
''
,
**
kwargs
):
old_init
(
self
,
*
args
,
**
kwargs
)
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
if
self
.
do_not_compile
:
return
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
cls
.
__init__
=
__init__
# type: ignore
cls
.
__init__
=
__init__
# type: ignore
...
@@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
...
@@ -138,7 +139,7 @@ def _support_torch_compile(cls: type,
# torch.compiler.is_compiling() means we are inside the compilation
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
# need to compile the model inside.
if
torch
.
compiler
.
is_compiling
():
if
self
.
do_not_compile
or
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
*
args
,
**
kwargs
)
return
self
.
forward
(
*
args
,
**
kwargs
)
# the first compilation needs to have dynamic shapes marked
# the first compilation needs to have dynamic shapes marked
...
...
vllm/config.py
View file @
eea55cca
...
@@ -2041,12 +2041,15 @@ class VllmConfig:
...
@@ -2041,12 +2041,15 @@ class VllmConfig:
simplifies passing around the distinct configurations in the codebase.
simplifies passing around the distinct configurations in the codebase.
"""
"""
model_config
:
ModelConfig
model_config
:
ModelConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
cache_config
:
CacheConfig
cache_config
:
CacheConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
parallel_config
:
ParallelConfig
parallel_config
:
ParallelConfig
=
field
(
default
=
None
,
scheduler_config
:
SchedulerConfig
init
=
True
)
# type: ignore
device_config
:
DeviceConfig
scheduler_config
:
SchedulerConfig
=
field
(
default
=
None
,
load_config
:
LoadConfig
init
=
True
)
# type: ignore
device_config
:
DeviceConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
load_config
:
LoadConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
lora_config
:
Optional
[
LoRAConfig
]
=
None
lora_config
:
Optional
[
LoRAConfig
]
=
None
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
speculative_config
:
Optional
[
SpeculativeConfig
]
=
None
decoding_config
:
Optional
[
DecodingConfig
]
=
None
decoding_config
:
Optional
[
DecodingConfig
]
=
None
...
@@ -2091,11 +2094,14 @@ class VllmConfig:
...
@@ -2091,11 +2094,14 @@ class VllmConfig:
def
__post_init__
(
self
):
def
__post_init__
(
self
):
"""Verify configs are valid & consistent with each other.
"""Verify configs are valid & consistent with each other.
"""
"""
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
if
self
.
model_config
is
not
None
:
self
.
speculative_config
,
self
.
model_config
.
verify_async_output_proc
(
self
.
parallel_config
,
self
.
device_config
)
self
.
speculative_config
,
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
device_config
)
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
self
.
model_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
cache_config
is
not
None
:
self
.
cache_config
.
verify_with_parallel_config
(
self
.
parallel_config
)
if
self
.
lora_config
:
if
self
.
lora_config
:
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
self
.
lora_config
.
verify_with_model_config
(
self
.
model_config
)
...
@@ -2149,4 +2155,4 @@ class VllmConfig:
...
@@ -2149,4 +2155,4 @@ class VllmConfig:
self
.
scheduler_config
.
num_scheduler_steps
,
self
.
scheduler_config
.
num_scheduler_steps
,
self
.
cache_config
.
enable_prefix_caching
,
self
.
cache_config
.
enable_prefix_caching
,
self
.
model_config
.
use_async_output_proc
,
self
.
model_config
.
use_async_output_proc
,
self
.
model_config
.
mm_processor_kwargs
)
self
.
model_config
.
mm_processor_kwargs
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment