"...python/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "3a5fe17db902203e2d4f1eb13673219dcd9b88f4"
Unverified Commit bc075004 authored by Jennifer Wei's avatar Jennifer Wei Committed by GitHub
Browse files

Merge pull request #404 from jnwei/multimer-small-edits

Type fixes and README changes for multimer branch
parents 5e0616b6 a2adb147
...@@ -5,7 +5,7 @@ jobs: ...@@ -5,7 +5,7 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-python@v4 - uses: actions/setup-python@v5
- run: pip install --upgrade pip - run: pip install --upgrade pip
- run: pip install flake8 - run: pip install flake8
- run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics - run: flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
...@@ -9,4 +9,4 @@ dist ...@@ -9,4 +9,4 @@ dist
data data
openfold/resources/ openfold/resources/
tests/test_data/ tests/test_data/
cutlass/
...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/ ...@@ -13,7 +13,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git RUN apt-get update && apt-get install -y wget libxml2 cuda-minimal-build-11-3 libcusparse-dev-11-3 libcublas-dev-11-3 libcusolver-dev-11-3 git
RUN wget -P /tmp \ RUN wget -P /tmp \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh" \ "https://github.com/conda-forge/miniforge/releases/download/23.3.1-1/Miniforge3-Linux-x86_64.sh" \
&& bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \ && bash /tmp/Miniforge3-Linux-x86_64.sh -b -p /opt/conda \
&& rm /tmp/Miniforge3-Linux-x86_64.sh && rm /tmp/Miniforge3-Linux-x86_64.sh
ENV PATH /opt/conda/bin:$PATH ENV PATH /opt/conda/bin:$PATH
......
...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite: ...@@ -595,4 +595,4 @@ If you use OpenProteinSet, please also cite:
primaryClass={q-bio.BM} primaryClass={q-bio.BM}
} }
``` ```
Any work that cites OpenFold should also cite AlphaFold. Any work that cites OpenFold should also cite [AlphaFold](https://www.nature.com/articles/s41586-021-03819-2) and [AlphaFold-Multimer](https://www.biorxiv.org/content/10.1101/2021.10.04.463034v1) if applicable.
...@@ -110,12 +110,12 @@ def make_sequence_features( ...@@ -110,12 +110,12 @@ def make_sequence_features(
) )
features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32) features["between_segment_residues"] = np.zeros((num_res,), dtype=np.int32)
features["domain_name"] = np.array( features["domain_name"] = np.array(
[description.encode("utf-8")], dtype=np.object_ [description.encode("utf-8")], dtype=object
) )
features["residue_index"] = np.array(range(num_res), dtype=np.int32) features["residue_index"] = np.array(range(num_res), dtype=np.int32)
features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32) features["seq_length"] = np.array([num_res] * num_res, dtype=np.int32)
features["sequence"] = np.array( features["sequence"] = np.array(
[sequence.encode("utf-8")], dtype=np.object_ [sequence.encode("utf-8")], dtype=object
) )
return features return features
...@@ -148,7 +148,7 @@ def make_mmcif_features( ...@@ -148,7 +148,7 @@ def make_mmcif_features(
) )
mmcif_feats["release_date"] = np.array( mmcif_feats["release_date"] = np.array(
[mmcif_object.header["release_date"].encode("utf-8")], dtype=np.object_ [mmcif_object.header["release_date"].encode("utf-8")], dtype=object
) )
mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32) mmcif_feats["is_distillation"] = np.array(0., dtype=np.float32)
......
...@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = { ...@@ -83,8 +83,8 @@ TEMPLATE_FEATURES = {
"template_aatype": np.int64, "template_aatype": np.int64,
"template_all_atom_mask": np.float32, "template_all_atom_mask": np.float32,
"template_all_atom_positions": np.float32, "template_all_atom_positions": np.float32,
"template_domain_names": np.object, "template_domain_names": object,
"template_sequence": np.object, "template_sequence": object,
"template_sum_probs": np.float32, "template_sum_probs": np.float32,
} }
...@@ -101,8 +101,8 @@ def empty_template_feats(n_res): ...@@ -101,8 +101,8 @@ def empty_template_feats(n_res):
"template_all_atom_positions": np.zeros( "template_all_atom_positions": np.zeros(
(0, n_res, residue_constants.atom_type_num, 3), np.float32 (0, n_res, residue_constants.atom_type_num, 3), np.float32
), ),
"template_domain_names": np.array([''.encode()], dtype=np.object), "template_domain_names": np.array([''.encode()], dtype=object),
"template_sequence": np.array([''.encode()], dtype=np.object), "template_sequence": np.array([''.encode()], dtype=object),
"template_sum_probs": np.zeros((0, 1), dtype=np.float32), "template_sum_probs": np.zeros((0, 1), dtype=np.float32),
} }
......
...@@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types( ...@@ -79,7 +79,7 @@ def assert_equal_nonterminal_atom_types(
"""Checks that pre- and post-minimized proteins have same atom set.""" """Checks that pre- and post-minimized proteins have same atom set."""
# Ignore any terminal OXT atoms which may have been added by minimization. # Ignore any terminal OXT atoms which may have been added by minimization.
oxt = residue_constants.atom_order["OXT"] oxt = residue_constants.atom_order["OXT"]
no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=np.bool) no_oxt_mask = np.ones(shape=atom_mask.shape, dtype=bool)
no_oxt_mask[..., oxt] = False no_oxt_mask[..., oxt] = False
np.testing.assert_almost_equal( np.testing.assert_almost_equal(
ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask] ref_atom_mask[no_oxt_mask], atom_mask[no_oxt_mask]
......
...@@ -90,15 +90,15 @@ def get_optimal_transform( ...@@ -90,15 +90,15 @@ def get_optimal_transform(
def get_least_asym_entity_or_longest_length(batch, input_asym_id): def get_least_asym_entity_or_longest_length(batch, input_asym_id):
""" """
First check how many subunit(s) one sequence has, if there is no tie, e.g. AABBB then select First check how many subunit(s) one sequence has. Select the subunit that is less
one of the A as anchor common, e.g. if the protein was AABBB then select one of the A as anchor
If there is a tie, e.g. AABB, first check which sequence is the longer/longest, If there is a tie, e.g. AABB, first check which sequence is the longer/longest,
then choose one of the corresponding subunits as anchor then choose one of the corresponding subunits as anchor
Args: Args:
batch: in this funtion batch is the full ground truth features batch: in this function batch is the full ground truth features
input_asym_id: A list of aym_ids that are in the cropped input features input_asym_id: A list of asym_ids that are in the cropped input features
Return: Return:
anchor_gt_asym_id: Tensor(int) selected ground truth asym_id anchor_gt_asym_id: Tensor(int) selected ground truth asym_id
...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id): ...@@ -126,7 +126,7 @@ def get_least_asym_entity_or_longest_length(batch, input_asym_id):
min_asym_count = min(entity_asym_count.values()) min_asym_count = min(entity_asym_count.values())
least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count] least_asym_entities = [entity for entity, count in entity_asym_count.items() if count == min_asym_count]
# If multiple entities have the least asym_id count, return those with the shortest length # If multiple entities have the least asym_id count, return those with the longest length
if len(least_asym_entities) > 1: if len(least_asym_entities) > 1:
max_length = max([entity_length[entity] for entity in least_asym_entities]) max_length = max([entity_length[entity] for entity in least_asym_entities])
least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length] least_asym_entities = [entity for entity in least_asym_entities if entity_length[entity] == max_length]
......
...@@ -123,7 +123,7 @@ def parse_fasta(data): ...@@ -123,7 +123,7 @@ def parse_fasta(data):
][1:] ][1:]
tags, seqs = lines[::2], lines[1::2] tags, seqs = lines[::2], lines[1::2]
tags = [t.split()[0] for t in tags] tags = [re.split('\W| \|', t)[0] for t in tags]
return tags, seqs return tags, seqs
......
...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args): ...@@ -63,10 +63,7 @@ def precompute_alignments(tags, seqs, alignment_dir, args):
with open(tmp_fasta_path, "w") as fp: with open(tmp_fasta_path, "w") as fp:
fp.write(f">{tag}\n{seq}") fp.write(f">{tag}\n{seq}")
local_alignment_dir = os.path.join( local_alignment_dir = os.path.join(alignment_dir, tag),
alignment_dir,
os.path.join(alignment_dir, tag),
)
if args.use_precomputed_alignments is None: if args.use_precomputed_alignments is None:
logger.info(f"Generating alignments for {tag}...") logger.info(f"Generating alignments for {tag}...")
......
...@@ -113,10 +113,10 @@ else: ...@@ -113,10 +113,10 @@ else:
setup( setup(
name='openfold', name='openfold',
version='1.0.1', version='2.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2', description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz & DeepMind', author='OpenFold Team',
author_email='gahdritz@gmail.com', author_email='jennifer.wei@omsf.io',
license='Apache License, Version 2.0', license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold', url='https://github.com/aqlaboratory/openfold',
packages=find_packages(exclude=["tests", "scripts"]), packages=find_packages(exclude=["tests", "scripts"]),
......
...@@ -6,6 +6,7 @@ import sys ...@@ -6,6 +6,7 @@ import sys
import unittest import unittest
import numpy as np import numpy as np
import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
...@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path): ...@@ -119,3 +120,20 @@ def fetch_alphafold_module_weights(weight_path):
"Make sure to call import_alphafold before running this function" "Make sure to call import_alphafold before running this function"
) )
return params return params
def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
# Helper function for comparing absolute differences of two torch tensors.
abs_diff = torch.abs(expected - actual)
err = compare_func(abs_diff)
zero_tensor = torch.tensor(0, dtype=err.dtype)
rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6
torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)
def assert_max_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.max, expected, actual, eps)
def assert_mean_abs_diff_small(expected, actual, eps):
_assert_abs_diff_small_base(torch.mean, expected, actual, eps)
import ml_collections as mlc import ml_collections as mlc
consts = mlc.ConfigDict(
monomer_consts = mlc.ConfigDict(
{
"model": "model_1_ptm", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": False, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"n_heads_extra_msa": 8,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 23, # monomer: 23, multimer: 22
"template_mmcif_dir": None # Set for test_multimer_datamodule
}
)
multimer_consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3 "model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True "is_multimer": True, # monomer: False, multimer: True
...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict( ...@@ -24,6 +49,8 @@ consts = mlc.ConfigDict(
} }
) )
consts = monomer_consts
config = mlc.ConfigDict( config = mlc.ConfigDict(
{ {
"data": { "data": {
......
...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -244,9 +244,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()} batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = { template_feats = {
k: v for k, v in batch.items() if k.startswith("template_") k: v for k, v in batch.items() if k.startswith("template_")
...@@ -276,8 +273,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -276,8 +273,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
) )
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu() out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds)) compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
self.assertTrue(err < eps, f'Error {err}')
def test_compare_model(self): def test_compare_model(self):
""" """
...@@ -310,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -310,7 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[ batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14" "residx_atom37_to_atom14"
].long() ].long()
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], 21).to(torch.float32) # print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"] batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update( batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch) data_transforms.atom37_to_torsion_angles("template_")(batch)
...@@ -335,8 +332,7 @@ class TestDeepSpeedKernel(unittest.TestCase): ...@@ -335,8 +332,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0) out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0) out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.mean(torch.abs(out_repro - out_repro_ds)) compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
self.assertTrue(err < eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -178,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration" "alphafold/alphafold_iteration/evoformer/evoformer_iteration"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
key = jax.random.PRNGKey(42) key = jax.random.PRNGKey(42)
out_gt = f.apply(params, key, activations, masks) out_gt = f.apply(params, key, activations, masks)
...@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -200,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version # Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0]( out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
...@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -217,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu() out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu() out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase): class TestExtraMSAStack(unittest.TestCase):
...@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -339,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition" + "msa_transition"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -354,8 +354,7 @@ class TestMSATransition(unittest.TestCase):
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase): ...@@ -386,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(), torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention" + "msa_row_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply( out_gt = f.apply(
params, None, msa_act, msa_mask, pair_act params, None, msa_act, msa_mask, pair_act
...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase): ...@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
) )
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase): class TestMSAColumnAttention(unittest.TestCase):
...@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention" + "msa_column_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase): ...@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
) )
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase): class TestMSAColumnGlobalAttention(unittest.TestCase):
...@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" "alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention" + "msa_column_global_attention"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
...@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase): ...@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu() .cpu()
) )
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -74,7 +74,7 @@ class TestOuterProductMean(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/" "alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean" + "evoformer_iteration/outer_product_mean"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready() out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets # Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps. # a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -69,7 +69,7 @@ class TestPairTransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition" + "pair_transition"
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......
...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym ...@@ -21,7 +21,6 @@ from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym
merge_labels) merge_labels)
@unittest.skip("Tests need to be fixed post-refactor")
class TestPermutation(unittest.TestCase): class TestPermutation(unittest.TestCase):
def setUp(self): def setUp(self):
""" """
...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase): ...@@ -65,10 +64,16 @@ class TestPermutation(unittest.TestCase):
'seq_length': torch.tensor([57]) 'seq_length': torch.tensor([57])
} }
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id']) anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
self.assertIn(int(anchor_gt_asym), [1, 2]) anchor_gt_asym = int(anchor_gt_asym)
self.assertNotIn(int(anchor_gt_asym), [3, 4, 5]) anchor_pred_asym = {int(i) for i in anchor_pred_asym}
self.assertIn(int(anchor_pred_asym), [1, 2]) expected_anchors = {1, 2}
self.assertNotIn(int(anchor_pred_asym), [3, 4, 5]) expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self): def test_2_permutation_pentamer(self):
batch = { batch = {
...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase): ...@@ -114,6 +119,7 @@ class TestPermutation(unittest.TestCase):
self.assertIn(aligns, possible_outcome) self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome) self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self): def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325 nres_pad = 325 - 57 # suppose the cropping size is 325
batch = { batch = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment