"docs/source/vscode:/vscode.git/clone" did not exist on "fa661ce749b0d14ae1999d1b097866248624a842"
Commit 1360daca authored by sshleifer's avatar sshleifer
Browse files

cleanup deltas

parent 810079de
......@@ -640,9 +640,8 @@ class SelfAttention(nn.Module):
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_probs = F.dropout(attn_weights, p=self.dropout, training=self.training,)
assert v is not None
attn_output = torch.bmm(attn_probs, v)
......
......@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32,
max_position_embeddings=48,
)
lm_model = BartForMaskedLM(config).to(torch_device)
context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]])
summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]])
lm_model = BartForMaskedLM(config)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long()
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long()
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape)
def test_generate_beam_search(self):
input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]])
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long()
config = BartConfig(
vocab_size=self.vocab_size,
d_model=24,
......@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings=48,
output_past=True,
)
lm_model = BartForMaskedLM(config).to(torch_device)
lm_model = BartForMaskedLM(config)
lm_model.eval()
new_input_ids = lm_model.generate(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment