Unverified Commit 05d1f8c9 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[misc] move functions to config.py (#10624)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent 25d806e9
...@@ -10,8 +10,8 @@ from torch.library import Library ...@@ -10,8 +10,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
from vllm.plugins import set_current_vllm_config set_current_vllm_config)
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
global_counter = 0 global_counter = 0
......
...@@ -16,8 +16,8 @@ from torch.library import Library ...@@ -16,8 +16,8 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig,
from vllm.plugins import set_current_vllm_config set_current_vllm_config)
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
# create a library to hold the custom op # create a library to hold the custom op
......
...@@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata, ...@@ -18,10 +18,9 @@ from vllm.attention import (Attention, AttentionBackend, AttentionMetadata,
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.attention.selector import (_Backend, _cached_get_attn_backend, from vllm.attention.selector import (_Backend, _cached_get_attn_backend,
global_force_attn_backend_context_manager) global_force_attn_backend_context_manager)
from vllm.config import VllmConfig from vllm.config import VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context from vllm.forward_context import set_forward_context
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
# List of support backends for encoder/decoder models # List of support backends for encoder/decoder models
LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN] LIST_ENC_DEC_SUPPORTED_BACKENDS = [_Backend.XFORMERS, _Backend.FLASH_ATTN]
......
...@@ -2,13 +2,12 @@ from typing import List ...@@ -2,13 +2,12 @@ from typing import List
import pytest import pytest
from vllm.config import CompilationConfig, VllmConfig from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import (GeluAndMul, from vllm.model_executor.layers.activation import (GeluAndMul,
ReLUSquaredActivation, ReLUSquaredActivation,
SiluAndMul) SiluAndMul)
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.plugins import set_current_vllm_config
# Registered subclass for test # Registered subclass for test
......
...@@ -7,13 +7,12 @@ import torch.nn as nn ...@@ -7,13 +7,12 @@ import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.attention import AttentionMetadata, AttentionType from vllm.attention import AttentionMetadata, AttentionType
from vllm.attention.selector import backend_name_to_enum, get_attn_backend from vllm.attention.selector import backend_name_to_enum, get_attn_backend
from vllm.config import CacheConfig from vllm.config import CacheConfig, get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
......
...@@ -8,7 +8,7 @@ from typing import Callable, List, Optional ...@@ -8,7 +8,7 @@ from typing import Callable, List, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
from vllm.config import CompilationLevel from vllm.config import CompilationLevel, get_current_vllm_config
class TorchCompileWrapperWithCustomDispatcher: class TorchCompileWrapperWithCustomDispatcher:
...@@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher: ...@@ -32,7 +32,6 @@ class TorchCompileWrapperWithCustomDispatcher:
# default compilation settings # default compilation settings
# compiling the forward method # compiling the forward method
from vllm.plugins import get_current_vllm_config
backend = get_current_vllm_config( backend = get_current_vllm_config(
).compilation_config.init_backend() ).compilation_config.init_backend()
......
...@@ -3,6 +3,7 @@ import enum ...@@ -3,6 +3,7 @@ import enum
import hashlib import hashlib
import json import json
import warnings import warnings
from contextlib import contextmanager
from dataclasses import dataclass, field, replace from dataclasses import dataclass, field, replace
from pathlib import Path from pathlib import Path
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict,
...@@ -2450,3 +2451,53 @@ class VllmConfig: ...@@ -2450,3 +2451,53 @@ class VllmConfig:
self.cache_config.enable_prefix_caching, self.cache_config.enable_prefix_caching,
self.model_config.use_async_output_proc, self.model_config.use_async_output_proc,
self.model_config.mm_processor_kwargs) self.model_config.mm_processor_kwargs)
_current_vllm_config: Optional[VllmConfig] = None
@contextmanager
def set_current_vllm_config(vllm_config: VllmConfig):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> VllmConfig:
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config
...@@ -2,9 +2,9 @@ from typing import Dict, Type ...@@ -2,9 +2,9 @@ from typing import Dict, Type
import torch.nn as nn import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import get_current_vllm_config
from vllm.utils import print_warning_once from vllm.utils import print_warning_once
logger = init_logger(__name__) logger = init_logger(__name__)
......
...@@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM ...@@ -23,7 +23,7 @@ from transformers import AutoModelForCausalLM
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig,
VllmConfig) VllmConfig, set_current_vllm_config)
from vllm.distributed import (get_tensor_model_parallel_rank, from vllm.distributed import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.envs import VLLM_USE_MODELSCOPE from vllm.envs import VLLM_USE_MODELSCOPE
...@@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import ( ...@@ -47,7 +47,6 @@ from vllm.model_executor.model_loader.weight_utils import (
safetensors_weights_iterator) safetensors_weights_iterator)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.plugins import set_current_vllm_config
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
......
...@@ -13,13 +13,12 @@ from torch import nn ...@@ -13,13 +13,12 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
import vllm.envs as envs import vllm.envs as envs
from vllm.config import ModelConfig, ParallelConfig from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config
from vllm.engine.arg_utils import EngineArgs from vllm.engine.arg_utils import EngineArgs
from vllm.engine.llm_engine import LLMEngine from vllm.engine.llm_engine import LLMEngine
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding) VocabParallelEmbedding)
from vllm.plugins import set_current_vllm_config
from vllm.utils import FlexibleArgumentParser from vllm.utils import FlexibleArgumentParser
tensorizer_error_msg = None tensorizer_error_msg = None
......
import logging import logging
import os import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Optional
import torch import torch
import vllm.envs as envs import vllm.envs as envs
if TYPE_CHECKING:
from vllm.config import VllmConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# make sure one process only loads plugins once # make sure one process only loads plugins once
...@@ -64,54 +59,3 @@ def load_general_plugins(): ...@@ -64,54 +59,3 @@ def load_general_plugins():
logger.info("plugin %s loaded.", plugin.name) logger.info("plugin %s loaded.", plugin.name)
except Exception: except Exception:
logger.exception("Failed to load plugin %s", plugin.name) logger.exception("Failed to load plugin %s", plugin.name)
_current_vllm_config: Optional["VllmConfig"] = None
@contextmanager
def set_current_vllm_config(vllm_config: "VllmConfig"):
"""
Temporarily set the current VLLM config.
Used during model initialization.
We save the current VLLM config in a global variable,
so that all modules can access it, e.g. custom ops
can access the VLLM config to determine how to dispatch.
"""
global _current_vllm_config
old_vllm_config = _current_vllm_config
from vllm.compilation.counter import compilation_counter
from vllm.config import CompilationLevel
num_models_seen = compilation_counter.num_models_seen
try:
_current_vllm_config = vllm_config
yield
finally:
logger.debug("enabled custom ops: %s",
vllm_config.compilation_config.enabled_custom_ops)
logger.debug("disabled custom ops: %s",
vllm_config.compilation_config.disabled_custom_ops)
if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \
and compilation_counter.num_models_seen == num_models_seen:
# If the model supports compilation,
# compilation_counter.num_models_seen should be increased
# by at least 1.
# If it is not increased, it means the model does not support
# compilation (does not have @support_torch_compile decorator).
logger.warning(
"`torch.compile` is turned on, but the model %s"
" does not support it. Please open an issue on GitHub"
"if you want it to be supported.",
vllm_config.model_config.model)
_current_vllm_config = old_vllm_config
def get_current_vllm_config() -> "VllmConfig":
if _current_vllm_config is None:
# in ci, usually when we test custom ops/modules directly,
# we don't set the vllm config. In that case, we set a default
# config.
logger.warning("Current VLLM config is not set.")
from vllm.config import VllmConfig
return VllmConfig()
return _current_vllm_config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment