Unverified Commit 1073a2bd authored by Sylvain Gugger's avatar Sylvain Gugger Committed by GitHub
Browse files

Switch `return_dict` to `True` by default. (#8530)

* Use the CI to identify failing tests

* Remove from all examples and tests

* More default switch

* Fixes

* More test fixes

* More fixes

* Last fixes hopefully

* Use the CI to identify failing tests

* Remove from all examples and tests

* More default switch

* Fixes

* More test fixes

* More fixes

* Last fixes hopefully

* Run on the real suite

* Fix slow tests
parent 0d0a0785
...@@ -101,7 +101,6 @@ class ElectraModelTester: ...@@ -101,7 +101,6 @@ class ElectraModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return ( return (
......
...@@ -85,7 +85,6 @@ class EncoderDecoderMixin: ...@@ -85,7 +85,6 @@ class EncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
self.assertEqual( self.assertEqual(
...@@ -117,7 +116,6 @@ class EncoderDecoderMixin: ...@@ -117,7 +116,6 @@ class EncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
self.assertEqual( self.assertEqual(
outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,)) outputs_encoder_decoder["logits"].shape, (decoder_input_ids.shape + (decoder_config.vocab_size,))
...@@ -132,7 +130,6 @@ class EncoderDecoderMixin: ...@@ -132,7 +130,6 @@ class EncoderDecoderMixin:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
self.assertEqual( self.assertEqual(
...@@ -278,7 +275,6 @@ class EncoderDecoderMixin: ...@@ -278,7 +275,6 @@ class EncoderDecoderMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=labels, labels=labels,
return_dict=True,
) )
loss = outputs_encoder_decoder["loss"] loss = outputs_encoder_decoder["loss"]
...@@ -313,7 +309,6 @@ class EncoderDecoderMixin: ...@@ -313,7 +309,6 @@ class EncoderDecoderMixin:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
output_attentions=True, output_attentions=True,
return_dict=True,
) )
encoder_attentions = outputs_encoder_decoder["encoder_attentions"] encoder_attentions = outputs_encoder_decoder["encoder_attentions"]
......
...@@ -113,7 +113,6 @@ class FlaubertModelTester(object): ...@@ -113,7 +113,6 @@ class FlaubertModelTester(object):
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
summary_type=self.summary_type, summary_type=self.summary_type,
use_proj=self.use_proj, use_proj=self.use_proj,
return_dict=True,
) )
return ( return (
......
...@@ -29,7 +29,7 @@ class FlaxBertModelTest(unittest.TestCase): ...@@ -29,7 +29,7 @@ class FlaxBertModelTest(unittest.TestCase):
# Check for simple input # Check for simple input
pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH) pt_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.PYTORCH)
fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX) fx_inputs = tokenizer.encode_plus("This is a simple input", return_tensors=TensorType.JAX)
pt_outputs = pt_model(**pt_inputs) pt_outputs = pt_model(**pt_inputs).to_tuple()
fx_outputs = fx_model(**fx_inputs) fx_outputs = fx_model(**fx_inputs)
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
......
...@@ -34,7 +34,7 @@ class FlaxRobertaModelTest(unittest.TestCase): ...@@ -34,7 +34,7 @@ class FlaxRobertaModelTest(unittest.TestCase):
self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch") self.assertEqual(len(fx_outputs), len(pt_outputs), "Output lengths differ between Flax and PyTorch")
for fx_output, pt_output in zip(fx_outputs, pt_outputs): for fx_output, pt_output in zip(fx_outputs, pt_outputs.to_tuple()):
self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4) self.assert_almost_equals(fx_output, pt_output.numpy(), 5e-4)
def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float): def assert_almost_equals(self, a: ndarray, b: ndarray, tol: float):
......
...@@ -259,7 +259,6 @@ class FSMTHeadTests(unittest.TestCase): ...@@ -259,7 +259,6 @@ class FSMTHeadTests(unittest.TestCase):
eos_token_id=2, eos_token_id=2,
pad_token_id=1, pad_token_id=1,
bos_token_id=0, bos_token_id=0,
return_dict=True,
) )
def _get_config_and_data(self): def _get_config_and_data(self):
......
...@@ -140,7 +140,6 @@ class FunnelModelTester: ...@@ -140,7 +140,6 @@ class FunnelModelTester:
activation_dropout=self.activation_dropout, activation_dropout=self.activation_dropout,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
return_dict=True,
) )
return ( return (
......
...@@ -131,7 +131,6 @@ class GPT2ModelTester: ...@@ -131,7 +131,6 @@ class GPT2ModelTester:
bos_token_id=self.bos_token_id, bos_token_id=self.bos_token_id,
eos_token_id=self.eos_token_id, eos_token_id=self.eos_token_id,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
return_dict=True,
gradient_checkpointing=gradient_checkpointing, gradient_checkpointing=gradient_checkpointing,
) )
......
...@@ -125,7 +125,6 @@ class LayoutLMModelTester: ...@@ -125,7 +125,6 @@ class LayoutLMModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, bbox, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......
...@@ -113,7 +113,6 @@ class LongformerModelTester: ...@@ -113,7 +113,6 @@ class LongformerModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
attention_window=self.attention_window, attention_window=self.attention_window,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......
...@@ -282,7 +282,6 @@ class LxmertModelTester: ...@@ -282,7 +282,6 @@ class LxmertModelTester:
attention_mask=input_mask, attention_mask=input_mask,
labels=ans, labels=ans,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=True,
) )
result = model(input_ids, visual_feats, bounding_boxes, labels=ans) result = model(input_ids, visual_feats, bounding_boxes, labels=ans)
result = model( result = model(
...@@ -302,7 +301,6 @@ class LxmertModelTester: ...@@ -302,7 +301,6 @@ class LxmertModelTester:
attention_mask=input_mask, attention_mask=input_mask,
labels=ans, labels=ans,
output_attentions=not output_attentions, output_attentions=not output_attentions,
return_dict=True,
) )
self.parent.assertEqual(result.question_answering_score.shape, (self.batch_size, self.num_qa_labels)) self.parent.assertEqual(result.question_answering_score.shape, (self.batch_size, self.num_qa_labels))
...@@ -335,7 +333,6 @@ class LxmertModelTester: ...@@ -335,7 +333,6 @@ class LxmertModelTester:
matched_label=matched_label, matched_label=matched_label,
ans=ans, ans=ans,
output_attentions=output_attentions, output_attentions=output_attentions,
return_dict=True,
) )
result = model( result = model(
input_ids, input_ids,
...@@ -390,7 +387,6 @@ class LxmertModelTester: ...@@ -390,7 +387,6 @@ class LxmertModelTester:
matched_label=matched_label, matched_label=matched_label,
ans=ans, ans=ans,
output_attentions=not output_attentions, output_attentions=not output_attentions,
return_dict=True,
) )
self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size)) self.parent.assertEqual(result.prediction_logits.shape, (self.batch_size, self.seq_length, self.vocab_size))
...@@ -427,7 +423,6 @@ class LxmertModelTester: ...@@ -427,7 +423,6 @@ class LxmertModelTester:
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
ans=ans, ans=ans,
return_dict=True,
) )
result_qa = model_qa( result_qa = model_qa(
...@@ -437,7 +432,6 @@ class LxmertModelTester: ...@@ -437,7 +432,6 @@ class LxmertModelTester:
labels=ans, labels=ans,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
return_dict=True,
) )
model_pretrain.resize_num_qa_labels(num_small_labels) model_pretrain.resize_num_qa_labels(num_small_labels)
...@@ -450,7 +444,6 @@ class LxmertModelTester: ...@@ -450,7 +444,6 @@ class LxmertModelTester:
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
ans=less_labels_ans, ans=less_labels_ans,
return_dict=True,
) )
result_qa_less = model_qa( result_qa_less = model_qa(
...@@ -460,7 +453,6 @@ class LxmertModelTester: ...@@ -460,7 +453,6 @@ class LxmertModelTester:
labels=less_labels_ans, labels=less_labels_ans,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
return_dict=True,
) )
model_pretrain.resize_num_qa_labels(num_large_labels) model_pretrain.resize_num_qa_labels(num_large_labels)
...@@ -473,7 +465,6 @@ class LxmertModelTester: ...@@ -473,7 +465,6 @@ class LxmertModelTester:
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
ans=more_labels_ans, ans=more_labels_ans,
return_dict=True,
) )
result_qa_more = model_qa( result_qa_more = model_qa(
...@@ -483,7 +474,6 @@ class LxmertModelTester: ...@@ -483,7 +474,6 @@ class LxmertModelTester:
labels=more_labels_ans, labels=more_labels_ans,
token_type_ids=token_type_ids, token_type_ids=token_type_ids,
attention_mask=input_mask, attention_mask=input_mask,
return_dict=True,
) )
model_qa_labels = model_qa.num_qa_labels model_qa_labels = model_qa.num_qa_labels
......
...@@ -50,7 +50,6 @@ class ModelTester: ...@@ -50,7 +50,6 @@ class ModelTester:
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
add_final_layer_norm=True, add_final_layer_norm=True,
return_dict=True,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -37,7 +37,6 @@ class ModelTester: ...@@ -37,7 +37,6 @@ class ModelTester:
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
add_final_layer_norm=True, add_final_layer_norm=True,
return_dict=True,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
...@@ -132,7 +131,6 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest): ...@@ -132,7 +131,6 @@ class MBartEnroIntegrationTest(AbstractSeq2SeqIntegrationTest):
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
add_final_layer_norm=True, add_final_layer_norm=True,
return_dict=True,
) )
lm_model = MBartForConditionalGeneration(config).to(torch_device) lm_model = MBartForConditionalGeneration(config).to(torch_device)
context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device) context = torch.Tensor([[71, 82, 18, 33, 46, 91, 2], [68, 34, 26, 58, 30, 2, 1]]).long().to(torch_device)
......
...@@ -124,7 +124,6 @@ class MobileBertModelTester: ...@@ -124,7 +124,6 @@ class MobileBertModelTester:
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
is_decoder=False, is_decoder=False,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......
...@@ -94,7 +94,6 @@ class OpenAIGPTModelTester: ...@@ -94,7 +94,6 @@ class OpenAIGPTModelTester:
# type_vocab_size=self.type_vocab_size, # type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range # initializer_range=self.initializer_range
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
return_dict=True,
) )
head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2) head_mask = ids_tensor([self.num_hidden_layers, self.num_attention_heads], 2)
......
...@@ -33,7 +33,6 @@ class ModelTester: ...@@ -33,7 +33,6 @@ class ModelTester:
decoder_ffn_dim=32, decoder_ffn_dim=32,
max_position_embeddings=48, max_position_embeddings=48,
add_final_layer_norm=True, add_final_layer_norm=True,
return_dict=True,
) )
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
......
...@@ -142,7 +142,6 @@ class ProphetNetModelTester: ...@@ -142,7 +142,6 @@ class ProphetNetModelTester:
disable_ngram_loss=self.disable_ngram_loss, disable_ngram_loss=self.disable_ngram_loss,
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
is_encoder_decoder=self.is_encoder_decoder, is_encoder_decoder=self.is_encoder_decoder,
return_dict=True,
) )
return ( return (
...@@ -344,7 +343,6 @@ class ProphetNetModelTester: ...@@ -344,7 +343,6 @@ class ProphetNetModelTester:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
tied_model_result = tied_model( tied_model_result = tied_model(
...@@ -352,7 +350,6 @@ class ProphetNetModelTester: ...@@ -352,7 +350,6 @@ class ProphetNetModelTester:
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
# check that models has less parameters # check that models has less parameters
...@@ -419,7 +416,6 @@ class ProphetNetModelTester: ...@@ -419,7 +416,6 @@ class ProphetNetModelTester:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
labels=lm_labels, labels=lm_labels,
return_dict=True,
) )
self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(128.2925, device=torch_device), atol=1e-3)) self.parent.assertTrue(torch.allclose(result.loss, torch.tensor(128.2925, device=torch_device), atol=1e-3))
...@@ -433,9 +429,7 @@ class ProphetNetModelTester: ...@@ -433,9 +429,7 @@ class ProphetNetModelTester:
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
outputs_no_mask = model( outputs_no_mask = model(input_ids=input_ids[:, :5], decoder_input_ids=decoder_input_ids[:, :5])
input_ids=input_ids[:, :5], decoder_input_ids=decoder_input_ids[:, :5], return_dict=True
)
attention_mask = torch.ones_like(input_ids) attention_mask = torch.ones_like(input_ids)
decoder_attention_mask = torch.ones_like(decoder_input_ids) decoder_attention_mask = torch.ones_like(decoder_input_ids)
...@@ -446,7 +440,6 @@ class ProphetNetModelTester: ...@@ -446,7 +440,6 @@ class ProphetNetModelTester:
attention_mask=attention_mask, attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids, decoder_input_ids=decoder_input_ids,
decoder_attention_mask=decoder_attention_mask, decoder_attention_mask=decoder_attention_mask,
return_dict=True,
) )
# check encoder # check encoder
...@@ -524,7 +517,6 @@ class ProphetNetStandaloneDecoderModelTester: ...@@ -524,7 +517,6 @@ class ProphetNetStandaloneDecoderModelTester:
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
ngram=2, ngram=2,
return_dict=True,
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
...@@ -562,7 +554,6 @@ class ProphetNetStandaloneDecoderModelTester: ...@@ -562,7 +554,6 @@ class ProphetNetStandaloneDecoderModelTester:
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.add_cross_attention = add_cross_attention self.add_cross_attention = add_cross_attention
self.is_encoder_decoder = is_encoder_decoder self.is_encoder_decoder = is_encoder_decoder
self.return_dict = return_dict
self.scope = None self.scope = None
self.decoder_key_length = decoder_seq_length self.decoder_key_length = decoder_seq_length
...@@ -602,7 +593,6 @@ class ProphetNetStandaloneDecoderModelTester: ...@@ -602,7 +593,6 @@ class ProphetNetStandaloneDecoderModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
add_cross_attention=self.add_cross_attention, add_cross_attention=self.add_cross_attention,
is_encoder_decoder=self.is_encoder_decoder, is_encoder_decoder=self.is_encoder_decoder,
return_dict=self.return_dict,
) )
return ( return (
...@@ -757,7 +747,6 @@ class ProphetNetStandaloneEncoderModelTester: ...@@ -757,7 +747,6 @@ class ProphetNetStandaloneEncoderModelTester:
pad_token_id=0, pad_token_id=0,
bos_token_id=1, bos_token_id=1,
eos_token_id=2, eos_token_id=2,
return_dict=True,
num_buckets=32, num_buckets=32,
relative_max_distance=128, relative_max_distance=128,
disable_ngram_loss=False, disable_ngram_loss=False,
...@@ -794,7 +783,6 @@ class ProphetNetStandaloneEncoderModelTester: ...@@ -794,7 +783,6 @@ class ProphetNetStandaloneEncoderModelTester:
self.max_position_embeddings = max_position_embeddings self.max_position_embeddings = max_position_embeddings
self.add_cross_attention = add_cross_attention self.add_cross_attention = add_cross_attention
self.is_encoder_decoder = is_encoder_decoder self.is_encoder_decoder = is_encoder_decoder
self.return_dict = return_dict
self.scope = None self.scope = None
self.decoder_key_length = decoder_seq_length self.decoder_key_length = decoder_seq_length
...@@ -829,7 +817,6 @@ class ProphetNetStandaloneEncoderModelTester: ...@@ -829,7 +817,6 @@ class ProphetNetStandaloneEncoderModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
add_cross_attention=self.add_cross_attention, add_cross_attention=self.add_cross_attention,
is_encoder_decoder=self.is_encoder_decoder, is_encoder_decoder=self.is_encoder_decoder,
return_dict=self.return_dict,
) )
return ( return (
...@@ -919,7 +906,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -919,7 +906,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
# methods overwrite method in `test_modeling_common.py` # methods overwrite method in `test_modeling_common.py`
def test_attention_outputs(self): def test_attention_outputs(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common() config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
config.return_dict = True
seq_len = getattr(self.model_tester, "seq_length", None) seq_len = getattr(self.model_tester, "seq_length", None)
decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len) decoder_seq_length = getattr(self.model_tester, "decoder_seq_length", seq_len)
...@@ -933,7 +919,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test ...@@ -933,7 +919,6 @@ class ProphetNetModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.Test
for model_class in self.all_model_classes: for model_class in self.all_model_classes:
inputs_dict["output_attentions"] = True inputs_dict["output_attentions"] = True
inputs_dict["output_hidden_states"] = False inputs_dict["output_hidden_states"] = False
config.return_dict = True
model = model_class(config) model = model_class(config)
model.to(torch_device) model.to(torch_device)
model.eval() model.eval()
...@@ -1121,7 +1106,6 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1121,7 +1106,6 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
attention_mask=None, attention_mask=None,
encoder_outputs=None, encoder_outputs=None,
decoder_input_ids=decoder_prev_ids, decoder_input_ids=decoder_prev_ids,
return_dict=True,
) )
output_predited_logits = output[0] output_predited_logits = output[0]
expected_shape = torch.Size((1, 12, 30522)) expected_shape = torch.Size((1, 12, 30522))
...@@ -1143,9 +1127,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase): ...@@ -1143,9 +1127,7 @@ class ProphetNetModelIntegrationTest(unittest.TestCase):
assert torch.allclose(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, atol=1e-4) assert torch.allclose(encoder_outputs[:, :3, :3], expected_encoder_outputs_slice, atol=1e-4)
# decoder outputs # decoder outputs
decoder_outputs = model.prophetnet.decoder( decoder_outputs = model.prophetnet.decoder(decoder_prev_ids, encoder_hidden_states=encoder_outputs)
decoder_prev_ids, encoder_hidden_states=encoder_outputs, return_dict=True
)
predicting_streams = decoder_outputs[1].view(1, model.config.ngram, 12, -1) predicting_streams = decoder_outputs[1].view(1, model.config.ngram, 12, -1)
predicting_streams_logits = model.lm_head(predicting_streams) predicting_streams_logits = model.lm_head(predicting_streams)
next_first_stream_logits = predicting_streams_logits[:, 0] next_first_stream_logits = predicting_streams_logits[:, 0]
......
...@@ -174,7 +174,6 @@ class ReformerModelTester: ...@@ -174,7 +174,6 @@ class ReformerModelTester:
attn_layers=self.attn_layers, attn_layers=self.attn_layers,
pad_token_id=self.pad_token_id, pad_token_id=self.pad_token_id,
hash_seed=self.hash_seed, hash_seed=self.hash_seed,
return_dict=True,
) )
return ( return (
......
...@@ -103,7 +103,6 @@ class RobertaModelTester: ...@@ -103,7 +103,6 @@ class RobertaModelTester:
max_position_embeddings=self.max_position_embeddings, max_position_embeddings=self.max_position_embeddings,
type_vocab_size=self.type_vocab_size, type_vocab_size=self.type_vocab_size,
initializer_range=self.initializer_range, initializer_range=self.initializer_range,
return_dict=True,
) )
return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
......
...@@ -131,7 +131,6 @@ if is_torch_available(): ...@@ -131,7 +131,6 @@ if is_torch_available():
post_attention_groups=self.post_attention_groups, post_attention_groups=self.post_attention_groups,
intermediate_groups=self.intermediate_groups, intermediate_groups=self.intermediate_groups,
output_groups=self.output_groups, output_groups=self.output_groups,
return_dict=True,
) )
return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels return config, input_ids, input_mask, sequence_labels, token_labels, choice_labels
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment