Commit 965934b8 authored by zhuwenwen's avatar zhuwenwen
Browse files

support sinks

parent 8b1077ba
...@@ -453,6 +453,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -453,6 +453,7 @@ class FlashAttentionImpl(AttentionImpl):
self.sinks = sinks self.sinks = sinks
if self.sinks is not None: if self.sinks is not None:
if not current_platform.is_rocm():
assert self.vllm_flash_attn_version == 3, ( assert self.vllm_flash_attn_version == 3, (
"Sinks are only supported in FlashAttention 3") "Sinks are only supported in FlashAttention 3")
assert self.sinks.shape[0] == num_heads, ( assert self.sinks.shape[0] == num_heads, (
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment