Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
803f37ea
Unverified
Commit
803f37ea
authored
Nov 19, 2024
by
youkaichao
Committed by
GitHub
Nov 19, 2024
Browse files
[6/N] torch.compile rollout to users (#10437)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
fd9f1249
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
129 additions
and
141 deletions
+129
-141
tests/compile/piecewise/piecewise_compilation_config.json
tests/compile/piecewise/piecewise_compilation_config.json
+0
-5
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+7
-11
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+17
-28
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+9
-4
tests/compile/utils.py
tests/compile/utils.py
+2
-2
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+1
-3
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+35
-12
tests/tpu/test_custom_dispatcher.py
tests/tpu/test_custom_dispatcher.py
+6
-4
vllm/config.py
vllm/config.py
+20
-23
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+23
-6
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-1
vllm/envs.py
vllm/envs.py
+0
-8
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-2
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+1
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-19
No files found.
tests/compile/piecewise/piecewise_compilation_config.json
deleted
100644 → 0
View file @
fd9f1249
{
"use_cudagraph"
:
true
,
"non_cudagraph_ops"
:
[
"silly.attention"
],
"cudagraph_copy_inputs"
:
true
}
\ No newline at end of file
tests/compile/piecewise/test_simple.py
View file @
803f37ea
...
@@ -2,7 +2,6 @@
...
@@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
can exactly calculate the expected output and side effects.
"""
"""
import
os
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
...
@@ -11,7 +10,7 @@ from torch.library import Library
...
@@ -11,7 +10,7 @@ from torch.library import Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.plugins
import
set_current_vllm_config
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
@@ -77,12 +76,12 @@ class SillyModel(nn.Module):
...
@@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def
test_simple_piecewise_compile
():
def
test_simple_piecewise_compile
():
directory
=
os
.
path
.
dirname
(
__file__
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
config
=
os
.
path
.
join
(
directory
,
"piecewise_compilation_config.json"
)
level
=
CompilationLevel
.
PIECEWISE
,
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
use_cudagraph
=
True
,
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
non_cudagraph_ops
=
[
"silly.attention"
],
cudagraph_copy_inputs
=
True
,
vllm_config
=
VllmConfig
(
)
)
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
...
@@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
...
@@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output
=
model
(
input
)
output
=
model
(
input
)
assert
global_counter
==
2
assert
global_counter
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
# clean up to avoid side effects for other tests
del
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
tests/compile/piecewise/test_toy_llama.py
View file @
803f37ea
...
@@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
...
@@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
initialized randomly with a fixed seed.
"""
"""
import
os
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
...
@@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
...
@@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.plugins
import
set_compilation_config
,
set_current_vllm_config
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
# create a library to hold the custom op
...
@@ -254,23 +253,17 @@ def run_model(llama_config,
...
@@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
if
use_compile
:
if
use_compile
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
compilation_config
=
CompilationConfig
(
CompilationLevel
.
PIECEWISE
)
level
=
CompilationLevel
.
PIECEWISE
,
if
split_attn
:
set_compilation_config
(
CompilationConfig
(
use_cudagraph
=
True
,
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
)
))
if
split_attn
:
else
:
compilation_config
.
non_cudagraph_ops
=
[
"silly.attention"
]
set_compilation_config
(
CompilationConfig
(
use_cudagraph
=
True
,
))
else
:
else
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
compilation_config
=
CompilationConfig
(
CompilationLevel
.
NO_COMPILATION
)
level
=
CompilationLevel
.
NO_COMPILATION
,
)
set_compilation_config
(
None
)
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
@@ -288,10 +281,6 @@ def run_model(llama_config,
...
@@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids
[:
2
].
zero_
()
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
# manual cleanup
del
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
set_compilation_config
(
None
)
output
=
output
.
cpu
()
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
if
llama_config
.
tractable_init
:
...
@@ -361,7 +350,6 @@ def test_toy_llama():
...
@@ -361,7 +350,6 @@ def test_toy_llama():
@
torch
.
inference_mode
@
torch
.
inference_mode
def
benchmark
():
def
benchmark
():
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
from
triton.testing
import
do_bench
from
triton.testing
import
do_bench
# similar to llama 3.1-8B
# similar to llama 3.1-8B
...
@@ -387,15 +375,16 @@ def benchmark():
...
@@ -387,15 +375,16 @@ def benchmark():
for
piecewise
in
[
False
,
True
]:
for
piecewise
in
[
False
,
True
]:
if
piecewise
:
if
piecewise
:
set_
compilation_config
(
compilation_config
=
CompilationConfig
(
Compilation
Config
(
level
=
Compilation
Level
.
PIECEWISE
,
use_cudagraph
=
True
,
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
non_cudagraph_ops
=
[
"silly.attention"
],
)
)
)
else
:
else
:
set_compilation_config
(
None
)
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
)
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
vllm_config
=
vllm_config
,
...
...
tests/compile/test_basic_correctness.py
View file @
803f37ea
...
@@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
...
@@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args
=
[
"--enforce-eager"
]
+
model_args
+
[
"-pp"
,
str
(
pp_size
)]
+
\
final_args
=
[
"--enforce-eager"
]
+
model_args
+
[
"-pp"
,
str
(
pp_size
)]
+
\
[
"-tp"
,
str
(
tp_size
)]
[
"-tp"
,
str
(
tp_size
)]
all_args
:
List
[
List
[
str
]]
=
[]
all_envs
:
List
[
Optional
[
Dict
[
str
,
str
]]]
=
[]
all_envs
:
List
[
Optional
[
Dict
[
str
,
str
]]]
=
[]
for
level
in
[
for
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
PIECEWISE
,
CompilationLevel
.
PIECEWISE
,
]:
]:
all_envs
.
append
({
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
level
)})
all_args
.
append
(
final_args
+
[
"-O"
,
str
(
level
)])
all_envs
.
append
({})
# inductor will change the output, so we only compare if the output
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
# is close, not exactly the same.
compare_all_settings
(
compare_all_settings
(
model
,
[
final_args
]
*
2
,
model
,
all_args
,
all_envs
,
all_envs
,
method
=
method
if
method
!=
"generate"
else
"generate_close"
)
method
=
method
if
method
!=
"generate"
else
"generate_close"
)
all_envs
.
clear
()
all_envs
.
clear
()
all_args
.
clear
()
for
level
in
[
for
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
,
CompilationLevel
.
DYNAMO_ONCE
,
]:
]:
all_envs
.
append
({
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
level
)})
all_args
.
append
(
final_args
+
[
"-O"
,
str
(
level
)])
all_envs
.
append
({})
if
level
!=
CompilationLevel
.
DYNAMO_ONCE
and
not
fullgraph
:
if
level
!=
CompilationLevel
.
DYNAMO_ONCE
and
not
fullgraph
:
# "DYNAMO_ONCE" will always use fullgraph
# "DYNAMO_ONCE" will always use fullgraph
all_envs
[
-
1
][
all_envs
[
-
1
][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"0"
# type: ignore
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"0"
# type: ignore
compare_all_settings
(
model
,
[
fin
al_args
]
*
3
,
all_envs
,
method
=
method
)
compare_all_settings
(
model
,
a
l
l_args
*
3
,
all_envs
,
method
=
method
)
tests/compile/utils.py
View file @
803f37ea
...
@@ -4,7 +4,7 @@ import torch
...
@@ -4,7 +4,7 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
TEST_MODELS
=
[
TEST_MODELS
=
[
...
@@ -65,7 +65,6 @@ def check_full_graph_support(model,
...
@@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level
,
optimization_level
,
tp_size
=
1
):
tp_size
=
1
):
# make sure these models can be captured in full graph mode
# make sure these models can be captured in full graph mode
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
optimization_level
)
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
# The base meta llama uses too much memory.
# The base meta llama uses too much memory.
...
@@ -86,6 +85,7 @@ def check_full_graph_support(model,
...
@@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager
=
True
,
enforce_eager
=
True
,
tensor_parallel_size
=
tp_size
,
tensor_parallel_size
=
tp_size
,
disable_custom_all_reduce
=
True
,
disable_custom_all_reduce
=
True
,
compilation_config
=
CompilationConfig
(
level
=
optimization_level
),
**
model_kwargs
)
**
model_kwargs
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
803f37ea
import
os
from
typing
import
List
from
typing
import
List
import
pytest
import
pytest
...
@@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
...
@@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
])
])
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
default_on
:
bool
):
default_on
:
bool
):
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
torch_level
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
env
.
split
(
","
)))
level
=
torch_level
,
custom_ops
=
env
.
split
(
","
)))
with
set_current_vllm_config
(
vllm_config
):
with
set_current_vllm_config
(
vllm_config
):
assert
CustomOp
.
default_on
()
==
default_on
assert
CustomOp
.
default_on
()
==
default_on
...
...
tests/tpu/test_compilation.py
View file @
803f37ea
import
glob
import
glob
import
os
import
os
import
runpy
import
tempfile
import
tempfile
import
depyf
import
depyf
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control
os
.
environ
[
'VLLM_TORCH_COMPILE_LEVEL'
]
=
str
(
CompilationLevel
.
DYNAMO_AS_IS
)
temp_dir
=
tempfile
.
mkdtemp
()
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
with
depyf
.
prepare_debug
(
temp_dir
):
cur_dir
=
os
.
path
.
dirname
(
__file__
)
from
vllm
import
LLM
,
SamplingParams
parent_dir
=
os
.
path
.
dirname
(
cur_dir
)
root_dir
=
os
.
path
.
dirname
(
parent_dir
)
prompts
=
[
example_file
=
os
.
path
.
join
(
root_dir
,
"examples"
,
"A robot may not injure a human being"
,
"offline_inference_tpu.py"
)
"It is only with the heart that one can see rightly;"
,
runpy
.
run_path
(
example_file
)
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, through inaction, allow a human being to come to harm."
,
" what is essential is invisible to the eye."
,
" but in rising every time we fall."
,
]
N
=
1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
1.0
,
n
=
N
,
max_tokens
=
16
)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm
=
LLM
(
model
=
"google/gemma-2b"
,
enforce_eager
=
True
,
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
DYNAMO_AS_IS
))
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
generated_text
.
startswith
(
answer
)
compiled_code
=
sorted
(
compiled_code
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
...
...
tests/tpu/test_custom_dispatcher.py
View file @
803f37ea
...
@@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
...
@@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def
test_custom_dispatcher
():
def
test_custom_dispatcher
():
compare_two_settings
(
compare_two_settings
(
"google/gemma-2b"
,
"google/gemma-2b"
,
arg1
=
[
"--enforce-eager"
],
arg1
=
[
"--enforce-eager"
,
"-O"
,
arg2
=
[
"--enforce-eager"
],
str
(
CompilationLevel
.
DYNAMO_ONCE
)],
env1
=
{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_ONCE
)},
arg2
=
[
"--enforce-eager"
,
"-O"
,
env2
=
{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_AS_IS
)})
str
(
CompilationLevel
.
DYNAMO_AS_IS
)],
env1
=
{},
env2
=
{})
vllm/config.py
View file @
803f37ea
...
@@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
...
@@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"CompilationConfig"
:
"""Parse the CLI value for the compilation config."""
if
cli_value
in
[
"0"
,
"1"
,
"2"
,
"3"
]:
return
cls
(
level
=
int
(
cli_value
))
return
CompilationConfig
.
model_validate_json
(
cli_value
)
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
self
.
level
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
count_none
=
self
.
custom_ops
.
count
(
"none"
)
count_none
=
self
.
custom_ops
.
count
(
"none"
)
count_all
=
self
.
custom_ops
.
count
(
"all"
)
count_all
=
self
.
custom_ops
.
count
(
"all"
)
...
@@ -2249,26 +2255,6 @@ class CompilationConfig(BaseModel):
...
@@ -2249,26 +2255,6 @@ class CompilationConfig(BaseModel):
"inductor_specialize_for_cudagraph_no_more_than is None"
)
"inductor_specialize_for_cudagraph_no_more_than is None"
)
self
.
compile_sizes
=
self
.
inductor_compile_sizes
self
.
compile_sizes
=
self
.
inductor_compile_sizes
@
staticmethod
def
select_and_init_config
()
->
"CompilationConfig"
:
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path
=
envs
.
VLLM_TORCH_COMPILE_CONFIG
if
config_path
is
not
None
:
with
open
(
config_path
)
as
json_file
:
config
=
CompilationConfig
.
model_validate_json
(
json_file
.
read
())
else
:
from
vllm.plugins
import
get_compilation_config
predefined_config
=
get_compilation_config
()
config
=
predefined_config
if
predefined_config
is
not
None
else
(
CompilationConfig
())
return
config
@
dataclass
@
dataclass
class
VllmConfig
:
class
VllmConfig
:
...
@@ -2354,8 +2340,19 @@ class VllmConfig:
...
@@ -2354,8 +2340,19 @@ class VllmConfig:
self
.
model_config
,
self
.
load_config
)
self
.
model_config
,
self
.
load_config
)
if
self
.
compilation_config
is
None
:
if
self
.
compilation_config
is
None
:
self
.
compilation_config
=
CompilationConfig
.
select_and_init_config
(
self
.
compilation_config
=
CompilationConfig
()
)
if
envs
.
VLLM_USE_V1
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
self
.
compilation_config
.
custom_ops
=
[
"none"
]
self
.
compilation_config
.
use_cudagraph
=
True
self
.
compilation_config
.
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
]
self
.
compilation_config
.
use_inductor
=
True
self
.
compilation_config
.
enable_fusion
=
False
current_platform
.
check_and_update_config
(
self
)
current_platform
.
check_and_update_config
(
self
)
...
...
vllm/engine/arg_utils.py
View file @
803f37ea
...
@@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
...
@@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
ConfigFormat
,
DecodingConfig
,
from
vllm.config
import
(
CacheConfig
,
CompilationConfig
,
ConfigFormat
,
DeviceConfig
,
HfOverrides
,
LoadConfig
,
LoadFormat
,
DecodingConfig
,
DeviceConfig
,
HfOverrides
,
LoadConfig
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ParallelConfig
,
PoolerConfig
,
PromptAdapterConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TaskOption
,
PromptAdapterConfig
,
SchedulerConfig
,
TokenizerPoolConfig
,
VllmConfig
)
SpeculativeConfig
,
TaskOption
,
TokenizerPoolConfig
,
VllmConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
...
@@ -189,6 +190,7 @@ class EngineArgs:
...
@@ -189,6 +190,7 @@ class EngineArgs:
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
compilation_config
:
Optional
[
CompilationConfig
]
=
None
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
not
self
.
tokenizer
:
if
not
self
.
tokenizer
:
...
@@ -868,6 +870,20 @@ class EngineArgs:
...
@@ -868,6 +870,20 @@ class EngineArgs:
help
=
"Override or set the pooling method in the embedding model. "
help
=
"Override or set the pooling method in the embedding model. "
"e.g. {
\"
pooling_type
\"
:
\"
mean
\"
,
\"
normalize
\"
: false}.'"
)
"e.g. {
\"
pooling_type
\"
:
\"
mean
\"
,
\"
normalize
\"
: false}.'"
)
parser
.
add_argument
(
'--compilation-config'
,
'-O'
,
type
=
CompilationConfig
.
from_cli
,
default
=
None
,
help
=
'torch.compile configuration for the model.'
'When it is a number (0, 1, 2, 3), it will be '
'interpreted as the optimization level.
\n
'
'NOTE: level 0 is the default level without '
'any optimization. level 1 and 2 are for internal '
'testing only. level 3 is the recommended level '
'for production.
\n
'
'To specify the full compilation config, '
'use a JSON string.'
)
return
parser
return
parser
@
classmethod
@
classmethod
...
@@ -1142,6 +1158,7 @@ class EngineArgs:
...
@@ -1142,6 +1158,7 @@ class EngineArgs:
decoding_config
=
decoding_config
,
decoding_config
=
decoding_config
,
observability_config
=
observability_config
,
observability_config
=
observability_config
,
prompt_adapter_config
=
prompt_adapter_config
,
prompt_adapter_config
=
prompt_adapter_config
,
compilation_config
=
self
.
compilation_config
,
)
)
...
...
vllm/engine/llm_engine.py
View file @
803f37ea
...
@@ -262,7 +262,8 @@ class LLMEngine:
...
@@ -262,7 +262,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s, pooler_config=%r)"
,
"mm_processor_kwargs=%s, pooler_config=%r,"
"compilation_config=%r"
,
VLLM_VERSION
,
VLLM_VERSION
,
model_config
.
model
,
model_config
.
model
,
speculative_config
,
speculative_config
,
...
@@ -297,6 +298,7 @@ class LLMEngine:
...
@@ -297,6 +298,7 @@ class LLMEngine:
use_cached_outputs
,
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
model_config
.
mm_processor_kwargs
,
model_config
.
pooler_config
,
model_config
.
pooler_config
,
vllm_config
.
compilation_config
,
)
)
# TODO(woosuk): Print more configs in debug mode.
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
self
.
model_config
=
model_config
...
...
vllm/envs.py
View file @
803f37ea
...
@@ -67,8 +67,6 @@ if TYPE_CHECKING:
...
@@ -67,8 +67,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_TORCH_COMPILE_CONFIG
:
Optional
[
str
]
=
None
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_USE_V1
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
...
@@ -209,12 +207,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
...
@@ -209,12 +207,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
"VLLM_TORCH_COMPILE_LEVEL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_LEVEL"
,
"0"
)),
# Path to the config file for torch compile
"VLLM_TORCH_COMPILE_CONFIG"
:
lambda
:
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_CONFIG"
,
None
),
# local rank of the process in the distributed setting, used to determine
# local rank of the process in the distributed setting, used to determine
# the GPU device id
# the GPU device id
...
...
vllm/platforms/tpu.py
View file @
803f37ea
import
os
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
import
torch
import
torch
...
@@ -40,7 +39,8 @@ class TpuPlatform(Platform):
...
@@ -40,7 +39,8 @@ class TpuPlatform(Platform):
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationLevel
compilation_config
=
vllm_config
.
compilation_config
compilation_config
=
vllm_config
.
compilation_config
if
"VLLM_TORCH_COMPILE_LEVEL"
not
in
os
.
environ
:
if
compilation_config
.
level
==
CompilationLevel
.
NO_COMPILATION
:
# TPU does not support NO_COMPILATION
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
"TPU does not support Inductor."
...
...
vllm/plugins/__init__.py
View file @
803f37ea
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import
vllm.envs
as
envs
import
vllm.envs
as
envs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -54,18 +54,6 @@ def load_general_plugins():
...
@@ -54,18 +54,6 @@ def load_general_plugins():
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
_compilation_config
:
Optional
[
"CompilationConfig"
]
=
None
def
set_compilation_config
(
config
:
Optional
[
"CompilationConfig"
]):
global
_compilation_config
_compilation_config
=
config
def
get_compilation_config
()
->
Optional
[
"CompilationConfig"
]:
return
_compilation_config
_current_vllm_config
:
Optional
[
"VllmConfig"
]
=
None
_current_vllm_config
:
Optional
[
"VllmConfig"
]
=
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
803f37ea
...
@@ -8,13 +8,12 @@ import torch.distributed
...
@@ -8,13 +8,12 @@ import torch.distributed
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.plugins
import
set_compilation_config
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
cdiv
,
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
cdiv
,
is_pin_memory_available
)
is_pin_memory_available
)
...
@@ -508,20 +507,6 @@ class GPUModelRunner:
...
@@ -508,20 +507,6 @@ class GPUModelRunner:
return
model_runner_output
return
model_runner_output
def
load_model
(
self
)
->
None
:
def
load_model
(
self
)
->
None
:
if
self
.
use_cuda_graph
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
set_compilation_config
(
CompilationConfig
(
custom_ops
=
[
"none"
],
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
],
use_inductor
=
True
,
enable_fusion
=
False
,
))
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
...
@@ -562,9 +547,8 @@ class GPUModelRunner:
...
@@ -562,9 +547,8 @@ class GPUModelRunner:
def
capture_model
(
self
)
->
None
:
def
capture_model
(
self
)
->
None
:
if
not
self
.
use_cuda_graph
:
if
not
self
.
use_cuda_graph
:
logger
.
warning
(
logger
.
warning
(
"Skipping CUDA graph capture. Please set "
"Skipping CUDA graph capture. Please add "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs."
,
"-O 3 to use CUDA graphs."
,
CompilationLevel
.
PIECEWISE
)
CompilationLevel
.
PIECEWISE
)
return
return
start_time
=
time
.
perf_counter
()
start_time
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment