Commit 54ddee7f authored by zhuwenwen's avatar zhuwenwen
Browse files

remove ck fa interface

parent 38dc43cd
...@@ -372,7 +372,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -372,7 +372,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.attn_func = flash_attn_varlen_func
logger.debug("Using CK/CUTLASS FA in ROCmBackend") logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
...@@ -522,9 +522,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -522,9 +522,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
import flash_attn
major, minor, _ = flash_attn.__version__.split('.')
if (major, minor) >= ('2', '6'):
out = self.attn_func( out = self.attn_func(
q=query, q=query,
k=key, k=key,
...@@ -538,18 +535,6 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -538,18 +535,6 @@ class ROCmFlashAttentionImpl(AttentionImpl):
window_size=self.sliding_window, window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
) )
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
# common code for prefill # common code for prefill
assert output[:num_prefill_tokens].shape == out.shape assert output[:num_prefill_tokens].shape == out.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment