Commit bf3d75f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface

parent f6324f60
...@@ -1221,6 +1221,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -1221,6 +1221,7 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
return res, s_lse return res, s_lse
if not current_platform.is_rocm():
output, softmax_lse = flash_attn_varlen_func( output, softmax_lse = flash_attn_varlen_func(
q=query_states, q=query_states,
k=key_states, k=key_states,
...@@ -1238,6 +1239,24 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -1238,6 +1239,24 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
block_table=block_table.unsqueeze(0), block_table=block_table.unsqueeze(0),
return_softmax_lse=True, return_softmax_lse=True,
) )
else:
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
block_table=block_table.unsqueeze(0),
return_attn_probs=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
2).float() 2).float()
return output, softmax_lse return output, softmax_lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment