Commit e9e3fbdc authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix unit test

parent 7d442323
...@@ -92,7 +92,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=True, _inplace_chunk_size=4, _inplace=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
...@@ -105,7 +105,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_tri_mul_in_compare(self): def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True) self._tri_mul_compare(incoming=True)
def _tri_mul_inference_mode(self, incoming=False): def _tri_mul_inplace(self, incoming=False):
n_res = consts.n_res n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
...@@ -122,23 +122,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -122,23 +122,23 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
out_stock = module( out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=False, _inplace=False,
).cpu() ).cpu()
# This has to come second because inference mode is in-place # This has to come second because inference mode is in-place
out_inference_mode = module( out_inplace = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
_inference_mode=True, _inplace_chunk_size=2, _inplace=True, _inplace_chunk_size=2,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_stock - out_inference_mode)) < consts.eps) self.assertTrue(torch.mean(torch.abs(out_stock - out_inplace)) < consts.eps)
def test_tri_mul_out_inference(self): def test_tri_mul_out_inference(self):
self._tri_mul_inference_mode() self._tri_mul_inplace()
def test_tri_mul_in_inference(self): def test_tri_mul_in_inference(self):
self._tri_mul_inference_mode(incoming=True) self._tri_mul_inplace(incoming=True)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment