Commit bf3d75f4 authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface

parent f6324f60
...@@ -1221,23 +1221,42 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl): ...@@ -1221,23 +1221,42 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float() s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
return res, s_lse return res, s_lse
output, softmax_lse = flash_attn_varlen_func( if not current_platform.is_rocm():
q=query_states, output, softmax_lse = flash_attn_varlen_func(
k=key_states, q=query_states,
v=value_states, k=key_states,
softmax_scale=softmax_scale, v=value_states,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]], softmax_scale=softmax_scale,
dtype=torch.int32, cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
device=query_states.device), dtype=torch.int32,
max_seqlen_q=query_states.shape[0], device=query_states.device),
cu_seqlens_k=torch.tensor([0, max_seqlen_k], max_seqlen_q=query_states.shape[0],
dtype=torch.int32, cu_seqlens_k=torch.tensor([0, max_seqlen_k],
device=query_states.device), dtype=torch.int32,
max_seqlen_k=max_seqlen_k, device=query_states.device),
causal=causal, max_seqlen_k=max_seqlen_k,
block_table=block_table.unsqueeze(0), causal=causal,
return_softmax_lse=True, block_table=block_table.unsqueeze(0),
) return_softmax_lse=True,
)
else:
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
block_table=block_table.unsqueeze(0),
return_attn_probs=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0, softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
2).float() 2).float()
return output, softmax_lse return output, softmax_lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment