Commit 810079de authored by sshleifer's avatar sshleifer
Browse files

no ipdb

parent c203509d
......@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
static_kv: bool,
) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None:
......@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
if prev_key_padding_mask.is_cuda:
filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
import ipdb
ipdb.set_trace()
elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment