"examples/run_bert_classifier.py" did not exist on "5ee171689c07b05ef02ca6596f3ec7cfd247478e"
Commit 810079de authored by sshleifer's avatar sshleifer
Browse files

no ipdb

parent c203509d
...@@ -688,6 +688,7 @@ class SelfAttention(nn.Module): ...@@ -688,6 +688,7 @@ class SelfAttention(nn.Module):
static_kv: bool, static_kv: bool,
) -> Optional[Tensor]: ) -> Optional[Tensor]:
# saved key padding masks have shape (bsz, seq_len) # saved key padding masks have shape (bsz, seq_len)
if prev_key_padding_mask is not None and static_kv: if prev_key_padding_mask is not None and static_kv:
new_key_padding_mask = prev_key_padding_mask new_key_padding_mask = prev_key_padding_mask
elif prev_key_padding_mask is not None and key_padding_mask is not None: elif prev_key_padding_mask is not None and key_padding_mask is not None:
...@@ -699,10 +700,6 @@ class SelfAttention(nn.Module): ...@@ -699,10 +700,6 @@ class SelfAttention(nn.Module):
if prev_key_padding_mask.is_cuda: if prev_key_padding_mask.is_cuda:
filler = filler.to(prev_key_padding_mask.device) filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
import ipdb
ipdb.set_trace()
elif key_padding_mask is not None: elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda: if key_padding_mask.is_cuda:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment