Commit 54ddee7f authored by zhuwenwen's avatar zhuwenwen
Browse files

remove ck fa interface

parent 38dc43cd
......@@ -372,7 +372,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
self.attn_func = flash_attn_varlen_func
logger.debug("Using CK/CUTLASS FA in ROCmBackend")
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
......@@ -522,34 +522,19 @@ class ROCmFlashAttentionImpl(AttentionImpl):
attn_masks,
)
else:
import flash_attn
major, minor, _ = flash_attn.__version__.split('.')
if (major, minor) >= ('2', '6'):
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
else:
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
)
out = self.attn_func(
q=query,
k=key,
v=value,
cu_seqlens_q=prefill_meta.seq_start_loc,
cu_seqlens_k=prefill_meta.seq_start_loc,
max_seqlen_q=prefill_meta.max_prefill_seq_len,
max_seqlen_k=prefill_meta.max_prefill_seq_len,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
alibi_slopes=self.alibi_slopes,
)
# common code for prefill
assert output[:num_prefill_tokens].shape == out.shape
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment