Commit fe054987 authored by laibao's avatar laibao
Browse files

增加VLLM_USE_OPT_RESHAPE_AND_CACHE 环境变量用来控制 优化reshape and cache kernel

parent a921f34c
......@@ -185,6 +185,7 @@ if TYPE_CHECKING:
VLLM_USE_ZERO_MTP: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_FUSE_SILU_AND_MUL: bool = True
VLLM_USE_OPT_RESHAPE_AND_CACHE: bool = True
def get_default_cache_root():
return os.getenv(
......@@ -1199,6 +1200,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_FUSE_SILU_AND_MUL":
lambda: (os.environ.get("VLLM_USE_FUSE_SILU_AND_MUL", "True").lower() in
("true", "1")),
# vLLM will use optimized reshape_and_cache kernel when enabled
"VLLM_USE_OPT_RESHAPE_AND_CACHE":
lambda:
(os.environ.get("VLLM_USE_OPT_RESHAPE_AND_CACHE", "True").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -566,16 +566,26 @@ class FlashAttentionImpl(AttentionImpl):
layer._v_scale,
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
from lightop import reshape_and_cache_cuda
if envs.VLLM_USE_OPT_RESHAPE_AND_CACHE:
reshape_and_cache_cuda(
key, value,
key_cache, value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale, layer._v_scale
)
else:
reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
)
if self.kv_cache_dtype.startswith("fp8"):
key_cache = key_cache.view(torch.float8_e4m3fn)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment