Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
30195c4a
"python-wheel/python/triton_distributed_rs/__init__.py" did not exist on "ffbc06ccf7c9abb40123f3d6ea047caff4609c6c"
Commit
30195c4a
authored
Jan 24, 2024
by
Jennifer
Browse files
Adds absolute error comparison function with better messaging.
parent
8f8b537d
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
36 additions
and
23 deletions
+36
-23
tests/compare_utils.py
tests/compare_utils.py
+18
-0
tests/test_deepspeed_evo_attention.py
tests/test_deepspeed_evo_attention.py
+2
-4
tests/test_evoformer.py
tests/test_evoformer.py
+5
-6
tests/test_feats.py
tests/test_feats.py
+1
-1
tests/test_msa.py
tests/test_msa.py
+3
-3
tests/test_outer_product_mean.py
tests/test_outer_product_mean.py
+1
-1
tests/test_structure_module.py
tests/test_structure_module.py
+2
-2
tests/test_template.py
tests/test_template.py
+2
-4
tests/test_triangular_attention.py
tests/test_triangular_attention.py
+1
-1
tests/test_triangular_multiplicative_update.py
tests/test_triangular_multiplicative_update.py
+1
-1
No files found.
tests/compare_utils.py
View file @
30195c4a
...
...
@@ -6,6 +6,7 @@ import sys
import
unittest
import
numpy
as
np
import
torch
from
openfold.config
import
model_config
from
openfold.model.model
import
AlphaFold
...
...
@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function"
)
return
params
def
_assert_abs_diff_small_base
(
compare_func
,
expected
,
actual
,
eps
):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff
=
torch
.
abs
(
expected
-
actual
)
err
=
compare_func
(
abs_diff
)
zero_tensor
=
torch
.
tensor
(
0
,
dtype
=
err
.
dtype
)
rtol
=
1.6e-2
if
err
.
dtype
==
torch
.
bfloat16
else
1.3e-6
torch
.
testing
.
assert_close
(
err
,
zero_tensor
,
atol
=
eps
,
rtol
=
rtol
)
def
assert_max_abs_diff_small
(
expected
,
actual
,
eps
):
_assert_abs_diff_small_base
(
torch
.
max
,
expected
,
actual
,
eps
)
def
assert_mean_abs_diff_small
(
expected
,
actual
,
eps
):
_assert_abs_diff_small_base
(
torch
.
mean
,
expected
,
actual
,
eps
)
tests/test_deepspeed_evo_attention.py
View file @
30195c4a
...
...
@@ -276,8 +276,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds
=
out_repro_ds
[
"template_pair_embedding"
].
cpu
()
err
=
torch
.
max
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Error
{
err
}
'
)
compare_utils
.
assert_max_abs_diff_small
(
out_repro
,
out_repro_ds
,
eps
)
def
test_compare_model
(
self
):
"""
...
...
@@ -335,8 +334,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro
=
out_repro
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
out_repro_ds
=
out_repro_ds
[
"sm"
][
"positions"
][
-
1
].
squeeze
(
0
)
err
=
torch
.
mean
(
torch
.
abs
(
out_repro
-
out_repro_ds
))
self
.
assertTrue
(
err
<
eps
,
f
'Error:
{
err
}
'
)
compare_utils
.
assert_mean_abs_diff_small
(
out_repro
,
out_repro_ds
,
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_evoformer.py
View file @
30195c4a
...
...
@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro
_msa
-
out_
gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
_pair
-
out_
gt
_pair
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
_msa
,
out_
repro_msa
,
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
_pair
,
out_
repro
_pair
,
consts
.
eps
)
# Inplace version
out_repro_msa
,
out_repro_pair
=
model
.
evoformer
.
blocks
[
0
](
...
...
@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa
=
out_repro_msa
.
cpu
()
out_repro_pair
=
out_repro_pair
.
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_repro
_msa
-
out_
gt_msa
))
<
consts
.
eps
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_repro
_pair
-
out_
gt
_pair
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
_msa
,
out_
repro_msa
,
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
_pair
,
out_
repro
_pair
,
consts
.
eps
)
class
TestExtraMSAStack
(
unittest
.
TestCase
):
...
...
@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
unittest
.
main
()
tests/test_feats.py
View file @
30195c4a
...
...
@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch
.
tensor
(
restype_atom14_rigid_group_positions
).
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_msa.py
View file @
30195c4a
...
...
@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestMSAColumnAttention
(
unittest
.
TestCase
):
...
...
@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestMSAColumnGlobalAttention
(
unittest
.
TestCase
):
...
...
@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.
cpu
()
)
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
)
<
consts
.
eps
)
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_outer_product_mean.py
View file @
30195c4a
...
...
@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
5e-4
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
5e-4
)
if
__name__
==
"__main__"
:
...
...
tests/test_structure_module.py
View file @
30195c4a
...
...
@@ -197,7 +197,7 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data.
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
0.05
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
0.05
)
class
TestInvariantPointAttention
(
unittest
.
TestCase
):
...
...
@@ -321,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch
.
as_tensor
(
sample_mask
.
squeeze
(
-
1
)).
float
().
cuda
(),
).
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
TestAngleResnet
(
unittest
.
TestCase
):
...
...
tests/test_template.py
View file @
30195c4a
...
...
@@ -191,9 +191,7 @@ class TestTemplatePairStack(unittest.TestCase):
_mask_trans
=
False
,
).
cpu
()
diff
=
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
self
.
assertTrue
(
diff
<
consts
.
eps
,
msg
=
f
"Found difference between ground truth and reproduction of
{
diff
}
"
)
compare_utils
.
assert_max_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
class
Template
(
unittest
.
TestCase
):
...
...
@@ -286,7 +284,7 @@ class Template(unittest.TestCase):
out_repro
=
out_repro_all
[
"template_pair_embedding"
]
out_repro
=
out_repro
.
cpu
()
self
.
assertTrue
(
torch
.
max
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
if
__name__
==
"__main__"
:
...
...
tests/test_triangular_attention.py
View file @
30195c4a
...
...
@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size
=
None
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_att_end_compare
(
self
):
...
...
tests/test_triangular_multiplicative_update.py
View file @
30195c4a
...
...
@@ -103,7 +103,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
inplace_safe
=
True
,
_inplace_chunk_size
=
4
,
).
cpu
()
self
.
assertTrue
(
torch
.
mean
(
torch
.
abs
(
out_gt
-
out_repro
))
<
consts
.
eps
)
compare_utils
.
assert_mean_abs_diff_small
(
out_gt
,
out_repro
,
consts
.
eps
)
@
compare_utils
.
skip_unless_alphafold_installed
()
def
test_tri_mul_out_compare
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment