Unverified Commit 4d1e52ab authored by liwenju0's avatar liwenju0 Committed by GitHub
Browse files

Add an assertion to enhance the robustness of the operator (#5736)

parent 155890e4
...@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module): ...@@ -271,6 +271,8 @@ class VisionSdpaAttention(nn.Module):
Returns: Returns:
[b * s, h, head_size] [b * s, h, head_size]
""" """
if self.flatten_batch:
assert bsz == 1, "flatten_batch is True, bsz must be 1"
s = q.shape[0] // bsz s = q.shape[0] // bsz
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment