Unverified Commit c13dbd5c authored by Will Berman's avatar Will Berman Committed by GitHub
Browse files

fix attention mask pad check (#3531)

parent bde2cb5d
......@@ -381,12 +381,7 @@ class Attention(nn.Module):
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length > target_length:
# we *could* trim the mask with:
# attention_mask = attention_mask[:,:target_length]
# but this is weird enough that it's more likely to be a mistake than a shortcut
raise ValueError(f"mask's length ({current_length}) exceeds the sequence length ({target_length}).")
elif current_length < target_length:
if current_length != target_length:
if attention_mask.device.type == "mps":
# HACK: MPS: Does not support padding by greater than dimension of input tensor.
# Instead, we can manually construct the padding tensor.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment