Unverified Commit 7375758b authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

Fix tests (#14703)

parent 68e53e6f
...@@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase): ...@@ -860,7 +860,8 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
self.assertEqual(logits.shape, expected_shape) self.assertEqual(logits.shape, expected_shape)
expected_slice = torch.tensor( expected_slice = torch.tensor(
[[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]] [[-10.8609, -10.7651, -10.9187], [-12.1689, -11.9389, -12.1479], [-12.1518, -11.9707, -12.2073]],
device=torch_device,
) )
self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4)) self.assertTrue(torch.allclose(logits[0, :3, :3], expected_slice, atol=1e-4))
...@@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase): ...@@ -970,7 +971,7 @@ class PerceiverModelIntegrationTest(unittest.TestCase):
# forward pass # forward pass
with torch.no_grad(): with torch.no_grad():
outputs = model(inputs=patches) outputs = model(inputs=patches.to(torch_device))
logits = outputs.logits logits = outputs.logits
# verify logits # verify logits
......
...@@ -99,17 +99,17 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase): ...@@ -99,17 +99,17 @@ class PerceiverTokenizationTest(TokenizerTesterMixin, unittest.TestCase):
# decoding # decoding
decoded = tokenizer.decode(encoded_ids) decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "<cls>Unicode €.<sep>") self.assertEqual(decoded, "[CLS]Unicode €.[SEP]")
encoded = tokenizer("e è é ê ë") encoded = tokenizer("e è é ê ë")
encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5] encoded_ids = [4, 107, 38, 201, 174, 38, 201, 175, 38, 201, 176, 38, 201, 177, 5]
self.assertEqual(encoded["input_ids"], encoded_ids) self.assertEqual(encoded["input_ids"], encoded_ids)
# decoding # decoding
decoded = tokenizer.decode(encoded_ids) decoded = tokenizer.decode(encoded_ids)
self.assertEqual(decoded, "<cls>e è é ê ë<sep>") self.assertEqual(decoded, "[CLS]e è é ê ë[SEP]")
# encode/decode, but with `encode` instead of `__call__` # encode/decode, but with `encode` instead of `__call__`
self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "<cls>e è é ê ë<sep>") self.assertEqual(tokenizer.decode(tokenizer.encode("e è é ê ë")), "[CLS]e è é ê ë[SEP]")
def test_prepare_batch_integration(self): def test_prepare_batch_integration(self):
tokenizer = self.perceiver_tokenizer tokenizer = self.perceiver_tokenizer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment