Unverified Commit 7772ddb4 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

fix big bird gpu test (#10967)

parent 86026437
...@@ -556,34 +556,30 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase): ...@@ -556,34 +556,30 @@ class BigBirdModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs[0].attention_type = type config_and_inputs[0].attention_type = type
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
@unittest.skipIf(torch_device == "cpu", "Fast integration only compatible on GPU")
def test_fast_integration(self): def test_fast_integration(self):
torch.manual_seed(0) # fmt: off
input_ids = torch.tensor(
input_ids = torch.randint( [[6, 117, 33, 36, 70, 22, 63, 31, 71, 72, 88, 58, 109, 49, 48, 116, 92, 6, 19, 95, 118, 100, 80, 111, 93, 2, 31, 84, 26, 5, 6, 82, 46, 96, 109, 4, 39, 19, 109, 13, 92, 31, 36, 90, 111, 18, 75, 6, 56, 74, 16, 42, 56, 92, 69, 108, 127, 81, 82, 41, 106, 19, 44, 24, 82, 121, 120, 65, 36, 26, 72, 13, 36, 98, 43, 64, 8, 53, 100, 92, 51, 122, 66, 17, 61, 50, 104, 127, 26, 35, 94, 23, 110, 71, 80, 67, 109, 111, 44, 19, 51, 41, 86, 71, 76, 44, 18, 68, 44, 77, 107, 81, 98, 126, 100, 2, 49, 98, 84, 39, 23, 98, 52, 46, 10, 82, 121, 73],[6, 117, 33, 36, 70, 22, 63, 31, 71, 72, 88, 58, 109, 49, 48, 116, 92, 6, 19, 95, 118, 100, 80, 111, 93, 2, 31, 84, 26, 5, 6, 82, 46, 96, 109, 4, 39, 19, 109, 13, 92, 31, 36, 90, 111, 18, 75, 6, 56, 74, 16, 42, 56, 92, 69, 108, 127, 81, 82, 41, 106, 19, 44, 24, 82, 121, 120, 65, 36, 26, 72, 13, 36, 98, 43, 64, 8, 53, 100, 92, 51, 12, 66, 17, 61, 50, 104, 127, 26, 35, 94, 23, 110, 71, 80, 67, 109, 111, 44, 19, 51, 41, 86, 71, 76, 28, 18, 68, 44, 77, 107, 81, 98, 126, 100, 2, 49, 18, 84, 39, 23, 98, 52, 46, 10, 82, 121, 73]], # noqa: E231
self.model_tester.vocab_size, dtype=torch.long,
(self.model_tester.batch_size, self.model_tester.seq_length),
device=torch_device, device=torch_device,
) )
attention_mask = torch.ones((self.model_tester.batch_size, self.model_tester.seq_length), device=torch_device) # fmt: on
input_ids = input_ids % self.model_tester.vocab_size
input_ids[1] = input_ids[1] - 1
attention_mask = torch.ones((input_ids.shape), device=torch_device)
attention_mask[:, :-10] = 0 attention_mask[:, :-10] = 0
token_type_ids = torch.randint(
self.model_tester.type_vocab_size,
(self.model_tester.batch_size, self.model_tester.seq_length),
device=torch_device,
)
config, _, _, _, _, _, _ = self.model_tester.prepare_config_and_inputs() config, _, _, _, _, _, _ = self.model_tester.prepare_config_and_inputs()
model = BigBirdModel(config).to(torch_device).eval() torch.manual_seed(0)
model = BigBirdModel(config).eval().to(torch_device)
with torch.no_grad(): with torch.no_grad():
hidden_states = model( hidden_states = model(input_ids, attention_mask=attention_mask).last_hidden_state
input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask
).last_hidden_state
self.assertTrue( self.assertTrue(
torch.allclose( torch.allclose(
hidden_states[0, 0, :5], hidden_states[0, 0, :5],
torch.tensor([-0.6326, 0.6124, -0.0844, 0.6698, -1.7155], device=torch_device), torch.tensor([1.4943, 0.0928, 0.8254, -0.2816, -0.9788], device=torch_device),
atol=1e-3, atol=1e-3,
) )
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment