"examples/offline_inference/vision_language_multi_image.py" did not exist on "aba8d6ee006b78149ac4514f460e4038b2d4f607"
Commit 2b84890b authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface

parent 9c663e50
......@@ -1215,22 +1215,40 @@ class DualChunkFlashAttentionImpl(FlashAttentionImpl):
s_lse = s_lse.view(q_len, q_heads, 1).transpose(0, 2).float()
return res, s_lse
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
return_softmax_lse=True,
)
if not current_platform.is_rocm():
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
return_softmax_lse=True,
)
else:
output, softmax_lse = flash_attn_varlen_func(
q=query_states,
k=key_states,
v=value_states,
softmax_scale=softmax_scale,
cu_seqlens_q=torch.tensor([0, query_states.shape[0]],
dtype=torch.int32,
device=query_states.device),
max_seqlen_q=query_states.shape[0],
cu_seqlens_k=torch.tensor([0, max_seqlen_k],
dtype=torch.int32,
device=query_states.device),
max_seqlen_k=max_seqlen_k,
causal=causal,
return_attn_probs=True,
)
softmax_lse = softmax_lse.view(q_len, q_heads, 1).transpose(0,
2).float()
return output, softmax_lse
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment