Commit fe054987 authored by laibao's avatar laibao
Browse files

增加VLLM_USE_OPT_RESHAPE_AND_CACHE 环境变量用来控制 优化reshape and cache kernel

parent a921f34c
...@@ -185,6 +185,7 @@ if TYPE_CHECKING: ...@@ -185,6 +185,7 @@ if TYPE_CHECKING:
VLLM_USE_ZERO_MTP: bool = False VLLM_USE_ZERO_MTP: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = True VLLM_USE_FUSE_SILU_AND_MUL: bool = True
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = True
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1199,6 +1200,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1199,6 +1200,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSE_SILU_AND_MUL": "VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "True").lower() in lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -565,6 +565,16 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -565,6 +565,16 @@ class FlashAttentionImpl(AttentionImpl):
layer._k_scale, layer._k_scale,
layer._v_scale, layer._v_scale,
) )
else:
from lightop import reshape_and_cache_cuda
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
reshape_and_cache_cuda(
key, value,
key_cache, value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale
)
else: else:
reshape_and_cache_cuda( reshape_and_cache_cuda(
key, key,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment