"...git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "fdb800dce2ddbfc616e0418c34e2424fbfb96041"
Commit 30195c4a authored by Jennifer's avatar Jennifer
Browse files

Adds absolute error comparison function with better messaging.

parent 8f8b537d
...@@ -6,6 +6,7 @@ import sys ...@@ -6,6 +6,7 @@ import sys
import unittest import unittest
import numpy as np import numpy as np
import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
...@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path): ...@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function" "Make sure to call import_alphafold before running this function"
) )
return params return params
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff = torch.abs(expected - actual)
err = compare_func(abs_diff)
zero_tensor = torch.tensor(0, dtype=err.dtype)
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
def assert_max_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
def assert_mean_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
...@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
) )
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu() out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds)) compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self): def test_compare_model(self):
""" """
...@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0) out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0) out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.mean(torch.abs(out_repro - out_repro_ds)) compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version # Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0]( out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
...@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase): class TestExtraMSAStack(unittest.TestCase):
...@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase): ...@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(), torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
) )
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase): class TestMSAColumnAttention(unittest.TestCase):
...@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
) )
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase): class TestMSAColumnGlobalAttention(unittest.TestCase):
...@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets # Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps. # a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile # The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to # We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data. # have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
...@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(), torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase): class TestAngleResnet(unittest.TestCase):
......
...@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
diff = torch.max(torch.abs(out_gt - out_repro)) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
self.assertTrue(diff < consts.eps,
msg=f"Found difference between ground truth and reproduction of {diff}")
class Template(unittest.TestCase): class Template(unittest.TestCase):
...@@ -286,7 +284,7 @@ class Template(unittest.TestCase): ...@@ -286,7 +284,7 @@ class Template(unittest.TestCase):
out_repro = out_repro_all["template_pair_embedding"] out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe=True, _inplace_chunk_size=4, inplace_safe=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self): def test_tri_mul_out_compare(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment