Commit 2071c380 authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_MERGE_ATTN_STATES_OPT to control merge_attn_states support

parent 48b4c41d
...@@ -5,6 +5,7 @@ from typing import Optional ...@@ -5,6 +5,7 @@ from typing import Optional
import torch import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs
def merge_attn_states( def merge_attn_states(
...@@ -31,7 +32,7 @@ def merge_attn_states( ...@@ -31,7 +32,7 @@ def merge_attn_states(
return headdim % 4 == 0 return headdim % 4 == 0
return headdim % 8 == 0 return headdim % 8 == 0
if (current_platform.is_cuda() or current_platform.is_rocm() and supported_dtypes(output) if (current_platform.is_cuda() or envs.VLLM_USE_MERGE_ATTN_STATES_OPT and supported_dtypes(output)
and supported_headdim(output)): and supported_headdim(output)):
from vllm._custom_ops import merge_attn_states from vllm._custom_ops import merge_attn_states
return merge_attn_states(output, prefix_output, prefix_lse, return merge_attn_states(output, prefix_output, prefix_lse,
......
...@@ -166,6 +166,7 @@ if TYPE_CHECKING: ...@@ -166,6 +166,7 @@ if TYPE_CHECKING:
VLLM_USE_GLOBAL_CACHE13: bool = False VLLM_USE_GLOBAL_CACHE13: bool = False
VLLM_USE_LIGHT_OP: bool = False VLLM_USE_LIGHT_OP: bool = False
VLLM_USE_TRITON_CAT: bool = False VLLM_USE_TRITON_CAT: bool = False
VLLM_USE_MERGE_ATTN_STATES_OPT: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1099,6 +1100,10 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1099,6 +1100,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_TRITON_CAT": "VLLM_USE_TRITON_CAT":
lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_CAT", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states,not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment