Unverified Commit 4b3e4474 authored by ZiTian.Zhao's avatar ZiTian.Zhao Committed by GitHub
Browse files

Optimize configuration access with LRU cache in custom ops (#22204)


Signed-off-by: default avatarzitian zhao <zitian.zhao@tencentmusic.com>
parent bd3db7f4
......@@ -15,7 +15,7 @@ from collections.abc import Mapping
from contextlib import contextmanager
from dataclasses import (MISSING, Field, asdict, field, fields, is_dataclass,
replace)
from functools import cached_property
from functools import cached_property, lru_cache
from importlib.util import find_spec
from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Literal, Optional,
Protocol, TypeVar, Union, cast, get_args)
......@@ -5123,6 +5123,14 @@ def set_current_vllm_config(vllm_config: VllmConfig,
finally:
_current_vllm_config = old_vllm_config
_current_prefix = old_prefix
# Clear the compilation config cache when context changes
get_cached_compilation_config.cache_clear()
@lru_cache(maxsize=1)
def get_cached_compilation_config():
"""Cache config to avoid repeated calls to get_current_vllm_config()"""
return get_current_vllm_config().compilation_config
def get_current_vllm_config() -> VllmConfig:
......
......@@ -5,7 +5,7 @@ from typing import Optional
import torch.nn as nn
from vllm.config import get_current_vllm_config
from vllm.config import get_cached_compilation_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
......@@ -86,7 +86,7 @@ class CustomOp(nn.Module):
def dispatch_forward(self):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
enabled = self.enabled()
if enabled:
compilation_config.enabled_custom_ops.update([self.__class__.name])
......@@ -115,7 +115,7 @@ class CustomOp(nn.Module):
@classmethod
def enabled(cls) -> bool:
# if no name, then it was not registered
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
custom_ops = compilation_config.custom_ops
if not hasattr(cls, "name"):
logger.warning_once(
......@@ -138,7 +138,7 @@ class CustomOp(nn.Module):
Specifying 'all' or 'none' in custom_op takes precedence.
"""
from vllm.config import CompilationLevel
compilation_config = get_current_vllm_config().compilation_config
compilation_config = get_cached_compilation_config()
default_on = (compilation_config.level < CompilationLevel.PIECEWISE
or not compilation_config.use_inductor)
count_none = compilation_config.custom_ops.count("none")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment