Unverified Commit 31f5dc5b authored by Yongye Zhu's avatar Yongye Zhu Committed by GitHub
Browse files

[gpt-oss] Enhance error msg on attention sink init (#22335)


Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
parent ec7cb192
......@@ -638,11 +638,15 @@ class FlashInferImpl(AttentionImpl):
self.sinks: Optional[torch.Tensor] = None
if sinks is not None:
assert sinks.shape[0] == num_heads, (
"Sinks must have the same number of heads "
"as the number of heads in the layer"
if sinks.shape[0] != num_heads:
raise ValueError(
"Sinks must have the same number of heads as the number of "
f"heads in the layer. Expected {num_heads}, but got "
f"{sinks.shape[0]}."
)
assert sinks.dtype == torch.float32, "Sinks must be of type float32"
if sinks.dtype != torch.float32:
raise ValueError("Sinks must be of type float32, but got "
f"{sinks.dtype}.")
self.sinks = sinks
def forward(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment