Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
05d1f8c9
Unverified
Commit
05d1f8c9
authored
Nov 25, 2024
by
youkaichao
Committed by
GitHub
Nov 25, 2024
Browse files
[misc] move functions to config.py (#10624)
Signed-off-by:
youkaichao
<
youkaichao@gmail.com
>
parent
25d806e9
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
62 additions
and
73 deletions
+62
-73
tests/compile/piecewise/test_simple.py
tests/compile/piecewise/test_simple.py
+2
-2
tests/compile/piecewise/test_toy_llama.py
tests/compile/piecewise/test_toy_llama.py
+2
-2
tests/kernels/test_encoder_decoder_attn.py
tests/kernels/test_encoder_decoder_attn.py
+1
-2
tests/model_executor/test_enabled_custom_ops.py
tests/model_executor/test_enabled_custom_ops.py
+1
-2
vllm/attention/layer.py
vllm/attention/layer.py
+1
-2
vllm/compilation/wrapper.py
vllm/compilation/wrapper.py
+1
-2
vllm/config.py
vllm/config.py
+51
-0
vllm/model_executor/custom_op.py
vllm/model_executor/custom_op.py
+1
-1
vllm/model_executor/model_loader/loader.py
vllm/model_executor/model_loader/loader.py
+1
-2
vllm/model_executor/model_loader/tensorizer.py
vllm/model_executor/model_loader/tensorizer.py
+1
-2
vllm/plugins/__init__.py
vllm/plugins/__init__.py
+0
-56
No files found.
tests/compile/piecewise/test_simple.py
View file @
05d1f8c9
...
@@ -10,8 +10,8 @@ from torch.library import Library
...
@@ -10,8 +10,8 @@ from torch.library import Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.plugins
import
set_current_vllm_config
set_current_vllm_config
)
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
global_counter
=
0
global_counter
=
0
...
...
tests/compile/piecewise/test_toy_llama.py
View file @
05d1f8c9
...
@@ -16,8 +16,8 @@ from torch.library import Library
...
@@ -16,8 +16,8 @@ from torch.library import Library
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.compile_context
import
set_compile_context
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.counter
import
compilation_counter
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.compilation.decorators
import
support_torch_compile
from
vllm.config
import
CompilationConfig
,
CompilationLevel
,
VllmConfig
from
vllm.config
import
(
CompilationConfig
,
CompilationLevel
,
VllmConfig
,
from
vllm.plugins
import
set_current_vllm_config
set_current_vllm_config
)
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
# create a library to hold the custom op
# create a library to hold the custom op
...
...
tests/kernels/test_encoder_decoder_attn.py
View file @
05d1f8c9
...
@@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
...
@@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.backends.utils
import
STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
from
vllm.attention.selector
import
(
_Backend
,
_cached_get_attn_backend
,
global_force_attn_backend_context_manager
)
global_force_attn_backend_context_manager
)
from
vllm.config
import
VllmConfig
from
vllm.config
import
VllmConfig
,
set_current_vllm_config
from
vllm.forward_context
import
set_forward_context
from
vllm.forward_context
import
set_forward_context
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
set_current_vllm_config
# List of support backends for encoder/decoder models
# List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
LIST_ENC_DEC_SUPPORTED_BACKENDS
=
[
_Backend
.
XFORMERS
,
_Backend
.
FLASH_ATTN
]
...
...
tests/model_executor/test_enabled_custom_ops.py
View file @
05d1f8c9
...
@@ -2,13 +2,12 @@ from typing import List
...
@@ -2,13 +2,12 @@ from typing import List
import
pytest
import
pytest
from
vllm.config
import
CompilationConfig
,
VllmConfig
from
vllm.config
import
CompilationConfig
,
VllmConfig
,
set_current_vllm_config
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.custom_op
import
CustomOp
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
from
vllm.model_executor.layers.activation
import
(
GeluAndMul
,
ReLUSquaredActivation
,
ReLUSquaredActivation
,
SiluAndMul
)
SiluAndMul
)
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.model_executor.layers.layernorm
import
RMSNorm
from
vllm.plugins
import
set_current_vllm_config
# Registered subclass for test
# Registered subclass for test
...
...
vllm/attention/layer.py
View file @
05d1f8c9
...
@@ -7,13 +7,12 @@ import torch.nn as nn
...
@@ -7,13 +7,12 @@ import torch.nn as nn
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention
import
AttentionMetadata
,
AttentionType
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.attention.selector
import
backend_name_to_enum
,
get_attn_backend
from
vllm.config
import
CacheConfig
from
vllm.config
import
CacheConfig
,
get_current_vllm_config
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.forward_context
import
ForwardContext
,
get_forward_context
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.model_executor.layers.quantization.kv_cache
import
BaseKVCacheMethod
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
get_current_vllm_config
from
vllm.utils
import
direct_register_custom_op
from
vllm.utils
import
direct_register_custom_op
...
...
vllm/compilation/wrapper.py
View file @
05d1f8c9
...
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional
...
@@ -8,7 +8,7 @@ from typing import Callable, List, Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
CompilationLevel
from
vllm.config
import
CompilationLevel
,
get_current_vllm_config
class
TorchCompileWrapperWithCustomDispatcher
:
class
TorchCompileWrapperWithCustomDispatcher
:
...
@@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher:
...
@@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings
# default compilation settings
# compiling the forward method
# compiling the forward method
from
vllm.plugins
import
get_current_vllm_config
backend
=
get_current_vllm_config
(
backend
=
get_current_vllm_config
(
).
compilation_config
.
init_backend
()
).
compilation_config
.
init_backend
()
...
...
vllm/config.py
View file @
05d1f8c9
...
@@ -3,6 +3,7 @@ import enum
...
@@ -3,6 +3,7 @@ import enum
import
hashlib
import
hashlib
import
json
import
json
import
warnings
import
warnings
from
contextlib
import
contextmanager
from
dataclasses
import
dataclass
,
field
,
replace
from
dataclasses
import
dataclass
,
field
,
replace
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Counter
,
Dict
,
from
typing
import
(
TYPE_CHECKING
,
Any
,
Callable
,
ClassVar
,
Counter
,
Dict
,
...
@@ -2450,3 +2451,53 @@ class VllmConfig:
...
@@ -2450,3 +2451,53 @@ class VllmConfig:
self
.
cache_config
.
enable_prefix_caching
,
self
.
cache_config
.
enable_prefix_caching
,
self
.
model_config
.
use_async_output_proc
,
self
.
model_config
.
use_async_output_proc
,
self
.
model_config
.
mm_processor_kwargs
)
self
.
model_config
.
mm_processor_kwargs
)
_current_vllm_config
:
Optional
[
VllmConfig
]
=
None
@
contextmanager
def
set_current_vllm_config
(
vllm_config
:
VllmConfig
):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global
_current_vllm_config
old_vllm_config
=
_current_vllm_config
from
vllm.compilation.counter
import
compilation_counter
num_models_seen
=
compilation_counter
.
num_models_seen
try
:
_current_vllm_config
=
vllm_config
yield
finally
:
logger
.
debug
(
"enabled custom ops: %s"
,
vllm_config
.
compilation_config
.
enabled_custom_ops
)
logger
.
debug
(
"disabled custom ops: %s"
,
vllm_config
.
compilation_config
.
disabled_custom_ops
)
if
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
\
and
compilation_counter
.
num_models_seen
==
num_models_seen
:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger
.
warning
(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported."
,
vllm_config
.
model_config
.
model
)
_current_vllm_config
=
old_vllm_config
def
get_current_vllm_config
()
->
VllmConfig
:
if
_current_vllm_config
is
None
:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger
.
warning
(
"Current VLLM config is not set."
)
from
vllm.config
import
VllmConfig
return
VllmConfig
()
return
_current_vllm_config
vllm/model_executor/custom_op.py
View file @
05d1f8c9
...
@@ -2,9 +2,9 @@ from typing import Dict, Type
...
@@ -2,9 +2,9 @@ from typing import Dict, Type
import
torch.nn
as
nn
import
torch.nn
as
nn
from
vllm.config
import
get_current_vllm_config
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
get_current_vllm_config
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
...
vllm/model_executor/model_loader/loader.py
View file @
05d1f8c9
...
@@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM
...
@@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
transformers.utils
import
SAFE_WEIGHTS_INDEX_NAME
from
vllm.config
import
(
LoadConfig
,
LoadFormat
,
ModelConfig
,
ParallelConfig
,
from
vllm.config
import
(
LoadConfig
,
LoadFormat
,
ModelConfig
,
ParallelConfig
,
VllmConfig
)
VllmConfig
,
set_current_vllm_config
)
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
get_tensor_model_parallel_world_size
)
from
vllm.envs
import
VLLM_USE_MODELSCOPE
from
vllm.envs
import
VLLM_USE_MODELSCOPE
...
@@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import (
...
@@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator
)
safetensors_weights_iterator
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
is_pin_memory_available
from
vllm.utils
import
is_pin_memory_available
...
...
vllm/model_executor/model_loader/tensorizer.py
View file @
05d1f8c9
...
@@ -13,13 +13,12 @@ from torch import nn
...
@@ -13,13 +13,12 @@ from torch import nn
from
transformers
import
PretrainedConfig
from
transformers
import
PretrainedConfig
import
vllm.envs
as
envs
import
vllm.envs
as
envs
from
vllm.config
import
ModelConfig
,
ParallelConfig
from
vllm.config
import
ModelConfig
,
ParallelConfig
,
set_current_vllm_config
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.arg_utils
import
EngineArgs
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.engine.llm_engine
import
LLMEngine
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
VocabParallelEmbedding
)
VocabParallelEmbedding
)
from
vllm.plugins
import
set_current_vllm_config
from
vllm.utils
import
FlexibleArgumentParser
from
vllm.utils
import
FlexibleArgumentParser
tensorizer_error_msg
=
None
tensorizer_error_msg
=
None
...
...
vllm/plugins/__init__.py
View file @
05d1f8c9
import
logging
import
logging
import
os
import
os
from
contextlib
import
contextmanager
from
typing
import
TYPE_CHECKING
,
Optional
import
torch
import
torch
import
vllm.envs
as
envs
import
vllm.envs
as
envs
if
TYPE_CHECKING
:
from
vllm.config
import
VllmConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# make sure one process only loads plugins once
# make sure one process only loads plugins once
...
@@ -64,54 +59,3 @@ def load_general_plugins():
...
@@ -64,54 +59,3 @@ def load_general_plugins():
logger
.
info
(
"plugin %s loaded."
,
plugin
.
name
)
logger
.
info
(
"plugin %s loaded."
,
plugin
.
name
)
except
Exception
:
except
Exception
:
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
logger
.
exception
(
"Failed to load plugin %s"
,
plugin
.
name
)
_current_vllm_config
:
Optional
[
"VllmConfig"
]
=
None
@
contextmanager
def
set_current_vllm_config
(
vllm_config
:
"VllmConfig"
):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global
_current_vllm_config
old_vllm_config
=
_current_vllm_config
from
vllm.compilation.counter
import
compilation_counter
from
vllm.config
import
CompilationLevel
num_models_seen
=
compilation_counter
.
num_models_seen
try
:
_current_vllm_config
=
vllm_config
yield
finally
:
logger
.
debug
(
"enabled custom ops: %s"
,
vllm_config
.
compilation_config
.
enabled_custom_ops
)
logger
.
debug
(
"disabled custom ops: %s"
,
vllm_config
.
compilation_config
.
disabled_custom_ops
)
if
vllm_config
.
compilation_config
.
level
==
CompilationLevel
.
PIECEWISE
\
and
compilation_counter
.
num_models_seen
==
num_models_seen
:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger
.
warning
(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported."
,
vllm_config
.
model_config
.
model
)
_current_vllm_config
=
old_vllm_config
def
get_current_vllm_config
()
->
"VllmConfig"
:
if
_current_vllm_config
is
None
:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger
.
warning
(
"Current VLLM config is not set."
)
from
vllm.config
import
VllmConfig
return
VllmConfig
()
return
_current_vllm_config
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment