"vscode:/vscode.git/clone" did not exist on "308f099b4e31a0cc61ba787e5997ec38d5dcf921"
Unverified Commit c13dbd5c authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

fix attention mask pad check (#3531)

parent bde2cb5d
...@@ -381,12 +381,7 @@ class Attention(nn.Module): ...@@ -381,12 +381,7 @@ class Attention(nn.Module):
return attention_mask return attention_mask
current_length: int = attention_mask.shape[-1] current_length: int = attention_mask.shape[-1]
if current_length > target_length: if current_length != target_length:
# we *could* trim the mask with:
# attention_mask = attention_mask[:,:target_length]
# but this is weird enough that it's more likely to be a mistake than a shortcut
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
elif current_length < target_length:
if attention_mask.device.type == "mps": if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor. # HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor. # Instead, we can manually construct the padding tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment