Commit f1d81db8 authored by alexeib's avatar alexeib Committed by Myle Ott
Browse files

fix tests

parent 97a6b139
...@@ -31,14 +31,18 @@ class TestCharacterTokenEmbedder(unittest.TestCase): ...@@ -31,14 +31,18 @@ class TestCharacterTokenEmbedder(unittest.TestCase):
embs = embedder(input) embs = embedder(input)
assert embs.size() == (len(test_sents), max_len + 2, 5) assert embs.size() == (len(test_sents), max_len + 2, 5)
assert embs[0][0].equal(embs[1][0]) self.assertAlmostEqual(embs[0][0], embs[1][0])
assert embs[0][0].equal(embs[0][-1]) self.assertAlmostEqual(embs[0][0], embs[0][-1])
assert embs[0][1].equal(embs[2][1]) self.assertAlmostEqual(embs[0][1], embs[2][1])
assert embs[0][3].equal(embs[1][1]) self.assertAlmostEqual(embs[0][3], embs[1][1])
embs.sum().backward() embs.sum().backward()
assert embedder.char_embeddings.weight.grad is not None assert embedder.char_embeddings.weight.grad is not None
def assertAlmostEqual(self, t1, t2):
self.assertEqual(t1.size(), t2.size(), "size mismatch")
self.assertLess((t1 - t2).abs().max(), 1e-6)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment