Unverified Commit a400fe89 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[LED Test] fix common inputs pt for flaky pt-tf led test (#9459)

* fix common inputs pt flakey led

* fix other tests correspondingly
parent ae5a32bb
......@@ -110,7 +110,8 @@ class LEDModelTester:
# because its local attention only attends to `self.attention_window + 1` locations
# (assuming no token with global attention, otherwise the last dimension of attentions
# is x + self.attention_window + 1, where x is the number of tokens with global attention)
self.encoder_key_length = self.attention_window + 1
# x is set to 1
self.encoder_key_length = self.attention_window + 2
# because of padding `encoder_seq_length`, is different from `seq_length`. Relevant for
# the `test_attention_outputs` and `test_hidden_states_output` tests
......@@ -149,6 +150,10 @@ class LEDModelTester:
def prepare_config_and_inputs_for_common(self):
config, inputs_dict = self.prepare_config_and_inputs()
global_attention_mask = torch.zeros_like(inputs_dict["input_ids"])
global_attention_mask[:, -1] = 1
inputs_dict["global_attention_mask"] = global_attention_mask
return config, inputs_dict
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
......@@ -196,9 +201,11 @@ class LEDModelTester:
encoder.save_pretrained(tmpdirname)
encoder = LEDEncoder.from_pretrained(tmpdirname).to(torch_device)
encoder_last_hidden_state_2 = encoder(inputs_dict["input_ids"], attention_mask=inputs_dict["attention_mask"])[
0
]
encoder_last_hidden_state_2 = encoder(
inputs_dict["input_ids"],
attention_mask=inputs_dict["attention_mask"],
global_attention_mask=inputs_dict["global_attention_mask"],
)[0]
self.parent.assertTrue((encoder_last_hidden_state_2 - encoder_last_hidden_state).abs().max().item() < 1e-3)
......@@ -390,7 +397,8 @@ class LEDModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
)
out_len = len(outputs)
correct_outlen = 5
# global attention outputs are added as well => so +1 here
correct_outlen = 6
# loss is at first position
if "labels" in inputs_dict:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment