Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e4d652ea
Unverified
Commit
e4d652ea
authored
Oct 10, 2024
by
youkaichao
Committed by
GitHub
Oct 10, 2024
Browse files
[torch.compile] integration with compilation control (#9058)
parent
78c0b416
Changes
22
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
387 additions
and
90 deletions
+387
-90
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+12
-8
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+48
-0
tests/compile/test_full_graph.py
tests/compile/test_full_graph.py
+11
-4
tests/compile/test_full_graph_multi_gpu.py
tests/compile/test_full_graph_multi_gpu.py
+0
-22
tests/compile/test_full_graph_smoke.py
tests/compile/test_full_graph_smoke.py
+0
-13
tests/compile/utils.py
tests/compile/utils.py
+9
-15
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+3
-1
tests/tpu/test_custom_dispatcher.py
tests/tpu/test_custom_dispatcher.py
+8
-5
vllm/compilation/backends.py
vllm/compilation/backends.py
+114
-1
vllm/compilation/compile_context.py
vllm/compilation/compile_context.py
+23
-0
vllm/compilation/decorators.py
vllm/compilation/decorators.py
+85
-0
vllm/compilation/levels.py
vllm/compilation/levels.py
+9
-0
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+24
-3
vllm/envs.py
vllm/envs.py
+3
-13
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+2
-1
vllm/model_executor/models/gemma2.py
vllm/model_executor/models/gemma2.py
+2
-0
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+2
-0
vllm/model_executor/models/llava.py
vllm/model_executor/models/llava.py
+5
-3
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+14
-0
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+13
-1
No files found.
.buildkite/test-pipeline.yaml
View file @
e4d652ea
...
@@ -121,7 +121,9 @@ steps:
...
@@ -121,7 +121,9 @@ steps:
-
vllm/core/
-
vllm/core/
-
tests/distributed
-
tests/distributed
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/spec_decode/e2e/test_integration_dist_tp4
-
tests/compile
commands
:
commands
:
-
pytest -v -s compile/test_basic_correctness.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s distributed/test_pynccl.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
-
pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
...
@@ -231,14 +233,16 @@ steps:
...
@@ -231,14 +233,16 @@ steps:
-
vllm/
-
vllm/
-
tests/compile
-
tests/compile
commands
:
commands
:
-
pytest -v -s compile/test_
full_graph_smoke
.py
-
pytest -v -s compile/test_
basic_correctness
.py
-
label
:
"
PyTorch
Fullgraph
Test"
# 18min
# TODO: re-write in comparison tests, and fix symbolic shape
source_file_dependencies
:
# for quantization ops.
-
vllm/
# - label: "PyTorch Fullgraph Test" # 18min
-
tests/compile
# source_file_dependencies:
commands
:
# - vllm/
-
pytest -v -s compile/test_full_graph.py
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py
-
label
:
Kernels Test %N
# 1h each
-
label
:
Kernels Test %N
# 1h each
mirror_hardwares
:
[
amd
]
mirror_hardwares
:
[
amd
]
...
@@ -394,7 +398,7 @@ steps:
...
@@ -394,7 +398,7 @@ steps:
-
tests/distributed/
-
tests/distributed/
-
vllm/compilation
-
vllm/compilation
commands
:
commands
:
-
pytest -v -s ./compile/test_
full_graph_multi_gpu
.py
-
pytest -v -s ./compile/test_
basic_correctness
.py
-
pytest -v -s ./compile/test_wrapper.py
-
pytest -v -s ./compile/test_wrapper.py
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
-
TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
-
TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
...
...
tests/compile/test_basic_correctness.py
0 → 100644
View file @
e4d652ea
from
typing
import
Dict
,
List
,
Optional
import
pytest
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.utils
import
cuda_device_count_stateless
from
..utils
import
compare_all_settings
# we cannot afford testing the full Catesian product
# of all models and all levels
@
pytest
.
mark
.
parametrize
(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph"
,
[
(
"meta-llama/Meta-Llama-3-8B"
,
[],
2
,
2
,
"FLASH_ATTN"
,
"generate"
,
True
),
(
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples"
,
[
"--quantization"
,
"compressed-tensors"
],
1
,
1
,
"FLASH_ATTN"
,
"generate"
,
True
),
(
"google/gemma-2-2b-it"
,
[],
1
,
2
,
"FLASHINFER"
,
"generate"
,
True
),
# TODO: add multi-modality test for llava
(
"llava-hf/llava-1.5-7b-hf"
,
[],
2
,
1
,
"FLASHINFER"
,
"generate"
,
False
)
])
def
test_compile_correctness
(
model
,
model_args
,
pp_size
,
tp_size
,
attn_backend
,
method
,
fullgraph
):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
if
cuda_device_count_stateless
()
!=
pp_size
*
tp_size
:
pytest
.
skip
(
"Not correct CUDA devices for the test."
)
import
os
os
.
environ
[
"VLLM_ATTENTION_BACKEND"
]
=
attn_backend
if
not
fullgraph
:
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"0"
all_args
=
[[
"--enforce-eager"
]
+
model_args
+
[
"--max_model_len"
,
"1024"
]
+
[
"-pp"
,
str
(
pp_size
)]
+
[
"-tp"
,
str
(
tp_size
)]]
*
3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
all_envs
:
List
[
Optional
[
Dict
[
str
,
str
]]]
=
[{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
level
)
}
for
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
,
]]
compare_all_settings
(
model
,
all_args
,
all_envs
,
method
=
method
)
tests/compile/test_full_graph.py
View file @
e4d652ea
import
pytest
import
pytest
from
vllm.compilation.
backend
s
import
vllm_backend
from
vllm.compilation.
level
s
import
CompilationLevel
from
..utils
import
fork_new_process_for_each_test
from
.utils
import
TEST_MODELS
,
check_full_graph_support
from
.utils
import
TEST_MODELS
,
check_full_graph_support
@
pytest
.
mark
.
parametrize
(
"model_info"
,
TEST_MODELS
)
@
pytest
.
mark
.
parametrize
(
"model_info"
,
TEST_MODELS
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"eager"
,
vllm_backend
])
@
pytest
.
mark
.
parametrize
(
def
test_full_graph
(
model_info
,
backend
):
"optimization_level"
,
[
CompilationLevel
.
DYNAMO_ONCE
,
CompilationLevel
.
INDUCTOR
])
@
fork_new_process_for_each_test
def
test_full_graph
(
model_info
,
optimization_level
):
model
=
model_info
[
0
]
model
=
model_info
[
0
]
model_kwargs
=
model_info
[
1
]
model_kwargs
=
model_info
[
1
]
check_full_graph_support
(
model
,
model_kwargs
,
backend
,
tp_size
=
1
)
check_full_graph_support
(
model
,
model_kwargs
,
optimization_level
,
tp_size
=
1
)
tests/compile/test_full_graph_multi_gpu.py
deleted
100644 → 0
View file @
78c0b416
import
pytest
from
vllm.compilation.backends
import
vllm_backend
from
vllm.utils
import
cuda_device_count_stateless
from
..utils
import
fork_new_process_for_each_test
from
.utils
import
TEST_MODELS_SMOKE
,
check_full_graph_support
@
pytest
.
mark
.
parametrize
(
"model_info"
,
TEST_MODELS_SMOKE
)
@
pytest
.
mark
.
parametrize
(
"tp_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"eager"
,
vllm_backend
])
@
fork_new_process_for_each_test
def
test_full_graph_multi_gpu
(
model_info
,
tp_size
,
backend
):
model
=
model_info
[
0
]
model_kwargs
=
model_info
[
1
]
# Skip the test if there are not enough CUDA devices.
if
cuda_device_count_stateless
()
<
tp_size
:
pytest
.
skip
(
"Not enough CUDA devices for the test."
)
check_full_graph_support
(
model
,
model_kwargs
,
backend
,
tp_size
=
tp_size
)
tests/compile/test_full_graph_smoke.py
deleted
100644 → 0
View file @
78c0b416
import
pytest
from
vllm.compilation.backends
import
vllm_backend
from
.utils
import
TEST_MODELS_SMOKE
,
check_full_graph_support
@
pytest
.
mark
.
parametrize
(
"model_info"
,
TEST_MODELS_SMOKE
)
@
pytest
.
mark
.
parametrize
(
"backend"
,
[
"eager"
,
vllm_backend
])
def
test_full_graph
(
model_info
,
backend
):
model
=
model_info
[
0
]
model_kwargs
=
model_info
[
1
]
check_full_graph_support
(
model
,
model_kwargs
,
backend
,
tp_size
=
1
)
tests/compile/utils.py
View file @
e4d652ea
...
@@ -4,16 +4,9 @@ import torch
...
@@ -4,16 +4,9 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.
plugins
import
set_torch_compile_backend
from
vllm.
compilation.levels
import
CompilationLevel
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
TEST_MODELS_SMOKE
=
[
(
"nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples"
,
{
"quantization"
:
"compressed-tensors"
}),
(
"meta-llama/Meta-Llama-3-8B"
,
{}),
]
TEST_MODELS
=
[
TEST_MODELS
=
[
(
"facebook/opt-125m"
,
{}),
(
"facebook/opt-125m"
,
{}),
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
(
"nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change"
,
{
...
@@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
...
@@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
}))
}))
def
check_full_graph_support
(
model
,
model_kwargs
,
backend
,
tp_size
=
1
):
def
check_full_graph_support
(
model
,
model_kwargs
,
optimization_level
,
tp_size
=
1
):
# make sure these models can be captured in full graph mode
# make sure these models can be captured in full graph mode
if
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
not
in
os
.
environ
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
optimization_level
)
os
.
environ
[
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
]
=
"1"
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization
=
model_kwargs
.
get
(
"quantization"
)
quantization
=
model_kwargs
.
get
(
"quantization"
)
if
(
quantization
==
"fp8"
or
quantization
==
"gptq_marlin"
if
(
quantization
==
"fp8"
or
quantization
==
"gptq_marlin"
or
quantization
==
"gptq_marlin_24"
)
and
backend
!=
"eager"
:
or
quantization
==
"gptq_marlin_24"
)
and
optimization_level
>=
CompilationLevel
.
INDUCTOR
:
return
return
set_torch_compile_backend
(
backend
)
prompts
=
[
prompts
=
[
"Hello, my name is"
,
"Hello, my name is"
,
"The president of the United States is"
,
"The president of the United States is"
,
...
...
tests/tpu/test_compilation.py
View file @
e4d652ea
...
@@ -5,9 +5,11 @@ import tempfile
...
@@ -5,9 +5,11 @@ import tempfile
import
depyf
import
depyf
from
vllm.compilation.levels
import
CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# disable custom dispatcher, let Dynamo takes over
# all the control
# all the control
os
.
environ
[
'VLLM_
DYNAMO_USE_CUSTOM_DISPATCHER'
]
=
"0"
os
.
environ
[
'VLLM_
TORCH_COMPILE_LEVEL'
]
=
str
(
CompilationLevel
.
DYNAMO_AS_IS
)
temp_dir
=
tempfile
.
mkdtemp
()
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
with
depyf
.
prepare_debug
(
temp_dir
):
...
...
tests/tpu/test_custom_dispatcher.py
View file @
e4d652ea
import
os
import
os
from
vllm.compilation.levels
import
CompilationLevel
from
..utils
import
compare_two_settings
from
..utils
import
compare_two_settings
# --enforce-eager on TPU causes graph compilation
# --enforce-eager on TPU causes graph compilation
...
@@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
...
@@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def
test_custom_dispatcher
():
def
test_custom_dispatcher
():
compare_two_settings
(
"google/gemma-2b"
,
compare_two_settings
(
"google/gemma-2b"
,
arg1
=
[
"--enforce-eager"
],
arg1
=
[
"--enforce-eager"
],
arg2
=
[
"--enforce-eager"
],
arg2
=
[
"--enforce-eager"
],
env1
=
{
"VLLM_
DYNAMO_USE_CUSTOM_DISPATCHER"
:
"0"
},
env1
=
{
"VLLM_
TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_ONCE
)
},
env2
=
{
})
env2
=
{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_AS_IS
)
})
vllm/compilation/backends.py
View file @
e4d652ea
import
copy
import
operator
import
operator
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
import
torch
import
torch.fx
as
fx
import
torch.fx
as
fx
from
vllm.logger
import
init_logger
from
.compile_context
import
get_compile_context
from
.levels
import
CompilationLevel
logger
=
init_logger
(
__name__
)
def
fix_functionalization
(
graph
:
fx
.
Graph
):
def
fix_functionalization
(
graph
:
fx
.
Graph
):
"""
"""
...
@@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
...
@@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
def
vllm_backend
(
graph
,
example_inputs
):
def
wrap_inductor
(
graph
,
example_inputs
,
additional_inductor_config
):
from
torch._inductor
import
config
from
torch._inductor
import
config
current_config
=
config
.
shallow_copy_dict
()
current_config
=
config
.
shallow_copy_dict
()
from
torch._inductor.compile_fx
import
compile_fx
from
torch._inductor.compile_fx
import
compile_fx
if
additional_inductor_config
is
not
None
:
current_config
.
update
(
additional_inductor_config
)
if
current_config
[
'post_grad_custom_post_pass'
]
is
not
None
:
logger
.
warning
(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization"
)
current_config
[
'post_grad_custom_post_pass'
]
=
fix_functionalization
current_config
[
'post_grad_custom_post_pass'
]
=
fix_functionalization
return
compile_fx
(
graph
,
example_inputs
,
config_patches
=
current_config
)
return
compile_fx
(
graph
,
example_inputs
,
config_patches
=
current_config
)
def
vllm_backend
(
graph
,
example_inputs
,
additional_inductor_config
:
Optional
[
Dict
]
=
None
)
->
Callable
:
context
=
get_compile_context
()
context
=
copy
.
deepcopy
(
context
)
if
context
is
not
None
else
[]
sizes_to_specialize
:
List
[
int
]
=
context
# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags
:
Dict
[
Tuple
[
int
,
...],
bool
]
=
{}
# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph
:
Dict
[
Tuple
[
int
,
...],
Callable
]
=
{}
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger
.
info
(
"Compiling a graph for general shapes"
)
graph_for_symbolic_shape
=
wrap_inductor
(
graph
,
example_inputs
,
additional_inductor_config
)
# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices
=
[
i
for
i
,
x
in
enumerate
(
example_inputs
)
if
isinstance
(
x
,
torch
.
SymInt
)
]
first_run
=
True
# this is the function we return to Dynamo to run finally
def
compiled_graph_wrapper
(
*
args
):
runtime_shapes
:
Tuple
[
int
,
...]
=
tuple
(
args
[
i
]
for
i
in
sym_shape_indices
)
nonlocal
first_run
nonlocal
runtime_shapes_to_compile_flags
nonlocal
runtime_shapes_to_compiled_graph
if
first_run
:
# the first compilation is for profiling, we directly run it
first_run
=
False
return
graph_for_symbolic_shape
(
*
args
)
if
runtime_shapes
not
in
runtime_shapes_to_compile_flags
:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags
[
runtime_shapes
]
=
runtime_shapes
[
0
]
in
sizes_to_specialize
if
not
runtime_shapes_to_compile_flags
[
runtime_shapes
]:
# we don't need to specialize for this shape
return
graph_for_symbolic_shape
(
*
args
)
if
runtime_shapes
not
in
runtime_shapes_to_compiled_graph
:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger
.
info
(
"Compiling a graph for shapes %s"
,
runtime_shapes
)
runtime_shapes_to_compiled_graph
[
runtime_shapes
]
=
wrap_inductor
(
graph
,
args
,
additional_inductor_config
)
return
runtime_shapes_to_compiled_graph
[
runtime_shapes
](
*
args
)
return
compiled_graph_wrapper
def
select_default_backend
(
level
:
int
)
->
Union
[
str
,
Callable
]:
if
level
in
[
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
]:
backend
=
"eager"
return
backend
assert
level
in
[
CompilationLevel
.
INDUCTOR
,
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
],
f
"Invalid level
{
level
}
"
from
vllm.compilation.backends
import
vllm_backend
from
vllm.plugins
import
get_inductor_additional_configs
additional_configs
=
get_inductor_additional_configs
()
if
level
==
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
:
if
"max_autotune"
in
additional_configs
and
not
additional_configs
[
"max_autotune"
]:
logger
.
warning
(
"max_autotune is disabled, but is overridden by level %s"
,
CompilationLevel
.
INDUCTOR_MAX_AUTOTUNE
)
additional_configs
[
'max_autotune'
]
=
True
from
functools
import
partial
backend
=
partial
(
vllm_backend
,
additional_inductor_config
=
additional_configs
)
return
backend
vllm/compilation/compile_context.py
0 → 100644
View file @
e4d652ea
from
contextlib
import
contextmanager
from
typing
import
Any
_compile_context
:
Any
=
None
def
get_compile_context
()
->
Any
:
"""Get the current compile context."""
return
_compile_context
@
contextmanager
def
set_compile_context
(
context
:
Any
):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global
_compile_context
prev_context
=
_compile_context
_compile_context
=
context
try
:
yield
finally
:
_compile_context
=
prev_context
vllm/compilation/decorators.py
0 → 100644
View file @
e4d652ea
from
typing
import
List
,
Optional
,
Union
import
torch
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.compilation.wrapper
import
TorchCompileWrapperWithCustomDispatcher
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
supports_dynamo
def
support_compile_llama_style
(
cls
:
type
):
"""
A decorator to add support for compiling the forward method of a class.
If a module's **forward signature** is compatible with llama, this
decorator can be used to enable the compilation of the forward method.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if
envs
.
VLLM_TORCH_COMPILE_LEVEL
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
]
or
not
supports_dynamo
():
return
cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls
.
__bases__
=
cls
.
__bases__
+
(
TorchCompileWrapperWithCustomDispatcher
,
)
old_init
=
cls
.
__init__
def
__init__
(
self
,
*
args
,
**
kwargs
):
old_init
(
self
,
*
args
,
**
kwargs
)
TorchCompileWrapperWithCustomDispatcher
.
__init__
(
self
)
cls
.
__init__
=
__init__
def
__call__
(
self
,
input_ids
:
Optional
[
torch
.
Tensor
],
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
],
inputs_embeds
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Union
[
torch
.
Tensor
,
IntermediateTensors
]:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if
torch
.
compiler
.
is_compiling
():
return
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
# the first compilation needs to have dynamic shapes marked
if
len
(
self
.
compiled_codes
)
<
1
:
if
input_ids
is
not
None
:
torch
.
_dynamo
.
mark_dynamic
(
input_ids
,
0
)
torch
.
_dynamo
.
mark_dynamic
(
positions
,
0
)
if
inputs_embeds
is
not
None
:
torch
.
_dynamo
.
mark_dynamic
(
inputs_embeds
,
0
)
if
intermediate_tensors
is
not
None
:
for
tensors
in
intermediate_tensors
.
tensors
.
values
():
torch
.
_dynamo
.
mark_dynamic
(
tensors
,
0
)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if
len
(
self
.
compiled_codes
)
<
1
or
not
self
.
use_custom_dispatcher
:
return
self
.
compiled_callable
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# the Dynamo guard mechanism.
with
self
.
dispatch_to_code
(
0
):
model_output
=
self
.
forward
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
,
intermediate_tensors
,
inputs_embeds
)
return
model_output
cls
.
__call__
=
__call__
return
cls
vllm/compilation/levels.py
0 → 100644
View file @
e4d652ea
# constants for the levels of the compilation process
class
CompilationLevel
:
NO_COMPILATION
=
0
DYNAMO_AS_IS
=
1
DYNAMO_ONCE
=
2
INDUCTOR
=
3
INDUCTOR_MAX_AUTOTUNE
=
4
vllm/compilation/wrapper.py
View file @
e4d652ea
...
@@ -3,12 +3,14 @@ import sys
...
@@ -3,12 +3,14 @@ import sys
from
abc
import
abstractmethod
from
abc
import
abstractmethod
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
types
import
CodeType
from
types
import
CodeType
from
typing
import
Callable
,
List
from
typing
import
Callable
,
List
,
Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
.levels
import
CompilationLevel
class
TorchCompileWrapperWithCustomDispatcher
:
class
TorchCompileWrapperWithCustomDispatcher
:
"""
"""
...
@@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
`torch.compile` over the forward method.
"""
"""
def
__init__
(
self
,
compiled_callable
:
Callable
):
def
__init__
(
self
,
compiled_callable
:
Optional
[
Callable
]
=
None
):
if
compiled_callable
is
None
:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from
vllm.plugins
import
get_torch_compile_backend
backend
=
get_torch_compile_backend
()
if
backend
is
None
:
from
vllm.compilation.backends
import
select_default_backend
backend
=
select_default_backend
(
envs
.
VLLM_TORCH_COMPILE_LEVEL
)
compiled_callable
=
torch
.
compile
(
self
.
forward
,
fullgraph
=
envs
.
VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE
,
backend
=
backend
)
self
.
compiled_callable
=
compiled_callable
self
.
compiled_callable
=
compiled_callable
self
.
original_code_object
=
self
.
__class__
.
forward
.
__code__
self
.
original_code_object
=
self
.
__class__
.
forward
.
__code__
self
.
compiled_codes
:
List
[
CodeType
]
=
[]
self
.
compiled_codes
:
List
[
CodeType
]
=
[]
...
@@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
# and the default Dynamo guard mechanism.
self
.
use_custom_dispatcher
:
bool
=
\
self
.
use_custom_dispatcher
:
bool
=
\
envs
.
VLLM_
DYNAMO_USE_CUSTOM_DISPATCHER
envs
.
VLLM_
TORCH_COMPILE_LEVEL
>=
CompilationLevel
.
DYNAMO_ONCE
def
__call__
(
self
,
*
args
,
**
kwargs
):
def
__call__
(
self
,
*
args
,
**
kwargs
):
"""Implement the dispatch logic here, beyond the torch.compile level.
"""Implement the dispatch logic here, beyond the torch.compile level.
...
...
vllm/envs.py
View file @
e4d652ea
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
...
@@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1
:
bool
=
False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
def
get_default_cache_root
():
def
get_default_cache_root
():
...
@@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
(
"true"
,
"1"
)),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE"
,
"0"
)),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER"
,
"True"
).
lower
()
in
(
"true"
,
"1"
)),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS"
,
"0"
)),
# Internal flag to enable Dynamo fullgraph capture
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
"VLLM_TORCH_COMPILE_LEVEL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_LEVEL"
,
"0"
)),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
...
...
vllm/model_executor/custom_op.py
View file @
e4d652ea
import
torch.nn
as
nn
import
torch.nn
as
nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
from
vllm.utils
import
is_cpu
,
is_hip
,
is_xpu
...
@@ -55,7 +56,7 @@ class CustomOp(nn.Module):
...
@@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
# specific backend. Currently, we do not support dynamic dispatching.
if
envs
.
VLLM_T
EST
_COMPILE_
NO_CUSTOM_OPS
:
if
envs
.
VLLM_T
ORCH
_COMPILE_
LEVEL
>=
CompilationLevel
.
INDUCTOR
:
return
self
.
forward_native
return
self
.
forward_native
if
is_hip
():
if
is_hip
():
...
...
vllm/model_executor/models/gemma2.py
View file @
e4d652ea
...
@@ -21,6 +21,7 @@ from torch import nn
...
@@ -21,6 +21,7 @@ from torch import nn
from
transformers
import
Gemma2Config
from
transformers
import
Gemma2Config
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile_llama_style
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.distributed
import
get_pp_group
,
get_tensor_model_parallel_world_size
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
...
@@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_compile_llama_style
class
Gemma2Model
(
nn
.
Module
):
class
Gemma2Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/llama.py
View file @
e4d652ea
...
@@ -28,6 +28,7 @@ from torch import nn
...
@@ -28,6 +28,7 @@ from torch import nn
from
transformers
import
LlamaConfig
from
transformers
import
LlamaConfig
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.compilation.decorators
import
support_compile_llama_style
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.config
import
CacheConfig
,
LoRAConfig
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
...
@@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
...
@@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
return
hidden_states
,
residual
return
hidden_states
,
residual
@
support_compile_llama_style
class
LlamaModel
(
nn
.
Module
):
class
LlamaModel
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
...
vllm/model_executor/models/llava.py
View file @
e4d652ea
...
@@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
input_ids
=
None
input_ids
=
None
inputs_embeds
=
None
inputs_embeds
=
None
else
:
else
:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
image_input
=
self
.
_parse_and_validate_image_input
(
**
kwargs
)
if
image_input
is
not
None
:
if
image_input
is
not
None
:
...
@@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
...
@@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds
=
merge_multimodal_embeddings
(
inputs_embeds
=
merge_multimodal_embeddings
(
input_ids
,
inputs_embeds
,
vision_embeddings
,
input_ids
,
inputs_embeds
,
vision_embeddings
,
self
.
config
.
image_token_index
)
self
.
config
.
image_token_index
)
input_ids
=
None
else
:
else
:
inputs_embeds
=
None
inputs_embeds
=
self
.
language_model
.
model
.
get_input_embeddings
(
input_ids
)
input_ids
=
None
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
hidden_states
=
self
.
language_model
.
model
(
input_ids
,
positions
,
positions
,
...
...
vllm/platforms/tpu.py
View file @
e4d652ea
import
os
import
torch
import
torch
import
vllm.envs
as
envs
from
vllm.compilation.levels
import
CompilationLevel
from
vllm.plugins
import
set_torch_compile_backend
from
.interface
import
Platform
,
PlatformEnum
from
.interface
import
Platform
,
PlatformEnum
if
"VLLM_TORCH_COMPILE_LEVEL"
not
in
os
.
environ
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
DYNAMO_ONCE
)
assert
envs
.
VLLM_TORCH_COMPILE_LEVEL
<
CompilationLevel
.
INDUCTOR
,
\
"TPU does not support Inductor."
set_torch_compile_backend
(
"openxla"
)
class
TpuPlatform
(
Platform
):
class
TpuPlatform
(
Platform
):
_enum
=
PlatformEnum
.
TPU
_enum
=
PlatformEnum
.
TPU
...
...
vllm/plugins/__init__.py
View file @
e4d652ea
import
logging
import
logging
from
typing
import
Callable
,
Optional
,
Union
from
typing
import
Callable
,
Dict
,
Optional
,
Union
import
vllm.envs
as
envs
import
vllm.envs
as
envs
...
@@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
...
@@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
def
get_torch_compile_backend
()
->
Optional
[
Union
[
Callable
,
str
]]:
def
get_torch_compile_backend
()
->
Optional
[
Union
[
Callable
,
str
]]:
return
_torch_compile_backend
return
_torch_compile_backend
_inductor_additional_configs
:
Dict
=
{}
def
set_inductor_additional_configs
(
configs
:
Dict
):
global
_inductor_additional_configs
_inductor_additional_configs
=
configs
def
get_inductor_additional_configs
()
->
Dict
:
return
_inductor_additional_configs
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment