Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
803f37ea
Unverified
Commit
803f37ea
authored
Nov 19, 2024
by
youkaichao
Committed by
GitHub
Nov 19, 2024
Browse files
[6/N] torch.compile rollout to users (#10437)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
fd9f1249
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
129 additions
and
141 deletions
+129
-141
tests/compile/piecewise/piecewise_compilation_config.json
tests/compile/piecewise/piecewise_compilation_config.json
+0
-5
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+7
-11
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+17
-28
tests/compile/test_basic_correctness.py
tests/compile/test_basic_correctness.py
+9
-4
tests/compile/utils.py
tests/compile/utils.py
+2
-2
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+1
-3
tests/tpu/test_compilation.py
tests/tpu/test_compilation.py
+35
-12
tests/tpu/test_custom_dispatcher.py
tests/tpu/test_custom_dispatcher.py
+6
-4
vllm/config.py
vllm/config.py
+20
-23
vllm/engine/arg_utils.py
vllm/engine/arg_utils.py
+23
-6
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+3
-1
vllm/envs.py
vllm/envs.py
+0
-8
vllm/platforms/tpu.py
vllm/platforms/tpu.py
+2
-2
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+1
-13
vllm/v1/worker/gpu_model_runner.py
vllm/v1/worker/gpu_model_runner.py
+3
-19
No files found.
tests/compile/piecewise/piecewise_compilation_config.json
deleted
100644 → 0
View file @
fd9f1249
{
"use_cudagraph"
:
true
,
"non_cudagraph_ops"
:
[
"silly.attention"
],
"cudagraph_copy_inputs"
:
true
}
\ No newline at end of file
tests/compile/piecewise/test_simple.py
View file @
803f37ea
...
...
@@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import
os
import
torch
from
torch
import
nn
...
...
@@ -11,7 +10,7 @@ from torch.library import Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
...
...
@@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def
test_simple_piecewise_compile
():
directory
=
os
.
path
.
dirname
(
__file__
)
config
=
os
.
path
.
join
(
directory
,
"piecewise_compilation_config.json"
)
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
=
config
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
vllm_config
=
VllmConfig
(
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
cudagraph_copy_inputs
=
True
,
)
)
with
set_current_vllm_config
(
vllm_config
):
model
=
SillyModel
(
vllm_config
=
vllm_config
,
prefix
=
''
)
...
...
@@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output
=
model
(
input
)
assert
global_counter
==
2
assert
torch
.
allclose
(
output
.
cpu
(),
torch
.
tensor
([
3.
,
1.
]))
# clean up to avoid side effects for other tests
del
os
.
environ
[
"VLLM_TORCH_COMPILE_CONFIG"
]
tests/compile/piecewise/test_toy_llama.py
View file @
803f37ea
...
...
@@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import
os
from
dataclasses
import
dataclass
from
typing
import
Optional
,
Tuple
...
...
@@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.plugins
import
set_compilation_config
,
set_current_vllm_config
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
...
...
@@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn
:
bool
=
False
)
->
torch
.
Tensor
:
if
use_compile
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
use_cudagraph
=
True
,
)
if
split_attn
:
set_compilation_config
(
CompilationConfig
(
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
))
else
:
set_compilation_config
(
CompilationConfig
(
use_cudagraph
=
True
,
))
compilation_config
.
non_cudagraph_ops
=
[
"silly.attention"
]
else
:
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
NO_COMPILATION
)
set_compilation_config
(
None
)
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
NO_COMPILATION
,
)
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
...
...
@@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids
[:
2
].
zero_
()
output
=
model
(
input_ids
[:
2
],
positions
[:
2
])
# manual cleanup
del
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
set_compilation_config
(
None
)
output
=
output
.
cpu
()
if
llama_config
.
tractable_init
:
...
...
@@ -361,7 +350,6 @@ def test_toy_llama():
@
torch
.
inference_mode
def
benchmark
():
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
CompilationLevel
.
PIECEWISE
)
from
triton.testing
import
do_bench
# similar to llama 3.1-8B
...
...
@@ -387,15 +375,16 @@ def benchmark():
for
piecewise
in
[
False
,
True
]:
if
piecewise
:
set_
compilation_config
(
Compilation
Config
(
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
)
)
compilation_config
=
CompilationConfig
(
level
=
Compilation
Level
.
PIECEWISE
,
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"silly.attention"
],
)
else
:
set_compilation_config
(
None
)
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
PIECEWISE
,
)
vllm_config
=
VllmConfig
()
vllm_config
=
VllmConfig
(
compilation_config
=
compilation_config
)
with
set_current_vllm_config
(
vllm_config
):
model
=
LlamaModel
(
config
=
llama_config
,
vllm_config
=
vllm_config
,
...
...
tests/compile/test_basic_correctness.py
View file @
803f37ea
...
...
@@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args
=
[
"--enforce-eager"
]
+
model_args
+
[
"-pp"
,
str
(
pp_size
)]
+
\
[
"-tp"
,
str
(
tp_size
)]
all_args
:
List
[
List
[
str
]]
=
[]
all_envs
:
List
[
Optional
[
Dict
[
str
,
str
]]]
=
[]
for
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
PIECEWISE
,
]:
all_envs
.
append
({
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
level
)})
all_args
.
append
(
final_args
+
[
"-O"
,
str
(
level
)])
all_envs
.
append
({})
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings
(
model
,
[
final_args
]
*
2
,
model
,
all_args
,
all_envs
,
method
=
method
if
method
!=
"generate"
else
"generate_close"
)
all_envs
.
clear
()
all_args
.
clear
()
for
level
in
[
CompilationLevel
.
NO_COMPILATION
,
CompilationLevel
.
DYNAMO_AS_IS
,
CompilationLevel
.
DYNAMO_ONCE
,
]:
all_envs
.
append
({
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
level
)})
all_args
.
append
(
final_args
+
[
"-O"
,
str
(
level
)])
all_envs
.
append
({})
if
level
!=
CompilationLevel
.
DYNAMO_ONCE
and
not
fullgraph
:
# "DYNAMO_ONCE" will always use fullgraph
all_envs
[
-
1
][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"0"
# type: ignore
compare_all_settings
(
model
,
[
fin
al_args
]
*
3
,
all_envs
,
method
=
method
)
compare_all_settings
(
model
,
a
l
l_args
*
3
,
all_envs
,
method
=
method
)
tests/compile/utils.py
View file @
803f37ea
...
...
@@ -4,7 +4,7 @@ import torch
from
tests.quantization.utils
import
is_quant_method_supported
from
vllm
import
LLM
,
SamplingParams
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationConfig
,
CompilationLevel
from
vllm.platforms
import
current_platform
TEST_MODELS
=
[
...
...
@@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level
,
tp_size
=
1
):
# make sure these models can be captured in full graph mode
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
optimization_level
)
os
.
environ
[
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
]
=
"1"
# The base meta llama uses too much memory.
...
...
@@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager
=
True
,
tensor_parallel_size
=
tp_size
,
disable_custom_all_reduce
=
True
,
compilation_config
=
CompilationConfig
(
level
=
optimization_level
),
**
model_kwargs
)
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
803f37ea
import
os
from
typing
import
List
import
pytest
...
...
@@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
])
def
test_enabled_ops
(
env
:
str
,
torch_level
:
int
,
ops_enabled
:
List
[
int
],
default_on
:
bool
):
os
.
environ
[
"VLLM_TORCH_COMPILE_LEVEL"
]
=
str
(
torch_level
)
vllm_config
=
VllmConfig
(
compilation_config
=
CompilationConfig
(
custom_ops
=
env
.
split
(
","
)))
level
=
torch_level
,
custom_ops
=
env
.
split
(
","
)))
with
set_current_vllm_config
(
vllm_config
):
assert
CustomOp
.
default_on
()
==
default_on
...
...
tests/tpu/test_compilation.py
View file @
803f37ea
import
glob
import
os
import
runpy
import
tempfile
import
depyf
from
vllm.config
import
CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control
os
.
environ
[
'VLLM_TORCH_COMPILE_LEVEL'
]
=
str
(
CompilationLevel
.
DYNAMO_AS_IS
)
from
vllm.config
import
CompilationConfig
,
CompilationLevel
temp_dir
=
tempfile
.
mkdtemp
()
with
depyf
.
prepare_debug
(
temp_dir
):
cur_dir
=
os
.
path
.
dirname
(
__file__
)
parent_dir
=
os
.
path
.
dirname
(
cur_dir
)
root_dir
=
os
.
path
.
dirname
(
parent_dir
)
example_file
=
os
.
path
.
join
(
root_dir
,
"examples"
,
"offline_inference_tpu.py"
)
runpy
.
run_path
(
example_file
)
from
vllm
import
LLM
,
SamplingParams
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
"The greatest glory in living lies not in never falling,"
,
]
answers
=
[
" or, through inaction, allow a human being to come to harm."
,
" what is essential is invisible to the eye."
,
" but in rising every time we fall."
,
]
N
=
1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
top_p
=
1.0
,
n
=
N
,
max_tokens
=
16
)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm
=
LLM
(
model
=
"google/gemma-2b"
,
enforce_eager
=
True
,
compilation_config
=
CompilationConfig
(
level
=
CompilationLevel
.
DYNAMO_AS_IS
))
outputs
=
llm
.
generate
(
prompts
,
sampling_params
)
for
output
,
answer
in
zip
(
outputs
,
answers
):
prompt
=
output
.
prompt
generated_text
=
output
.
outputs
[
0
].
text
print
(
f
"Prompt:
{
prompt
!
r
}
, Generated text:
{
generated_text
!
r
}
"
)
assert
generated_text
.
startswith
(
answer
)
compiled_code
=
sorted
(
glob
.
glob
(
os
.
path
.
join
(
temp_dir
,
"__transformed_code*.py"
)))
...
...
tests/tpu/test_custom_dispatcher.py
View file @
803f37ea
...
...
@@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def
test_custom_dispatcher
():
compare_two_settings
(
"google/gemma-2b"
,
arg1
=
[
"--enforce-eager"
],
arg2
=
[
"--enforce-eager"
],
env1
=
{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_ONCE
)},
env2
=
{
"VLLM_TORCH_COMPILE_LEVEL"
:
str
(
CompilationLevel
.
DYNAMO_AS_IS
)})
arg1
=
[
"--enforce-eager"
,
"-O"
,
str
(
CompilationLevel
.
DYNAMO_ONCE
)],
arg2
=
[
"--enforce-eager"
,
"-O"
,
str
(
CompilationLevel
.
DYNAMO_AS_IS
)],
env1
=
{},
env2
=
{})
vllm/config.py
View file @
803f37ea
...
...
@@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
enabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
disabled_custom_ops
:
Counter
[
str
]
=
PrivateAttr
@
classmethod
def
from_cli
(
cls
,
cli_value
:
str
)
->
"CompilationConfig"
:
"""Parse the CLI value for the compilation config."""
if
cli_value
in
[
"0"
,
"1"
,
"2"
,
"3"
]:
return
cls
(
level
=
int
(
cli_value
))
return
CompilationConfig
.
model_validate_json
(
cli_value
)
def
model_post_init
(
self
,
__context
:
Any
)
->
None
:
self
.
level
=
envs
.
VLLM_TORCH_COMPILE_LEVEL
count_none
=
self
.
custom_ops
.
count
(
"none"
)
count_all
=
self
.
custom_ops
.
count
(
"all"
)
...
...
@@ -2249,26 +2255,6 @@ class CompilationConfig(BaseModel):
"inductor_specialize_for_cudagraph_no_more_than is None"
)
self
.
compile_sizes
=
self
.
inductor_compile_sizes
@
staticmethod
def
select_and_init_config
()
->
"CompilationConfig"
:
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path
=
envs
.
VLLM_TORCH_COMPILE_CONFIG
if
config_path
is
not
None
:
with
open
(
config_path
)
as
json_file
:
config
=
CompilationConfig
.
model_validate_json
(
json_file
.
read
())
else
:
from
vllm.plugins
import
get_compilation_config
predefined_config
=
get_compilation_config
()
config
=
predefined_config
if
predefined_config
is
not
None
else
(
CompilationConfig
())
return
config
@
dataclass
class
VllmConfig
:
...
...
@@ -2354,8 +2340,19 @@ class VllmConfig:
self
.
model_config
,
self
.
load_config
)
if
self
.
compilation_config
is
None
:
self
.
compilation_config
=
CompilationConfig
.
select_and_init_config
(
)
self
.
compilation_config
=
CompilationConfig
()
if
envs
.
VLLM_USE_V1
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
self
.
compilation_config
.
custom_ops
=
[
"none"
]
self
.
compilation_config
.
use_cudagraph
=
True
self
.
compilation_config
.
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
]
self
.
compilation_config
.
use_inductor
=
True
self
.
compilation_config
.
enable_fusion
=
False
current_platform
.
check_and_update_config
(
self
)
...
...
vllm/engine/arg_utils.py
View file @
803f37ea
...
...
@@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
import
torch
import
vllm.envs
as
envs
from
vllm.config
import
(
CacheConfig
,
ConfigFormat
,
DecodingConfig
,
DeviceConfig
,
HfOverrides
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TaskOption
,
TokenizerPoolConfig
,
VllmConfig
)
from
vllm.config
import
(
CacheConfig
,
CompilationConfig
,
ConfigFormat
,
DecodingConfig
,
DeviceConfig
,
HfOverrides
,
LoadConfig
,
LoadFormat
,
LoRAConfig
,
ModelConfig
,
ObservabilityConfig
,
ParallelConfig
,
PoolerConfig
,
PromptAdapterConfig
,
SchedulerConfig
,
SpeculativeConfig
,
TaskOption
,
TokenizerPoolConfig
,
VllmConfig
)
from
vllm.executor.executor_base
import
ExecutorBase
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
...
...
@@ -189,6 +190,7 @@ class EngineArgs:
override_neuron_config
:
Optional
[
Dict
[
str
,
Any
]]
=
None
override_pooler_config
:
Optional
[
PoolerConfig
]
=
None
compilation_config
:
Optional
[
CompilationConfig
]
=
None
def
__post_init__
(
self
):
if
not
self
.
tokenizer
:
...
...
@@ -868,6 +870,20 @@ class EngineArgs:
help
=
"Override or set the pooling method in the embedding model. "
"e.g. {
\"
pooling_type
\"
:
\"
mean
\"
,
\"
normalize
\"
: false}.'"
)
parser
.
add_argument
(
'--compilation-config'
,
'-O'
,
type
=
CompilationConfig
.
from_cli
,
default
=
None
,
help
=
'torch.compile configuration for the model.'
'When it is a number (0, 1, 2, 3), it will be '
'interpreted as the optimization level.
\n
'
'NOTE: level 0 is the default level without '
'any optimization. level 1 and 2 are for internal '
'testing only. level 3 is the recommended level '
'for production.
\n
'
'To specify the full compilation config, '
'use a JSON string.'
)
return
parser
@
classmethod
...
...
@@ -1142,6 +1158,7 @@ class EngineArgs:
decoding_config
=
decoding_config
,
observability_config
=
observability_config
,
prompt_adapter_config
=
prompt_adapter_config
,
compilation_config
=
self
.
compilation_config
,
)
...
...
vllm/engine/llm_engine.py
View file @
803f37ea
...
...
@@ -262,7 +262,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s, pooler_config=%r)"
,
"mm_processor_kwargs=%s, pooler_config=%r,"
"compilation_config=%r"
,
VLLM_VERSION
,
model_config
.
model
,
speculative_config
,
...
...
@@ -297,6 +298,7 @@ class LLMEngine:
use_cached_outputs
,
model_config
.
mm_processor_kwargs
,
model_config
.
pooler_config
,
vllm_config
.
compilation_config
,
)
# TODO(woosuk): Print more configs in debug mode.
self
.
model_config
=
model_config
...
...
vllm/envs.py
View file @
803f37ea
...
...
@@ -67,8 +67,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ
:
bool
=
False
VLLM_ALLOW_RUNTIME_LORA_UPDATING
:
bool
=
False
VLLM_SKIP_P2P_CHECK
:
bool
=
False
VLLM_TORCH_COMPILE_LEVEL
:
int
=
0
VLLM_TORCH_COMPILE_CONFIG
:
Optional
[
str
]
=
None
VLLM_DISABLED_KERNELS
:
List
[
str
]
=
[]
VLLM_USE_V1
:
bool
=
False
VLLM_ENABLE_V1_MULTIPROCESSING
:
bool
=
False
...
...
@@ -209,12 +207,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
:
lambda
:
bool
(
os
.
environ
.
get
(
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"
,
"1"
)
!=
"0"
),
"VLLM_TORCH_COMPILE_LEVEL"
:
lambda
:
int
(
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_LEVEL"
,
"0"
)),
# Path to the config file for torch compile
"VLLM_TORCH_COMPILE_CONFIG"
:
lambda
:
os
.
environ
.
get
(
"VLLM_TORCH_COMPILE_CONFIG"
,
None
),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
...
...
vllm/platforms/tpu.py
View file @
803f37ea
import
os
from
typing
import
TYPE_CHECKING
import
torch
...
...
@@ -40,7 +39,8 @@ class TpuPlatform(Platform):
def
check_and_update_config
(
cls
,
vllm_config
:
VllmConfig
)
->
None
:
from
vllm.config
import
CompilationLevel
compilation_config
=
vllm_config
.
compilation_config
if
"VLLM_TORCH_COMPILE_LEVEL"
not
in
os
.
environ
:
if
compilation_config
.
level
==
CompilationLevel
.
NO_COMPILATION
:
# TPU does not support NO_COMPILATION
compilation_config
.
level
=
CompilationLevel
.
DYNAMO_ONCE
assert
compilation_config
.
level
<
CompilationLevel
.
PIECEWISE
,
\
"TPU does not support Inductor."
...
...
vllm/plugins/__init__.py
View file @
803f37ea
...
...
@@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import
vllm.envs
as
envs
if
TYPE_CHECKING
:
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
VllmConfig
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -54,18 +54,6 @@ def load_general_plugins():
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
_compilation_config
:
Optional
[
"CompilationConfig"
]
=
None
def
set_compilation_config
(
config
:
Optional
[
"CompilationConfig"
]):
global
_compilation_config
_compilation_config
=
config
def
get_compilation_config
()
->
Optional
[
"CompilationConfig"
]:
return
_compilation_config
_current_vllm_config
:
Optional
[
"VllmConfig"
]
=
None
...
...
vllm/v1/worker/gpu_model_runner.py
View file @
803f37ea
...
...
@@ -8,13 +8,12 @@ import torch.distributed
import
torch.nn
as
nn
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
CompilationLevel
,
VllmConfig
from
vllm.forward_context
import
set_forward_context
from
vllm.inputs
import
INPUT_REGISTRY
,
InputRegistry
from
vllm.logger
import
init_logger
from
vllm.model_executor.model_loader
import
get_model
from
vllm.multimodal
import
MultiModalKwargs
from
vllm.plugins
import
set_compilation_config
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.utils
import
(
STR_DTYPE_TO_TORCH_DTYPE
,
DeviceMemoryProfiler
,
cdiv
,
is_pin_memory_available
)
...
...
@@ -508,20 +507,6 @@ class GPUModelRunner:
return
model_runner_output
def
load_model
(
self
)
->
None
:
if
self
.
use_cuda_graph
:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
set_compilation_config
(
CompilationConfig
(
custom_ops
=
[
"none"
],
use_cudagraph
=
True
,
non_cudagraph_ops
=
[
"vllm.unified_v1_flash_attention"
],
use_inductor
=
True
,
enable_fusion
=
False
,
))
logger
.
info
(
"Starting to load model %s..."
,
self
.
model_config
.
model
)
with
DeviceMemoryProfiler
()
as
m
:
# noqa: SIM117
self
.
model
=
get_model
(
vllm_config
=
self
.
vllm_config
)
...
...
@@ -562,9 +547,8 @@ class GPUModelRunner:
def
capture_model
(
self
)
->
None
:
if
not
self
.
use_cuda_graph
:
logger
.
warning
(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs."
,
CompilationLevel
.
PIECEWISE
)
"Skipping CUDA graph capture. Please add "
"-O 3 to use CUDA graphs."
,
CompilationLevel
.
PIECEWISE
)
return
start_time
=
time
.
perf_counter
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment