Commit 54ddee7f authored by zhuwenwen's avatar zhuwenwen
Browse files

remove ck fa interface

parent 38dc43cd
...@@ -372,7 +372,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -372,7 +372,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func self.attn_func = flash_attn_varlen_func
logger.debug("Using CK/CUTLASS FA in ROCmBackend") logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
...@@ -522,34 +522,19 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -522,34 +522,19 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks, attn_masks,
) )
else: else:
import flash_attn out = self.attn_func(
major, minor, _ = flash_attn.__version__.split('.') q=query,
if (major, minor) >= ('2', '6'): k=key,
out = self.attn_func( v=value,
q=query, cu_seqlens_q=prefill_meta.seq_start_loc,
k=key, cu_seqlens_k=prefill_meta.seq_start_loc,
v=value, max_seqlen_q=prefill_meta.max_prefill_seq_len,
cu_seqlens_q=prefill_meta.seq_start_loc, max_seqlen_k=prefill_meta.max_prefill_seq_len,
cu_seqlens_k=prefill_meta.seq_start_loc, softmax_scale=self.scale,
max_seqlen_q=prefill_meta.max_prefill_seq_len, causal=True,
max_seqlen_k=prefill_meta.max_prefill_seq_len, window_size=self.sliding_window,
softmax_scale=self.scale, alibi_slopes=self.alibi_slopes,
causal=True, )
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
# common code for prefill # common code for prefill
assert output[:num_prefill_tokens].shape == out.shape assert output[:num_prefill_tokens].shape == out.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment