Commit f48aca62 authored by zhuwenwen's avatar zhuwenwen
Browse files

update _forward_encoder_attention interface and support sinks

parent c5ac5cf7
......@@ -596,6 +596,7 @@ class FlashAttentionImpl(AttentionImpl):
self.sinks = sinks
if self.sinks is not None:
if not current_platform.is_rocm():
assert flash_attn_supports_sinks(), (
"Sinks are only supported in FlashAttention 3"
)
......@@ -812,7 +813,7 @@ class FlashAttentionImpl(AttentionImpl):
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=attn_metadata.max_num_splits,
# s_aux=self.sinks,
s_aux=self.sinks,
is_prefix_cache=True,
)
return output
......@@ -869,7 +870,7 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale,
# k_descale=layer._k_scale,
# v_descale=layer._v_scale,
# s_aux=self.sinks,
s_aux=self.sinks,
)
return output
......@@ -990,6 +991,7 @@ class FlashAttentionImpl(AttentionImpl):
)
# Call flash attention directly on Q, K, V tensors
if not current_platform.is_rocm():
flash_attn_varlen_func(
q=query,
k=key,
......@@ -1010,6 +1012,28 @@ class FlashAttentionImpl(AttentionImpl):
v_descale=layer._v_scale.expand(descale_shape),
num_splits=1 if self.batch_invariant_enabled else 0,
)
else:
vllm_flash_attn_varlen_func(
q=query,
k=key,
v=value,
out=output,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
softmax_scale=self.scale,
causal=False, # Encoder attention is bidirectional
alibi_slopes=self.alibi_slopes,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
# num_splits=1 if self.batch_invariant_enabled else 0,
is_prefix_cache=True,
)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment