Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
4fd93750
Unverified
Commit
4fd93750
authored
Nov 16, 2024
by
youkaichao
Committed by
GitHub
Nov 16, 2024
Browse files
[2/N][torch.compile] make compilation cfg part of vllm cfg (#10383)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
661a34fd
Changes
27
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
282 additions
and
257 deletions
+282
-257
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+5
-3
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+12
-10
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+1
-1
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+1
-1
tests/compile/test_fusion.py
tests/compile/test_fusion.py
+1
-1
tests/compile/test_wrapper.py
tests/compile/test_wrapper.py
+3
-1
tests/compile/utils.py
tests/compile/utils.py
+1
-1
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+26
-26
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+1
-1
tests/tpu/test_custom_dispatcher.py
tests/tpu/test_custom_dispatcher.py
+1
-1
vllm/compilation/backends.py
vllm/compilation/backends.py
+13
-7
vllm/compilation/config.py
vllm/compilation/config.py
+0
-159
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+5
-5
vllm/compilation/fusion.py
vllm/compilation/fusion.py
+1
-1
vllm/compilation/inductor_pass.py
vllm/compilation/inductor_pass.py
+1
-1
vllm/compilation/levels.py
vllm/compilation/levels.py
+0
-8
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+6
-5
vllm/config.py
vllm/config.py
+189
-0
vllm/envs.py
vllm/envs.py
+0
-13
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+15
-12
No files found.
tests/compile/piecewise/test_simple.py
View file @
4fd93750
...
@@ -11,8 +11,8 @@ from torch.library import Library
...
@@ -11,8 +11,8 @@ from torch.library import Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
,
VllmConfig
from
vllm.
config
import
V
llm
C
onfig
from
vllm.
plugins
import
set_current_v
llm
_c
onfig
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
global_counter
=
0
global_counter
=
0
...
@@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
...
@@ -82,7 +82,9 @@ def test_simple_piecewise_compile():
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
model
=
SillyModel
(
vllm_config
=
VllmConfig
(),
prefix
=
''
)
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
inputs
=
torch
.
randn
(
100
).
cuda
()
inputs
=
torch
.
randn
(
100
).
cuda
()
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
4fd93750
...
@@ -15,12 +15,10 @@ from torch import nn
...
@@ -15,12 +15,10 @@ from torch import nn
from
torch.library
import
Library
from
torch.library
import
Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.config
import
CompilationConfig
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
VllmConfig
from
vllm.plugins
import
set_compilation_config
,
set_current_vllm_config
from
vllm.plugins
import
set_compilation_config
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
# create a library to hold the custom op
...
@@ -272,8 +270,10 @@ def run_model(llama_config,
...
@@ -272,8 +270,10 @@ def run_model(llama_config,
CompilationLevel
.
NO_COMPILATION
)
CompilationLevel
.
NO_COMPILATION
)
set_compilation_config
(
None
)
set_compilation_config
(
None
)
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
V
llm
C
onfig
()
,
vllm_config
=
v
llm
_c
onfig
,
prefix
=
""
).
eval
().
cuda
()
prefix
=
""
).
eval
().
cuda
()
B
=
16
# max batch size
B
=
16
# max batch size
...
@@ -395,8 +395,10 @@ def benchmark():
...
@@ -395,8 +395,10 @@ def benchmark():
else
:
else
:
set_compilation_config
(
None
)
set_compilation_config
(
None
)
vllm_config
=
VllmConfig
()
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
V
llm
C
onfig
()
,
vllm_config
=
v
llm
_c
onfig
,
prefix
=
""
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
prefix
=
""
).
eval
().
cuda
().
to
(
torch
.
bfloat16
)
B
=
256
# max batch size
B
=
256
# max batch size
...
...
tests/compile/test_basic_correctness.py
View file @
4fd93750
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional
...
@@ -3,7 +3,7 @@ from typing import Dict, List, Optional
import
pytest
import
pytest
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
from
vllm.utils
import
cuda_device_count_stateless
from
vllm.utils
import
cuda_device_count_stateless
from
..utils
import
compare_all_settings
from
..utils
import
compare_all_settings
...
...
tests/compile/test_full_graph.py
View file @
4fd93750
import
pytest
import
pytest
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
from
..utils
import
fork_new_process_for_each_test
from
..utils
import
fork_new_process_for_each_test
from
.utils
import
TEST_MODELS
,
check_full_graph_support
from
.utils
import
TEST_MODELS
,
check_full_graph_support
...
...
tests/compile/test_fusion.py
View file @
4fd93750
...
@@ -3,10 +3,10 @@ import torch
...
@@ -3,10 +3,10 @@ import torch
from
compressed_tensors.quantization
import
FP8_DTYPE
from
compressed_tensors.quantization
import
FP8_DTYPE
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.config
import
CompilationConfig
from
vllm.compilation.fusion
import
(
FusionPass
,
find_auto_fn
,
from
vllm.compilation.fusion
import
(
FusionPass
,
find_auto_fn
,
find_auto_fn_maybe
)
find_auto_fn_maybe
)
from
vllm.compilation.reshapes
import
RedundantReshapesPass
from
vllm.compilation.reshapes
import
RedundantReshapesPass
from
vllm.config
import
CompilationConfig
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
apply_fp8_linear
)
apply_fp8_linear
)
...
...
tests/compile/test_wrapper.py
View file @
4fd93750
...
@@ -3,6 +3,7 @@ from typing import Optional
...
@@ -3,6 +3,7 @@ from typing import Optional
import
torch
import
torch
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
CompilationLevel
class
MyMod
(
torch
.
nn
.
Module
):
class
MyMod
(
torch
.
nn
.
Module
):
...
@@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
...
@@ -18,7 +19,8 @@ class MyWrapper(TorchCompileWrapperWithCustomDispatcher):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
compiled_callable
=
torch
.
compile
(
self
.
forward
,
backend
=
"eager"
)
compiled_callable
=
torch
.
compile
(
self
.
forward
,
backend
=
"eager"
)
super
().
__init__
(
compiled_callable
)
super
().
__init__
(
compiled_callable
,
compilation_level
=
CompilationLevel
.
DYNAMO_ONCE
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
cache
:
Optional
[
torch
.
Tensor
]
=
None
):
def
forward
(
self
,
x
:
torch
.
Tensor
,
cache
:
Optional
[
torch
.
Tensor
]
=
None
):
# this is the function to be compiled
# this is the function to be compiled
...
...
tests/compile/utils.py
View file @
4fd93750
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
TEST_MODELS
=
[
TEST_MODELS
=
[
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
4fd93750
...
@@ -3,11 +3,13 @@ from typing import List
...
@@ -3,11 +3,13 @@ from typing import List
import
pytest
import
pytest
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
ReLUSquaredActivation
,
ReLUSquaredActivation
,
SiluAndMul
)
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.plugins
import
set_current_vllm_config
# Registered subclass for test
# Registered subclass for test
...
@@ -51,12 +53,10 @@ class Relu3(ReLUSquaredActivation):
...
@@ -51,12 +53,10 @@ class Relu3(ReLUSquaredActivation):
])
])
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
default_on
:
bool
):
default_on
:
bool
):
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
env
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
torch_level
)
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
torch_level
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
# Reset default_on (computed once):
custom_ops
=
env
.
split
(
","
)))
CustomOp
.
default_on
.
cache_clear
()
with
set_current_vllm_config
(
vllm_config
):
assert
CustomOp
.
default_on
()
==
default_on
assert
CustomOp
.
default_on
()
==
default_on
ops_enabled
=
[
bool
(
x
)
for
x
in
ops_enabled
]
ops_enabled
=
[
bool
(
x
)
for
x
in
ops_enabled
]
...
@@ -85,8 +85,8 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
...
@@ -85,8 +85,8 @@ def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"env"
,
[
"all,none"
,
"all,+rms_norm,all"
,
"+rms_norm,-rms_norm"
])
"env"
,
[
"all,none"
,
"all,+rms_norm,all"
,
"+rms_norm,-rms_norm"
])
def
test_enabled_ops_invalid
(
env
:
str
):
def
test_enabled_ops_invalid
(
env
:
str
):
os
.
environ
[
"VLLM_CUSTOM_OPS"
]
=
env
with
pytest
.
raises
(
Exception
):
# noqa
CustomOp
.
default_on
.
cache_clear
()
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
env
.
split
(
","
)))
with
pytest
.
raises
(
AssertionError
):
with
set_current_vllm_config
(
vllm_config
):
RMSNorm
(
1024
).
enabled
()
RMSNorm
(
1024
).
enabled
()
tests/tpu/test_compilation.py
View file @
4fd93750
...
@@ -5,7 +5,7 @@ import tempfile
...
@@ -5,7 +5,7 @@ import tempfile
import
depyf
import
depyf
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# disable custom dispatcher, let Dynamo takes over
# all the control
# all the control
...
...
tests/tpu/test_custom_dispatcher.py
View file @
4fd93750
import
os
import
os
from
vllm.co
mpilation.levels
import
CompilationLevel
from
vllm.co
nfig
import
CompilationLevel
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
...
...
vllm/compilation/backends.py
View file @
4fd93750
...
@@ -10,13 +10,12 @@ import torch
...
@@ -10,13 +10,12 @@ import torch
import
torch.fx
as
fx
import
torch.fx
as
fx
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.utils
import
combine_fx_passes
,
weak_ref_tensors
from
vllm.utils
import
combine_fx_passes
,
weak_ref_tensors
from
.config
import
CompilationConfig
from
.counter
import
compilation_counter
from
.counter
import
compilation_counter
from
.fusion
import
FusionPass
from
.fusion
import
FusionPass
from
.levels
import
CompilationLevel
from
.reshapes
import
RedundantReshapesPass
from
.reshapes
import
RedundantReshapesPass
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -392,7 +391,10 @@ class VllmBackend:
...
@@ -392,7 +391,10 @@ class VllmBackend:
sym_tensor_indices
:
List
[
int
]
sym_tensor_indices
:
List
[
int
]
input_buffers
:
List
[
torch
.
Tensor
]
input_buffers
:
List
[
torch
.
Tensor
]
def
__init__
(
self
,
post_grad_passes
:
Sequence
[
Callable
]
=
()):
def
__init__
(
self
,
compilation_configs
:
CompilationConfig
,
):
global
global_graph_pool
global
global_graph_pool
if
global_graph_pool
is
None
:
if
global_graph_pool
is
None
:
global_graph_pool
=
torch
.
cuda
.
graph_pool_handle
()
global_graph_pool
=
torch
.
cuda
.
graph_pool_handle
()
...
@@ -401,11 +403,13 @@ class VllmBackend:
...
@@ -401,11 +403,13 @@ class VllmBackend:
# streams, it might not be safe to share a global pool.
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
# only investigate this when we use multiple streams
self
.
graph_pool
=
global_graph_pool
self
.
graph_pool
=
global_graph_pool
self
.
post_grad_passes
=
post_grad_passes
self
.
post_grad_passes
=
[]
self
.
sym_tensor_indices
=
[]
self
.
sym_tensor_indices
=
[]
self
.
input_buffers
=
[]
self
.
input_buffers
=
[]
self
.
compilation_configs
=
compilation_configs
# `torch.compile` is JIT compiled, so we don't need to
# `torch.compile` is JIT compiled, so we don't need to
# do anything here
# do anything here
...
@@ -437,10 +441,10 @@ class VllmBackend:
...
@@ -437,10 +441,10 @@ class VllmBackend:
assert
not
self
.
_called
,
"VllmBackend can only be called once"
assert
not
self
.
_called
,
"VllmBackend can only be called once"
self
.
graph
=
graph
self
.
graph
=
graph
# config is
rea
d now, because only here can
# config is
update
d now, because only here can
# we get the sizes to capture for cudagraph
# we get the sizes to capture for cudagraph
# from compilation context
# from compilation context
self
.
compilation_configs
=
CompilationConfig
.
select_and_init_config
()
self
.
compilation_configs
.
init_during_runtime
()
self
.
add_passes_to_config
()
self
.
add_passes_to_config
()
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
self
.
split_gm
,
self
.
piecewise_graphs
=
split_graph
(
...
@@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
...
@@ -688,4 +692,6 @@ def select_default_backend(level: int) -> Union[str, Callable]:
return
backend_str
return
backend_str
assert
level
==
CompilationLevel
.
PIECEWISE
assert
level
==
CompilationLevel
.
PIECEWISE
return
VllmBackend
()
from
vllm.plugins
import
get_current_vllm_config
compilation_config
=
get_current_vllm_config
().
compilation_config
return
VllmBackend
(
compilation_config
)
vllm/compilation/config.py
deleted
100644 → 0
View file @
661a34fd
import
copy
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
import
vllm.envs
as
envs
from
vllm.logger
import
init_logger
from
.compile_context
import
get_compile_context
logger
=
init_logger
(
__name__
)
class
CompilationConfig
(
BaseModel
):
"""
Configuration for compilation.
It has two parts:
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- Custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
use_inductor
:
bool
=
True
inductor_specialize_for_cudagraph_no_more_than
:
Optional
[
int
]
=
None
inductor_compile_sizes
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
dict
)
inductor_compile_config
:
Dict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
Dict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
use_cudagraph
:
bool
=
False
non_cudagraph_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
cudagraph_num_of_warmups
:
int
=
0
cudagraph_capture_sizes
:
Optional
[
List
[
int
]]
=
None
cudagraph_copy_inputs
:
bool
=
False
dump_graph_stages
:
List
[
str
]
=
Field
(
default_factory
=
list
)
dump_graph_dir
:
Path
=
Field
(
default
=
Path
(
"."
))
enable_fusion
:
bool
=
True
# not configurable, computed after init
compile_sizes
:
List
[
int
]
=
PrivateAttr
capture_sizes
:
List
[
int
]
=
PrivateAttr
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
for
k
,
v
in
self
.
inductor_passes
.
items
():
if
not
isinstance
(
v
,
str
):
assert
callable
(
v
),
(
f
"pass
{
k
}
should be a function or a qualified name"
)
self
.
inductor_compile_config
[
k
]
=
v
continue
# resolve function from qualified name
names
=
v
.
split
(
"."
)
module
=
"."
.
join
(
names
[:
-
1
])
func_name
=
names
[
-
1
]
func
=
__import__
(
module
).
__dict__
[
func_name
]
self
.
inductor_compile_config
[
k
]
=
func
def
init_during_runtime
(
self
):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
context
=
get_compile_context
()
context
=
copy
.
deepcopy
(
context
)
if
context
is
not
None
else
[]
sizes_to_specialize
:
List
[
int
]
=
context
if
self
.
cudagraph_capture_sizes
is
None
:
self
.
capture_sizes
=
sizes_to_specialize
else
:
self
.
capture_sizes
=
self
.
cudagraph_capture_sizes
logger
.
info
((
"cudagraph sizes specified by model runner"
" %s is overridden by config %s"
),
sizes_to_specialize
,
self
.
cudagraph_capture_sizes
)
if
self
.
inductor_specialize_for_cudagraph_no_more_than
is
not
None
:
assert
self
.
inductor_compile_sizes
is
None
,
(
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None"
)
self
.
compile_sizes
=
[
x
for
x
in
self
.
capture_sizes
if
x
<=
self
.
inductor_specialize_for_cudagraph_no_more_than
]
else
:
assert
self
.
inductor_compile_sizes
is
not
None
,
(
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None"
)
self
.
compile_sizes
=
self
.
inductor_compile_sizes
@
staticmethod
def
select_and_init_config
()
->
"CompilationConfig"
:
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path
=
envs
.
VLLM_TORCH_COMPILE_CONFIG
if
config_path
is
not
None
:
with
open
(
config_path
)
as
json_file
:
config
=
CompilationConfig
.
model_validate_json
(
json_file
.
read
())
else
:
from
vllm.plugins
import
get_compilation_config
predefined_config
=
get_compilation_config
()
config
=
predefined_config
if
predefined_config
is
not
None
else
(
CompilationConfig
())
config
.
init_during_runtime
()
return
config
vllm/compilation/decorators.py
View file @
4fd93750
...
@@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union
...
@@ -3,10 +3,8 @@ from typing import Dict, List, Optional, Union
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.config
import
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.sequence
import
IntermediateTensors
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
from
vllm.utils
import
supports_dynamo
...
@@ -126,12 +124,14 @@ def _support_torch_compile(cls: type,
...
@@ -126,12 +124,14 @@ def _support_torch_compile(cls: type,
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
old_init
(
self
,
vllm_config
=
vllm_config
,
prefix
=
prefix
,
**
kwargs
)
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
# will handle the compilation, so we don't need to do anything here.
self
.
do_not_compile
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
self
.
do_not_compile
=
\
vllm_config
.
compilation_config
.
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
()
]
or
not
supports_dynamo
()
if
self
.
do_not_compile
:
if
self
.
do_not_compile
:
return
return
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
,
compilation_level
=
vllm_config
.
compilation_config
.
level
)
cls
.
__init__
=
__init__
# type: ignore
cls
.
__init__
=
__init__
# type: ignore
...
...
vllm/compilation/fusion.py
View file @
4fd93750
...
@@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
...
@@ -6,8 +6,8 @@ from torch._higher_order_ops.auto_functionalize import auto_functionalized
from
torch._inductor.pattern_matcher
import
(
Match
,
PatternMatcherPass
,
from
torch._inductor.pattern_matcher
import
(
Match
,
PatternMatcherPass
,
fwd_only
,
register_replacement
)
fwd_only
,
register_replacement
)
from
vllm.compilation.config
import
CompilationConfig
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.compilation.inductor_pass
import
InductorPass
from
vllm.config
import
CompilationConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/compilation/inductor_pass.py
View file @
4fd93750
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
...
@@ -2,7 +2,7 @@ from abc import ABC, abstractmethod
import
torch
import
torch
from
vllm.
compilation.
config
import
CompilationConfig
from
vllm.config
import
CompilationConfig
# yapf: disable
# yapf: disable
from
vllm.distributed
import
get_tensor_model_parallel_rank
as
get_tp_rank
from
vllm.distributed
import
get_tensor_model_parallel_rank
as
get_tp_rank
from
vllm.distributed
import
(
from
vllm.distributed
import
(
...
...
vllm/compilation/levels.py
deleted
100644 → 0
View file @
661a34fd
# constants for the levels of the compilation process
class
CompilationLevel
:
NO_COMPILATION
=
0
DYNAMO_AS_IS
=
1
DYNAMO_ONCE
=
2
PIECEWISE
=
3
vllm/compilation/wrapper.py
View file @
4fd93750
...
@@ -8,8 +8,7 @@ from typing import Callable, List, Optional
...
@@ -8,8 +8,7 @@ from typing import Callable, List, Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CompilationLevel
from
.levels
import
CompilationLevel
class
TorchCompileWrapperWithCustomDispatcher
:
class
TorchCompileWrapperWithCustomDispatcher
:
...
@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -25,7 +24,9 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
`torch.compile` over the forward method.
"""
"""
def
__init__
(
self
,
compiled_callable
:
Optional
[
Callable
]
=
None
):
def
__init__
(
self
,
compiled_callable
:
Optional
[
Callable
]
=
None
,
compilation_level
:
int
=
0
):
if
compiled_callable
is
None
:
if
compiled_callable
is
None
:
# default compilation settings
# default compilation settings
...
@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -38,7 +39,7 @@ class TorchCompileWrapperWithCustomDispatcher:
backend
=
get_torch_compile_backend
()
backend
=
get_torch_compile_backend
()
if
backend
is
None
:
if
backend
is
None
:
from
vllm.compilation.backends
import
select_default_backend
from
vllm.compilation.backends
import
select_default_backend
backend
=
select_default_backend
(
envs
.
VLLM_TORCH_COMPILE_LEVEL
)
backend
=
select_default_backend
(
compilation_level
)
compiled_callable
=
torch
.
compile
(
compiled_callable
=
torch
.
compile
(
self
.
forward
,
self
.
forward
,
...
@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -54,7 +55,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
# and the default Dynamo guard mechanism.
self
.
use_custom_dispatcher
:
bool
=
\
self
.
use_custom_dispatcher
:
bool
=
\
envs
.
VLLM_TORCH_COMPILE_LEVEL
>=
CompilationLevel
.
DYNAMO_ONCE
compilation_level
>=
CompilationLevel
.
DYNAMO_ONCE
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Implement the dispatch logic here, beyond the torch.compile level.
"""Implement the dispatch logic here, beyond the torch.compile level.
...
...
vllm/config.py
View file @
4fd93750
...
@@ -3,10 +3,12 @@ import enum
...
@@ -3,10 +3,12 @@ import enum
import
json
import
json
import
warnings
import
warnings
from
dataclasses
import
dataclass
,
field
,
replace
from
dataclasses
import
dataclass
,
field
,
replace
from
pathlib
import
Path
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Dict
,
Final
,
List
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Dict
,
Final
,
List
,
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
Literal
,
Mapping
,
Optional
,
Set
,
Tuple
,
Type
,
Union
)
import
torch
import
torch
from
pydantic
import
BaseModel
,
Field
,
PrivateAttr
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -2052,6 +2054,185 @@ class ObservabilityConfig:
...
@@ -2052,6 +2054,185 @@ class ObservabilityConfig:
f
"installed. Original error:
\n
{
otel_import_error_traceback
}
"
)
f
"installed. Original error:
\n
{
otel_import_error_traceback
}
"
)
class
CompilationLevel
:
# constants for the levels of the compilation process
NO_COMPILATION
=
0
DYNAMO_AS_IS
=
1
DYNAMO_ONCE
=
2
PIECEWISE
=
3
class
CompilationConfig
(
BaseModel
):
"""
Configuration for compilation.
It has three parts:
- Top-level Compilation control:
- level: the level of compilation.
- 0: no compilation.
- 1: dynamo as is.
- 2: dynamo once.
- 3: piecewise compilation.
- custom_ops: fine-grained control over which custom ops to enable/disable.
Use 'all' to enable all, 'none' to disable all.
Also specify a list of custom op names to enable (prefixed with a '+'),
or disable (prefixed with a '-').
Examples:
- 'all,-op1' to enable all except op1
- 'none,+op1,+op2' to enable only op1 and op2
By default, all custom ops are enabled when running without Inductor
and disabled when running with Inductor (compile_level >= Inductor).
- CudaGraph capture:
- use_cudagraph: whether to use cudagraph inside compilation.
- False: cudagraph inside compilation is not used.
- True: cudagraph inside compilation is used. It requires
that all input buffers have fixed addresses.
Note that this is orthogonal to the cudagraph capture out
side of compilation.
TODO: move outside cudagraph logic into compilation.
torch.compile will handle cudagraph capture logic in the future.
- cudagraph_capture_sizes: sizes to capture cudagraph.
- None: capture sizes are inferred from compilation context.
- List[int]: capture sizes are specified.
- cudagraph_num_of_warmups: number of warmup runs for cudagraph.
It means the first several runs will be treated as warmup runs.
Only after that, the execution will be recorded, and the recorded
cudagraph will be used for subsequent runs.
- cudagraph_copy_inputs: whether to copy input tensors for
cudagraph. If the caller can guarantee that the same input buffers
are always used, it can set this to False. Otherwise, it should
set this to True, and the compiler will copy the input to an
internally managed buffer. Default is False.
- Inductor compilation:
- use_inductor: whether to use inductor compilation.
- False: inductor compilation is not used. graph runs in eager.
- True: inductor compilation is used. one graph for symbolic shape
is compiled. In addition, compile for different sizes specified
in inductor_compile_sizes, using configurations
in inductor_compile_config.
- inductor_compile_sizes: sizes to compile for inductor.
- inductor_specialize_for_cudagraph_no_more_than: an optional integer
to specialize inductor for cudagraph sizes no more than the
specified size. It is useful when we want to specialize inductor
with a subset of cudagraph sizes.
- inductor_compile_config: additional configurations for inductor.
- None: use default configurations.
- inductor_passes: additional passes for inductor. It is a dictionary
from pass name to pass function qualified name. We use function
name because the config uses json format. If we pass the config
from Python, functions can also be passed directly via Python object
constructor, e.g. `CompilationConfig(inductor_passes={"a": func})`
- custom inductor passes:
- dump_graph_stages: list of stages for which we want to dump the graph.
Each pass defines its own stages (before, after, maybe in-between).
- dump_graph_dir: directory to dump the graph. Default is .
- enable_fusion: whether to enable the custom fusion pass.
TODO better pass enabling system.
Why we have different sizes for cudagraph and inductor:
- cudagraph: a cudagraph captured for a specific size can only be used
for the same size. We need to capture all the sizes we want to use.
- inductor: a graph compiled by inductor for a general shape can be used
for different sizes. Inductor can also compile for specific sizes,
where it can have more information to optimize the graph with fully
static shapes. However, we find the general shape compilation is
sufficient for most cases. It might be beneficial to compile for
certain small batchsizes, where inductor is good at optimizing.
"""
# noqa
level
:
int
=
0
custom_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
use_inductor
:
bool
=
True
inductor_specialize_for_cudagraph_no_more_than
:
Optional
[
int
]
=
None
inductor_compile_sizes
:
Optional
[
List
[
int
]]
=
Field
(
default_factory
=
dict
)
inductor_compile_config
:
Dict
=
Field
(
default_factory
=
dict
)
inductor_passes
:
Dict
[
str
,
str
]
=
Field
(
default_factory
=
dict
)
use_cudagraph
:
bool
=
False
non_cudagraph_ops
:
List
[
str
]
=
Field
(
default_factory
=
list
)
cudagraph_num_of_warmups
:
int
=
0
cudagraph_capture_sizes
:
Optional
[
List
[
int
]]
=
None
cudagraph_copy_inputs
:
bool
=
False
dump_graph_stages
:
List
[
str
]
=
Field
(
default_factory
=
list
)
dump_graph_dir
:
Path
=
Field
(
default
=
Path
(
"."
))
enable_fusion
:
bool
=
True
# not configurable, computed after init
compile_sizes
:
List
[
int
]
=
PrivateAttr
capture_sizes
:
List
[
int
]
=
PrivateAttr
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
self
.
level
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
count_none
=
self
.
custom_ops
.
count
(
"none"
)
count_all
=
self
.
custom_ops
.
count
(
"all"
)
assert
count_none
+
count_all
<=
1
,
"Can only specify 'none' or 'all'"
for
k
,
v
in
self
.
inductor_passes
.
items
():
if
not
isinstance
(
v
,
str
):
assert
callable
(
v
),
(
f
"pass
{
k
}
should be a function or a qualified name"
)
self
.
inductor_compile_config
[
k
]
=
v
continue
# resolve function from qualified name
names
=
v
.
split
(
"."
)
module
=
"."
.
join
(
names
[:
-
1
])
func_name
=
names
[
-
1
]
func
=
__import__
(
module
).
__dict__
[
func_name
]
self
.
inductor_compile_config
[
k
]
=
func
def
init_during_runtime
(
self
):
"""To complete the initialization of config,
we need to know the compile context, which is only available
during the first run of the model.
"""
from
vllm.compilation.compile_context
import
get_compile_context
context
=
get_compile_context
()
context
=
copy
.
deepcopy
(
context
)
if
context
is
not
None
else
[]
sizes_to_specialize
:
List
[
int
]
=
context
if
self
.
cudagraph_capture_sizes
is
None
:
self
.
capture_sizes
=
sizes_to_specialize
else
:
self
.
capture_sizes
=
self
.
cudagraph_capture_sizes
logger
.
info
((
"cudagraph sizes specified by model runner"
" %s is overridden by config %s"
),
sizes_to_specialize
,
self
.
cudagraph_capture_sizes
)
if
self
.
inductor_specialize_for_cudagraph_no_more_than
is
not
None
:
assert
self
.
inductor_compile_sizes
is
None
,
(
"inductor_compile_sizes should be None when "
"inductor_specialize_for_cudagraph_no_more_than is not None"
)
self
.
compile_sizes
=
[
x
for
x
in
self
.
capture_sizes
if
x
<=
self
.
inductor_specialize_for_cudagraph_no_more_than
]
else
:
assert
self
.
inductor_compile_sizes
is
not
None
,
(
"inductor_compile_sizes should not be None when "
"inductor_specialize_for_cudagraph_no_more_than is None"
)
self
.
compile_sizes
=
self
.
inductor_compile_sizes
@
staticmethod
def
select_and_init_config
()
->
"CompilationConfig"
:
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path
=
envs
.
VLLM_TORCH_COMPILE_CONFIG
if
config_path
is
not
None
:
with
open
(
config_path
)
as
json_file
:
config
=
CompilationConfig
.
model_validate_json
(
json_file
.
read
())
else
:
from
vllm.plugins
import
get_compilation_config
predefined_config
=
get_compilation_config
()
config
=
predefined_config
if
predefined_config
is
not
None
else
(
CompilationConfig
())
return
config
@
dataclass
@
dataclass
class
VllmConfig
:
class
VllmConfig
:
"""Dataclass which contains all vllm-related configuration. This
"""Dataclass which contains all vllm-related configuration. This
...
@@ -2073,6 +2254,8 @@ class VllmConfig:
...
@@ -2073,6 +2254,8 @@ class VllmConfig:
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
observability_config
:
Optional
[
ObservabilityConfig
]
=
None
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
prompt_adapter_config
:
Optional
[
PromptAdapterConfig
]
=
None
quant_config
:
Optional
[
QuantizationConfig
]
=
None
quant_config
:
Optional
[
QuantizationConfig
]
=
None
compilation_config
:
CompilationConfig
=
field
(
default
=
None
,
init
=
True
)
# type: ignore
@
staticmethod
@
staticmethod
def
_get_quantization_config
(
def
_get_quantization_config
(
...
@@ -2133,6 +2316,12 @@ class VllmConfig:
...
@@ -2133,6 +2316,12 @@ class VllmConfig:
self
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
self
.
quant_config
=
VllmConfig
.
_get_quantization_config
(
self
.
model_config
,
self
.
load_config
)
self
.
model_config
,
self
.
load_config
)
if
self
.
compilation_config
is
None
:
self
.
compilation_config
=
CompilationConfig
.
select_and_init_config
(
)
current_platform
.
check_and_update_config
(
self
)
def
__str__
(
self
):
def
__str__
(
self
):
return
(
"model=%r, speculative_config=%r, tokenizer=%r, "
return
(
"model=%r, speculative_config=%r, tokenizer=%r, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
"skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, "
...
...
vllm/envs.py
View file @
4fd93750
...
@@ -69,7 +69,6 @@ if TYPE_CHECKING:
...
@@ -69,7 +69,6 @@ if TYPE_CHECKING:
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_TORCH_COMPILE_CONFIG
:
Optional
[
str
]
=
None
VLLM_TORCH_COMPILE_CONFIG
:
Optional
[
str
]
=
None
VLLM_CUSTOM_OPS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_USE_V1
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
...
@@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -217,18 +216,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TORCH_COMPILE_CONFIG"
:
"VLLM_TORCH_COMPILE_CONFIG"
:
lambda
:
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_CONFIG"
,
None
),
lambda
:
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_CONFIG"
,
None
),
# Fine-grained control over which custom ops to enable/disable.
# Use 'all' to enable all, 'none' to disable all.
# Also specify a list of custom op names to enable (prefixed with a '+'),
# or disable (prefixed with a '-').
# Examples:
# - 'all,-op1' to enable all except op1
# - 'none,+op1,+op2' to enable only op1 and op2
# By default, all custom ops are enabled when running without Inductor
# and disabled when running with Inductor (compile_level >= Inductor).
"VLLM_CUSTOM_OPS"
:
lambda
:
os
.
environ
.
get
(
"VLLM_CUSTOM_OPS"
,
""
).
replace
(
" "
,
""
).
split
(
","
),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
"LOCAL_RANK"
:
"LOCAL_RANK"
:
...
...
vllm/model_executor/custom_op.py
View file @
4fd93750
from
functools
import
lru_cache
from
typing
import
Dict
,
Type
from
typing
import
Dict
,
Type
import
torch.nn
as
nn
import
torch.nn
as
nn
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
get_current_vllm_config
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -87,6 +85,8 @@ class CustomOp(nn.Module):
...
@@ -87,6 +85,8 @@ class CustomOp(nn.Module):
@
classmethod
@
classmethod
def
enabled
(
cls
)
->
bool
:
def
enabled
(
cls
)
->
bool
:
# if no name, then it was not registered
# if no name, then it was not registered
compilation_config
=
get_current_vllm_config
().
compilation_config
custom_ops
=
compilation_config
.
custom_ops
if
not
hasattr
(
cls
,
"name"
):
if
not
hasattr
(
cls
,
"name"
):
print_warning_once
(
print_warning_once
(
f
"Custom op
{
cls
.
__name__
}
was not registered, "
f
"Custom op
{
cls
.
__name__
}
was not registered, "
...
@@ -94,22 +94,25 @@ class CustomOp(nn.Module):
...
@@ -94,22 +94,25 @@ class CustomOp(nn.Module):
f
"It will be enabled/disabled based on the global settings."
)
f
"It will be enabled/disabled based on the global settings."
)
return
CustomOp
.
default_on
()
return
CustomOp
.
default_on
()
enabled
=
f
"+
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
enabled
=
f
"+
{
cls
.
name
}
"
in
custom_ops
disabled
=
f
"-
{
cls
.
name
}
"
in
envs
.
VLLM_CUSTOM_OPS
disabled
=
f
"-
{
cls
.
name
}
"
in
custom_ops
assert
not
(
enabled
assert
not
(
enabled
and
disabled
),
f
"Cannot enable and disable
{
cls
.
name
}
"
and
disabled
),
f
"Cannot enable and disable
{
cls
.
name
}
"
return
(
CustomOp
.
default_on
()
or
enabled
)
and
not
disabled
return
(
CustomOp
.
default_on
()
or
enabled
)
and
not
disabled
# On by default if VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.PIECEWISE
# Specifying 'all' or 'none' in VLLM_CUSTOM_OPS takes precedence.
@
staticmethod
@
staticmethod
@
lru_cache
def
default_on
()
->
bool
:
def
default_on
()
->
bool
:
count_none
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"none"
)
"""
count_all
=
envs
.
VLLM_CUSTOM_OPS
.
count
(
"all"
)
On by default if level < CompilationLevel.PIECEWISE
assert
count_none
+
count_all
<=
1
,
"Can only specify 'none' or 'all'"
Specifying 'all' or 'none' in custom_op takes precedence.
return
envs
.
VLLM_TORCH_COMPILE_LEVEL
<
CompilationLevel
.
PIECEWISE
and
\
"""
from
vllm.config
import
CompilationLevel
compilation_config
=
get_current_vllm_config
().
compilation_config
custom_ops
=
compilation_config
.
custom_ops
count_none
=
custom_ops
.
count
(
"none"
)
count_all
=
custom_ops
.
count
(
"all"
)
return
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
and
\
not
count_none
>
0
or
count_all
>
0
not
count_none
>
0
or
count_all
>
0
# Dictionary of all custom ops (classes, indexed by registered name).
# Dictionary of all custom ops (classes, indexed by registered name).
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment