Commit c36fdc88 authored by sshleifer's avatar sshleifer
Browse files

tests pass

parent 7ac47bfe
...@@ -15,11 +15,14 @@ ...@@ -15,11 +15,14 @@
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE. # SOFTWARE.
import copy import copy
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
import math import math
import numpy as np import numpy as np
......
...@@ -640,9 +640,10 @@ class SelfAttention(nn.Module): ...@@ -640,9 +640,10 @@ class SelfAttention(nn.Module):
reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool) reshaped = key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool)
attn_weights = attn_weights.masked_fill(reshaped, float("-inf")) attn_weights = attn_weights.masked_fill(reshaped, float("-inf"))
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len)
attn_weights_float = F.softmax(attn_weights, dim=-1, dtype=torch.float32) attn_weights_float = F.softmax(attn_weights, dim=-1)
attn_weights = attn_weights_float.type_as(attn_weights)
attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,) attn_probs = F.dropout(attn_weights_float, p=self.dropout, training=self.training,)
attn_weights = attn_weights_float.type_as(attn_weights)
assert v is not None assert v is not None
attn_output = torch.bmm(attn_probs, v) attn_output = torch.bmm(attn_probs, v)
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim) assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
...@@ -696,8 +697,12 @@ class SelfAttention(nn.Module): ...@@ -696,8 +697,12 @@ class SelfAttention(nn.Module):
elif prev_key_padding_mask is not None: elif prev_key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1)) filler = torch.zeros(batch_size, src_len - prev_key_padding_mask.size(1))
if prev_key_padding_mask.is_cuda: if prev_key_padding_mask.is_cuda:
filler = filler.cuda() filler = filler.to(prev_key_padding_mask.device)
new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1) new_key_padding_mask = torch.cat([prev_key_padding_mask.float(), filler.float()], dim=1)
print(new_key_padding_mask.device, new_key_padding_mask.dtype)
import ipdb
ipdb.set_trace()
elif key_padding_mask is not None: elif key_padding_mask is not None:
filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1)) filler = torch.zeros(batch_size, src_len - key_padding_mask.size(1))
if key_padding_mask.is_cuda: if key_padding_mask.is_cuda:
......
...@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase): ...@@ -243,15 +243,15 @@ class BartHeadTests(unittest.TestCase):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
) )
lm_model = BartForMaskedLM(config) lm_model = BartForMaskedLM(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long() context = _long_tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]])
summary = torch.Tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]]).long() summary = _long_tensor([[82, 71, 82, 18, 2], [58, 68, 2, 1, 1]])
logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary) logits, enc_features = lm_model.forward(input_ids=context, decoder_input_ids=summary)
expected_shape = (*summary.shape, config.vocab_size) expected_shape = (*summary.shape, config.vocab_size)
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
def test_generate_beam_search(self): def test_generate_beam_search(self):
input_ids = torch.Tensor([[71, 82, 2], [68, 34, 2]]).long() input_ids = _long_tensor([[71, 82, 2], [68, 34, 2]])
config = BartConfig( config = BartConfig(
vocab_size=self.vocab_size, vocab_size=self.vocab_size,
d_model=24, d_model=24,
...@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase): ...@@ -264,7 +264,7 @@ class BartHeadTests(unittest.TestCase):
max_position_embeddings=48, max_position_embeddings=48,
output_past=True, output_past=True,
) )
lm_model = BartForMaskedLM(config) lm_model = BartForMaskedLM(config).to(torch_device)
lm_model.eval() lm_model.eval()
new_input_ids = lm_model.generate( new_input_ids = lm_model.generate(
...@@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase): ...@@ -294,6 +294,13 @@ class BartHeadTests(unittest.TestCase):
bart_toks = tokenizer.encode(ex, return_tensors="pt") bart_toks = tokenizer.encode(ex, return_tensors="pt")
_assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex) _assert_tensors_equal(desired_result.long(), bart_toks, prefix=ex)
@unittest.skipIf(torch_device == "cpu", "Cant do half precision")
def test_generate_fp16(self):
config, input_ids, batch_size = self._get_config_and_data(output_past=True)
attention_mask = input_ids.ne(1)
lm_model = BartForMaskedLM(config).eval().to(torch_device).half()
lm_model.generate(input_ids, attention_mask)
def _assert_tensors_equal(a, b, atol=1e-12, prefix=""): def _assert_tensors_equal(a, b, atol=1e-12, prefix=""):
"""If tensors not close, or a and b arent both tensors, raise a nice Assertion error.""" """If tensors not close, or a and b arent both tensors, raise a nice Assertion error."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment