Commit 8b1077ba authored by zhuwenwen's avatar zhuwenwen
Browse files

update _forward_encoder_attention interface

parent 98f111f9
...@@ -736,25 +736,47 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -736,25 +736,47 @@ class FlashAttentionImpl(AttentionImpl):
self.num_kv_heads) self.num_kv_heads)
# Call flash attention directly on Q, K, V tensors # Call flash attention directly on Q, K, V tensors
flash_attn_varlen_func( if not current_platform.is_rocm():
q=query, flash_attn_varlen_func(
k=key, q=query,
v=value, k=key,
out=output, v=value,
cu_seqlens_q=cu_seqlens_q, out=output,
cu_seqlens_k=cu_seqlens_k, cu_seqlens_q=cu_seqlens_q,
max_seqlen_q=max_seqlen_q, cu_seqlens_k=cu_seqlens_k,
max_seqlen_k=max_seqlen_k, max_seqlen_q=max_seqlen_q,
softmax_scale=self.scale, max_seqlen_k=max_seqlen_k,
causal=False, # Encoder attention is bidirectional softmax_scale=self.scale,
alibi_slopes=self.alibi_slopes, causal=False, # Encoder attention is bidirectional
window_size=self.sliding_window, alibi_slopes=self.alibi_slopes,
softcap=self.logits_soft_cap, window_size=self.sliding_window,
fa_version=self.vllm_flash_attn_version, softcap=self.logits_soft_cap,
q_descale=layer._q_scale.expand(descale_shape), fa_version=self.vllm_flash_attn_version,
k_descale=layer._k_scale.expand(descale_shape), q_descale=layer._q_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
) v_descale=layer._v_scale.expand(descale_shape),
)
else:
vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
is_prefix_cache=True,
)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment