Unverified Commit a400fe89 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LED Test] fix common inputs pt for flaky pt-tf led test (#9459)

* fix common inputs pt flakey led

* fix other tests correspondingly
parent ae5a32bb
...@@ -110,7 +110,8 @@ class LEDModelTester: ...@@ -110,7 +110,8 @@ class LEDModelTester:
# because its local attention only attends to `self.attention_window + 1` locations # because its local attention only attends to `self.attention_window + 1` locations
# (assuming no token with global attention, otherwise the last dimension of attentions # (assuming no token with global attention, otherwise the last dimension of attentions
# is x + self.attention_window + 1, where x is the number of tokens with global attention) # is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.encoder_key_length = self.attention_window + 1 # x is set to 1
self.encoder_key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for # because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests # the `test_attention_outputs` and `test_hidden_states_output` tests
...@@ -149,6 +150,10 @@ class LEDModelTester: ...@@ -149,6 +150,10 @@ class LEDModelTester:
def prepare_config_and_inputs_for_common(self): def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs() config, inputs_dict = self.prepare_config_and_inputs()
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
global_attention_mask[:, -1] = 1
inputs_dict["global_attention_mask"] = global_attention_mask
return config, inputs_dict return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict): def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
...@@ -196,9 +201,11 @@ class LEDModelTester: ...@@ -196,9 +201,11 @@ class LEDModelTester:
encoder.save_pretrained(tmpdirname) encoder.save_pretrained(tmpdirname)
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device) encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[ encoder_last_hidden_state_2 = encoder(
0 inputs_dict["input_ids"],
] attention_mask=inputs_dict["attention_mask"],
global_attention_mask=inputs_dict["global_attention_mask"],
)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3) self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
...@@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
) )
out_len = len(outputs) out_len = len(outputs)
correct_outlen = 5 # global attention outputs are added as well => so +1 here
correct_outlen = 6
# loss is at first position # loss is at first position
if "labels" in inputs_dict: if "labels" in inputs_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment