Commit f5516805 authored by Julien Chaumond's avatar Julien Chaumond
Browse files

Fix bart slow test

parent 5bc99e7f
...@@ -314,7 +314,7 @@ class BartModelIntegrationTest(unittest.TestCase): ...@@ -314,7 +314,7 @@ class BartModelIntegrationTest(unittest.TestCase):
output = model.forward(**inputs_dict)[0] output = model.forward(**inputs_dict)[0]
expected_shape = torch.Size((1, 11, 1024)) expected_shape = torch.Size((1, 11, 1024))
self.assertEqual(output.shape, expected_shape) self.assertEqual(output.shape, expected_shape)
expected_slice = torch.Tensor( expected_slice = torch.tensor(
[[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device [[0.7144, 0.8143, -1.2813], [0.7144, 0.8143, -1.2813], [-0.0467, 2.5911, -2.1845]], device=torch_device
) )
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE)) self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=TOLERANCE))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment