Unverified Commit 31f5dc5b authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[gpt-oss] Enhance error msg on attention sink init (#22335)


Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
parent ec7cb192
...@@ -638,11 +638,15 @@ class FlashInferImpl(AttentionImpl): ...@@ -638,11 +638,15 @@ class FlashInferImpl(AttentionImpl):
self.sinks: Optional[torch.Tensor] = None self.sinks: Optional[torch.Tensor] = None
if sinks is not None: if sinks is not None:
assert sinks.shape[0] == num_heads, ( if sinks.shape[0] != num_heads:
"Sinks must have the same number of heads " raise ValueError(
"as the number of heads in the layer" "Sinks must have the same number of heads as the number of "
) f"heads in the layer. Expected {num_heads}, but got "
assert sinks.dtype == torch.float32, "Sinks must be of type float32" f"{sinks.shape[0]}."
)
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
self.sinks = sinks self.sinks = sinks
def forward( def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment