Commit 7fafc730 authored by zhuwenwen's avatar zhuwenwen
Browse files

choose ck or cutlass implementation based on the fa version

parent 5b62725e
...@@ -303,7 +303,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -303,7 +303,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.attn_func = flash_attn_varlen_func
logger.debug("Using CK FA in ROCmBackend") logger.debug("Using CK/CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
...@@ -453,20 +453,22 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -453,20 +453,22 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
if envs.VLLM_USE_CL_FLASH_ATTN: import flash_attn
out = self.attn_func( major, minor, _ = flash_attn.__version__.split('.')
q=query, if (major, minor) >= ('2', '6'):
k=key, out = self.attn_func(
v=value, q=query,
cu_seqlens_q=prefill_meta.seq_start_loc, k=key,
cu_seqlens_k=prefill_meta.seq_start_loc, v=value,
max_seqlen_q=prefill_meta.max_prefill_seq_len, cu_seqlens_q=prefill_meta.seq_start_loc,
max_seqlen_k=prefill_meta.max_prefill_seq_len, cu_seqlens_k=prefill_meta.seq_start_loc,
softmax_scale=self.scale, max_seqlen_q=prefill_meta.max_prefill_seq_len,
causal=True, max_seqlen_k=prefill_meta.max_prefill_seq_len,
window_size=self.sliding_window, softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, causal=True,
) window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else: else:
out = self.attn_func( out = self.attn_func(
q=query, q=query,
......
...@@ -12,7 +12,6 @@ if TYPE_CHECKING: ...@@ -12,7 +12,6 @@ if TYPE_CHECKING:
VLLM_NCCL_SO_PATH: Optional[str] = None VLLM_NCCL_SO_PATH: Optional[str] = None
LD_LIBRARY_PATH: Optional[str] = None LD_LIBRARY_PATH: Optional[str] = None
VLLM_USE_TRITON_FLASH_ATTN: bool = False VLLM_USE_TRITON_FLASH_ATTN: bool = False
VLLM_USE_CL_FLASH_ATTN: bool = False
VLLM_USE_OPT_OP: bool = False VLLM_USE_OPT_OP: bool = False
VLLM_USE_PA_PRINT_PARAM: bool = False VLLM_USE_PA_PRINT_PARAM: bool = False
LOCAL_RANK: int = 0 LOCAL_RANK: int = 0
...@@ -197,11 +196,6 @@ environment_variables: Dict[str, Callable[[], Any]] = { ...@@ -197,11 +196,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "False").lower() in
("true", "1")), ("true", "1")),
# flag to control if vllm should use cutlass flash attention
"VLLM_USE_CL_FLASH_ATTN":
lambda: (os.environ.get("VLLM_USE_CL_FLASH_ATTN", "True").lower() in
("true", "1")),
# flag to control vllm to use optimized kernels # flag to control vllm to use optimized kernels
"VLLM_USE_OPT_OP": "VLLM_USE_OPT_OP":
lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in lambda: (os.environ.get("VLLM_USE_OPT_OP", "True").lower() in
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment