Commit 8b1077ba authored by zhuwenwen's avatar zhuwenwen
Browse files

update _forward_encoder_attention interface

parent 98f111f9
......@@ -736,6 +736,7 @@ class FlashAttentionImpl(AttentionImpl):
self.num_kv_heads)
# Call flash attention directly on Q, K, V tensors
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=query,
k=key,
......@@ -755,6 +756,27 @@ class FlashAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape),
)
else:
vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
is_prefix_cache=True,
)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment