"vscode:/vscode.git/clone" did not exist on "d696f86e7bdf23a6a4c212fee3522a589a460b24"
Unverified Commit 90ec0069 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[gpt-oss] flashinfer attention sink init (#22330)


Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarHongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: default avatarMinseok Lee <47620120+minseokl@users.noreply.github.com>
parent a47e6ffe
...@@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl): ...@@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None, logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER, attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None, kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None: ) -> None:
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
...@@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl): ...@@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for " "are not implemented for "
"FlashInferImpl") "FlashInferImpl")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads "
"as the number of heads in the layer"
)
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
self.sinks = sinks
def forward( def forward(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment