Commit 7fafc730 authored by zhuwenwen's avatar zhuwenwen
Browse files

choose ck or cutlass implementation based on the fa version

parent 5b62725e
......@@ -303,7 +303,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend")
logger.debug("Using CK/CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
......@@ -453,7 +453,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks,
)
else:
if envs.VLLM_USE_CL_FLASH_ATTN:
import flash_attn
major, minor, _ = flash_attn.__version__.split('.')
if (major, minor) >= ('2', '6'):
out = self.attn_func(
q=query,
k=key,
......
......@@ -12,7 +12,6 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0
......@@ -197,11 +196,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment