Unverified Commit 90ec0069 authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[gpt-oss] flashinfer attention sink init (#22330)


Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarLiuXiaoxuanPKU <lilyliupku@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarChen Zhang <zhangch99@outlook.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarHongxia Yang <62075498+hongxiayang@users.noreply.github.com>
Co-authored-by: default avatarMinseok Lee <47620120+minseokl@users.noreply.github.com>
parent a47e6ffe
......@@ -611,6 +611,7 @@ class FlashInferImpl(AttentionImpl):
logits_soft_cap: Optional[float] = None,
attn_type: AttentionType = AttentionType.DECODER,
kv_sharing_target_layer_name: Optional[int] = None,
sinks: Optional[torch.Tensor] = None,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
......@@ -635,6 +636,15 @@ class FlashInferImpl(AttentionImpl):
"are not implemented for "
"FlashInferImpl")
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads "
"as the number of heads in the layer"
)
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
self.sinks = sinks
def forward(
self,
layer: torch.nn.Module,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment