Commit b016dd16 authored by thomwolf's avatar thomwolf
Browse files

fix tests on python 3.5

parent 169fea68
...@@ -338,7 +338,7 @@ class T5Attention(nn.Module): ...@@ -338,7 +338,7 @@ class T5Attention(nn.Module):
raise ValueError("No position_bias provided and no weights to compute position_bias") raise ValueError("No position_bias provided and no weights to compute position_bias")
position_bias = self.compute_bias(qlen, klen) position_bias = self.compute_bias(qlen, klen)
if mask is not None: if mask is not None:
position_bias += mask # (bs, n_heads, qlen, klen) position_bias = position_bias + mask # (bs, n_heads, qlen, klen)
scores += position_bias scores += position_bias
weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen) weights = F.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
......
...@@ -138,8 +138,8 @@ class CommonTestCases: ...@@ -138,8 +138,8 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(attentions[0].shape[-3:]), list(attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.seq_length, self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length])
out_len = len(outputs) out_len = len(outputs)
if self.is_encoder_decoder: if self.is_encoder_decoder:
...@@ -151,8 +151,8 @@ class CommonTestCases: ...@@ -151,8 +151,8 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(decoder_attentions[0].shape[-3:]), list(decoder_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.seq_length, self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) self.model_tester.decoder_seq_length if hasattr(self.model_tester, 'decoder_seq_length') else self.model_tester.seq_length])
# Check attention is always last and order is fine # Check attention is always last and order is fine
config.output_attentions = True config.output_attentions = True
...@@ -169,8 +169,8 @@ class CommonTestCases: ...@@ -169,8 +169,8 @@ class CommonTestCases:
self.assertListEqual( self.assertListEqual(
list(self_attentions[0].shape[-3:]), list(self_attentions[0].shape[-3:]),
[self.model_tester.num_attention_heads, [self.model_tester.num_attention_heads,
self.model_tester.seq_length, self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.key_len if hasattr(self.model_tester, 'key_len') else self.model_tester.seq_length]) self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length])
def test_torchscript(self): def test_torchscript(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
...@@ -440,7 +440,8 @@ class CommonTestCases: ...@@ -440,7 +440,8 @@ class CommonTestCases:
self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1) self.assertEqual(len(hidden_states), self.model_tester.num_hidden_layers + 1)
self.assertListEqual( self.assertListEqual(
list(hidden_states[0].shape[-2:]), list(hidden_states[0].shape[-2:]),
[self.model_tester.seq_length, self.model_tester.hidden_size]) [self.model_tester.encoder_seq_length if hasattr(self.model_tester, 'encoder_seq_length') else self.model_tester.seq_length,
self.model_tester.hidden_size])
def test_resize_tokens_embeddings(self): def test_resize_tokens_embeddings(self):
original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() original_config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
......
...@@ -134,7 +134,7 @@ class T5Tokenizer(PreTrainedTokenizer): ...@@ -134,7 +134,7 @@ class T5Tokenizer(PreTrainedTokenizer):
""" Converts a token (str/unicode) in an id using the vocab. """ """ Converts a token (str/unicode) in an id using the vocab. """
if token.startswith(u"<extra_id_"): if token.startswith(u"<extra_id_"):
l = re.match(r'<extra_id_(\d+)>', token) l = re.match(r'<extra_id_(\d+)>', token)
num = int(l[1]) num = int(l.group(1))
return self.vocab_size - num - 1 return self.vocab_size - num - 1
return self.sp_model.piece_to_id(token) return self.sp_model.piece_to_id(token)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment