Unverified Commit 73893fc7 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[BigBird Pegasus] Make tests faster (#11744)



* improve tests

* remove bogus file

* make style
Co-authored-by: default avatarPatrick von Platen <patrick@huggingface.co>
parent a0531c8a
......@@ -368,17 +368,24 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
self._check_batched_forward(attn_type="block_sparse", tolerance=1e-1)
def _check_batched_forward(self, attn_type, tolerance=1e-3):
config = BigBirdPegasusConfig(block_size=16, attention_type=attn_type)
config, _ = self.model_tester.prepare_config_and_inputs()
config.max_position_embeddings = 128
config.block_size = 16
config.attention_type = attn_type
model = BigBirdPegasusForConditionalGeneration(config).to(torch_device)
model.eval()
sample_with_padding = [3, 8, 11] * 128 + [0] * 128
sample_without_padding = [4, 7, 9, 13] * 128
chunk_length = 32
sample_with_padding = [3, 8, 11] * chunk_length + [0] * chunk_length
sample_without_padding = [4, 7, 9, 13] * chunk_length
target_ids_without_padding = [2, 3] * 8
target_ids_with_padding = [7, 8] * 6 + 4 * [-100]
attention_mask = torch.tensor(
[[1] * 3 * 128 + [0] * 128, [1] * 4 * 128], device=torch_device, dtype=torch.long
[[1] * 3 * chunk_length + [0] * chunk_length, [1] * 4 * chunk_length],
device=torch_device,
dtype=torch.long,
)
input_ids = torch.tensor([sample_with_padding, sample_without_padding], device=torch_device, dtype=torch.long)
......@@ -390,7 +397,7 @@ class BigBirdPegasusModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.
logits_batched = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels).logits
with torch.no_grad():
logits_single_first = model(input_ids=input_ids[:1, :-128], labels=labels[:1]).logits
logits_single_first = model(input_ids=input_ids[:1, :-chunk_length], labels=labels[:1]).logits
self.assertTrue(torch.allclose(logits_batched[0, -3:], logits_single_first[0, -3:], atol=tolerance))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment