"python-wheel/python/triton_distributed_rs/__init__.py" did not exist on "ffbc06ccf7c9abb40123f3d6ea047caff4609c6c"
Commit 30195c4a authored by Jennifer's avatar Jennifer
Browse files

Adds absolute error comparison function with better messaging.

parent 8f8b537d
......@@ -6,6 +6,7 @@ import sys
import unittest
import numpy as np
import torch
from openfold.config import model_config
from openfold.model.model import AlphaFold
......@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function"
)
return params
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff = torch.abs(expected - actual)
err = compare_func(abs_diff)
zero_tensor = torch.tensor(0, dtype=err.dtype)
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
def assert_max_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
def assert_mean_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
......@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
def test_compare_model(self):
"""
......@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.mean(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
if __name__ == "__main__":
......
......@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
......@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase):
......@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase):
......@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__":
......
......@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase):
......@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase):
......
......@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False,
).cpu()
diff = torch.max(torch.abs(out_gt - out_repro))
self.assertTrue(diff < consts.eps,
msg=f"Found difference between ground truth and reproduction of {diff}")
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase):
......@@ -286,7 +284,7 @@ class Template(unittest.TestCase):
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self):
......
......@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe=True, _inplace_chunk_size=4,
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment