Unverified Commit 2298e69b authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[ci][bugfix] fix kernel tests (#10431)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent a03ea407
...@@ -6,9 +6,6 @@ import vllm.envs as envs ...@@ -6,9 +6,6 @@ import vllm.envs as envs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, VllmConfig
else:
CompilationConfig = None
VllmConfig = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -50,23 +47,23 @@ def load_general_plugins(): ...@@ -50,23 +47,23 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name) logger.exception("Failed to load plugin %s", plugin.name)
_compilation_config: Optional[CompilationConfig] = None _compilation_config: Optional["CompilationConfig"] = None
def set_compilation_config(config: Optional[CompilationConfig]): def set_compilation_config(config: Optional["CompilationConfig"]):
global _compilation_config global _compilation_config
_compilation_config = config _compilation_config = config
def get_compilation_config() -> Optional[CompilationConfig]: def get_compilation_config() -> Optional["CompilationConfig"]:
return _compilation_config return _compilation_config
_current_vllm_config: Optional[VllmConfig] = None _current_vllm_config: Optional["VllmConfig"] = None
@contextmanager @contextmanager
def set_current_vllm_config(vllm_config: VllmConfig): def set_current_vllm_config(vllm_config: "VllmConfig"):
""" """
Temporarily set the current VLLM config. Temporarily set the current VLLM config.
Used during model initialization. Used during model initialization.
...@@ -87,6 +84,12 @@ def set_current_vllm_config(vllm_config: VllmConfig): ...@@ -87,6 +84,12 @@ def set_current_vllm_config(vllm_config: VllmConfig):
_current_vllm_config = old_vllm_config _current_vllm_config = old_vllm_config
def get_current_vllm_config() -> VllmConfig: def get_current_vllm_config() -> "VllmConfig":
assert _current_vllm_config is not None, "Current VLLM config is not set." if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config return _current_vllm_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment