Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
>query
MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
>tr|A0A2W3M096|A0A2W3M096_STAAU Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=C7Q70_14145 PE=4 SV=1
-------------------MDKKETQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
>tr|A0A0Q9XW80|A0A0Q9XW80_9STAP Uncharacterized protein OS=Staphylococcus sp. NAM3COL9 GN=ACA31_00310 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
>tr|A0A1E5U0W4|A0A1E5U0W4_STAXY Uncharacterized protein OS=Staphylococcus xylosus GN=AST15_04830 PE=4 SV=1
-------------------MSKQETNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGDTpP
This source diff could not be displayed because it is too large. You can view the blob instead.
# STOCKHOLM 1.0
#=GS MGYP000048211747/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000256545448/1-51 DE [subseq from] PL=00 UP=0 BIOMES=0000000011000
#=GS MGYP000517307434/104-157 DE [subseq from] PL=11 UP=0 BIOMES=0000000011000
#=GS MGYP000971940026/195-224 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/46-74 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
#=GS MGYP000859660985/83-111 DE [subseq from] PL=10 UP=0 BIOMES=0110000000000
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000048211747/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
MGYP000256545448/1-51 -------------------MKKKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000517307434/104-157 ----------------GDLLRQKETQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
MGYP000971940026/195-224 ------------------------------VKKSDLGQVTSFLKEVPEGKKQDVLDEVLK----------
MGYP000859660985/46-74 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
MGYP000859660985/83-111 ------------------------------IKKSDLGQVASFLKEVPEGQKQEVLDQVL-----------
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus schleiferi OX=1295 GN=NP71_p00120 PE=4 SV=1
#=GS tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus (strain USA300) OX=367830 GN=SAUSA300_pUSA030035 PE=4 SV=1
#=GS tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 DE [subseq from] Putative plasmid segregation protein ParR OS=Staphylococcus pseudintermedius OX=283734 GN=parR PE=4 SV=1
#=GS tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus epidermidis OX=1282 GN=parR PE=4 SV=1
#=GS tr|A0A133QXU6|A0A133QXU6_STASI/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus simulans OX=1286 GN=HMPREF3215_00002 PE=4 SV=1
#=GS tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 DE [subseq from] DNA-binding protein OS=Staphylococcus xylosus OX=1288 GN=p11 PE=4 SV=1
#=GS tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 DE [subseq from] ParR OS=Staphylococcus aureus subsp. aureus RN4220 OX=561307 GN=pGO400_p33 PE=4 SV=1
#=GS tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus lugdunensis OX=28035 GN=parR PE=4 SV=1
#=GS tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 DE [subseq from] DNA-binding protein ParR OS=Staphylococcus aureus OX=1280 GN=parR PE=4 SV=1
#=GS tr|A0A418HED5|A0A418HED5_STAGA/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus gallinarum OX=1293 GN=BUY97_07835 PE=4 SV=1
#=GS tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus sp. SKL71187 OX=2497688 GN=EKV43_01520 PE=4 SV=1
#=GS tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus hominis OX=1290 GN=FOB69_12695 PE=4 SV=1
#=GS tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 DE [subseq from] Plasmid segregation protein OS=Staphylococcus equorum OX=246432 PE=4 SV=1
#=GS tr|A0A848F022|A0A848F022_STACP/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus capitis OX=29388 GN=HHM13_04665 PE=4 SV=1
#=GS tr|O87365|O87365_STAAU/1-51 DE [subseq from] Conserved domain protein OS=Staphylococcus aureus OX=1280 GN=parR PE=1 SV=1
#=GS tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 DE [subseq from] Plasmid segregation protein ParR OS=Staphylococcus aureus OX=1280 PE=4 SV=1
#=GS tr|E4PYH1|E4PYH1_STAAU/1-39 DE [subseq from] DUF655 domain-containing protein OS=Staphylococcus aureus OX=1280 GN=SUM_0041p2 PE=4 SV=1
#=GS tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 DE [subseq from] RHH_1 domain-containing protein OS=Staphylococcus sp. NAM3COL9 OX=1667172 GN=ACA31_00310 PE=4 SV=1
query MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0K0ME10|A0A0K0ME10_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0C5BVQ8|A0A0C5BVQ8_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0D4ZYK6|A0A0D4ZYK6_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0H2XKQ4|A0A0H2XKQ4_STAA3/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0N9NJL4|A0A0N9NJL4_STAPS/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A0U2CJ65|A0A0U2CJ65_STAEP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A133QXU6|A0A133QXU6_STASI/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141BHY3|A0A141BHY3_STAXY/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A141HMK9|A0A141HMK9_STAA8/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1B1UXS0|A0A1B1UXS0_STALU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A1S7BGJ1|A0A1S7BGJ1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A418HED5|A0A418HED5_STAGA/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A507SJ94|A0A507SJ94_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A6N0I4W4|A0A6N0I4W4_STAHO/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3T6L6|A0A7G3T6L6_9STAP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A848F022|A0A848F022_STACP/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|O87365|O87365_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|A0A7G3L2E1|A0A7G3L2E1_STAAU/1-51 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
tr|E4PYH1|E4PYH1_STAAU/1-39 -------------------MDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREA------------
tr|A0A0Q9XW80|A0A0Q9XW80_9STAP/1-51 -------------------MSKQETNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
#=GC RF xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
# STOCKHOLM 1.0
#=GS UniRef90_A0A141BHY3/1-51 DE [subseq from] DNA-binding protein n=37 Tax=Staphylococcaceae TaxID=90964 RepID=A0A141BHY3_STAXY
#=GS UniRef90_UPI000A061283/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI000A061283
#=GS UniRef90_UPI001E649B27/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Mammaliicoccus sciuri TaxID=1296 RepID=UPI001E649B27
#=GS UniRef90_UPI00201A2D50/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI00201A2D50
#=GS UniRef90_UPI0018EDBA69/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0018EDBA69
#=GS UniRef90_UPI0005E12F5A/1-51 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Staphylococcus TaxID=1279 RepID=UPI0005E12F5A
#=GS UniRef90_UPI00207B21F3/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus TaxID=1279 RepID=UPI00207B21F3
#=GS UniRef90_UPI0009836679/1-51 DE [subseq from] plasmid segregation protein ParR n=2 Tax=Staphylococcus aureus TaxID=1280 RepID=UPI0009836679
#=GS UniRef90_UPI001F5439CD/1-51 DE [subseq from] plasmid segregation protein ParR n=11 Tax=Staphylococcaceae TaxID=90964 RepID=UPI001F5439CD
#=GS UniRef90_UPI000DA9B884/1-51 DE [subseq from] plasmid segregation protein ParR n=3 Tax=Bacillales TaxID=1385 RepID=UPI000DA9B884
#=GS UniRef90_A0A0Q9XW80/1-51 DE [subseq from] RHH_1 domain-containing protein n=1 Tax=Staphylococcus sp. NAM3COL9 TaxID=1667172 RepID=A0A0Q9XW80_9STAP
#=GS UniRef90_UPI001CCC4088/3-48 DE [subseq from] plasmid segregation protein ParR n=1 Tax=Macrococcus armenti TaxID=2875764 RepID=UPI001CCC4088
#=GS UniRef90_UPI0014612D4C/1-49 DE [subseq from] De novo designed WSHC6 n=2 Tax=synthetic construct TaxID=32630 RepID=UPI0014612D4C
#=GS UniRef90_UPI000B802FE5/1-42 DE [subseq from] HEEH_rd4_0097 n=1 Tax=Escherichia coli TaxID=562 RepID=UPI000B802FE5
#=GS UniRef90_UPI001E281CEB/1-54 DE [subseq from] Network hallucinated protein 0738_mod n=1 Tax=synthetic construct TaxID=32630 RepID=UPI001E281CEB
query MGSSHHHHHHSSGLVP-GSHMDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_A0A141BHY3/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP
UniRef90_UPI000A061283/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGENP
UniRef90_UPI001E649B27/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI00201A2D50/1-51 --------------------MEKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEMGDNP
UniRef90_UPI0018EDBA69/1-51 --------------------MDKKE-TKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALLRYIEEFGENP
UniRef90_UPI0005E12F5A/1-51 --------------------MKKKE-TQHLLKIKKEDYPQIFDFLEGLPRGTKTAHIREALLRYIADEGENP
UniRef90_UPI00207B21F3/1-51 --------------------MSKQE-TNHLLKIKKEDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGSP
UniRef90_UPI0009836679/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGQNP
UniRef90_UPI001F5439CD/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFDFLENVPKGTKTAHIREALIRYINDLGGTP
UniRef90_UPI000DA9B884/1-51 --------------------MDKKE-TQHLLKIKKQDYPQIFNFLEGLPKGTKTAHIREALMRYIAEEGNTP
UniRef90_A0A0Q9XW80/1-51 --------------------MSKQE-TNHLLKIKKKDYPQIFEFLEGVPKGTKTAHIREALLRYIEELGAPP
UniRef90_UPI001CCC4088/3-48 ----------------------KEV-NQTLLKIDKAEYPEIYDFLENVPRGTKTAHIREALIRYINDIN---
UniRef90_UPI0014612D4C/1-49 MGSSHHHHHHSSGLVPRGSHMTEDE-IRKLRKLLEEAEKKLYKLEDKTRR----------------------
UniRef90_UPI000B802FE5/1-42 MGSSHHHHHHSSGLVPRGSHMDVEEQIRRLEEVLKKNQPVTW------------------------------
UniRef90_UPI001E281CEB/1-54 MGSSHHHHHHSSGLVPRGSHMNIQV-SLQWE---DPKKGKVFSHTVNIPPGGTAEQIA--------------
#=GC RF xxxxxxxxxxxxxxxx.xxxxxxxx.xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
//
{"2q2k": {"release_date": "2008-02-05", "chain_ids": ["A", "B"], "seqs": ["MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP", "MGSSHHHHHHSSGLVPGSHMDKKETKHLLKIKKEDYPQIFDFLENVPRGTKTAHIREALRRYIEEIGENP"], "no_chains": 2, "resolution": 3.0}}
\ No newline at end of file
......@@ -15,19 +15,13 @@
import pickle
import shutil
import torch
import numpy as np
import unittest
from openfold.data.data_pipeline import DataPipeline
from openfold.data.templates import TemplateHitFeaturizer
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
from openfold.data.templates import HhsearchHitFeaturizer, HmmsearchHitFeaturizer
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -45,7 +39,23 @@ class TestDataPipeline(unittest.TestCase):
with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp:
alphafold_feature_dict = pickle.load(fp)
template_featurizer = TemplateHitFeaturizer(
if consts.is_multimer:
# template_featurizer = HmmsearchHitFeaturizer(
# mmcif_dir="tests/test_data/mmcifs",
# max_template_date="2021-12-20",
# max_hits=20,
# kalign_binary_path=shutil.which("kalign"),
# _zero_center_positions=False,
# )
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
else:
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
......
import copy
import gzip
import os
import pickle
import numpy as np
......@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15)
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
assert 'bert_mask' in protein
assert 'true_msa' in protein
assert 'msa' in protein
......
......@@ -236,19 +236,25 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
if consts.is_multimer:
batch["asym_id"] = batch['asym_id'][0]
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
}
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
......@@ -258,7 +264,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
......@@ -266,15 +273,14 @@ class TestDeepSpeedKernel(unittest.TestCase):
)
out_repro_ds = out_repro_ds["template_pair_embedding"].cpu()
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error {err}')
compare_utils.assert_max_abs_diff_small(out_repro, out_repro_ds, eps)
def test_compare_model(self):
"""
Run full model with and without using DeepSpeed Evoformer attention kernel
and compare output coordinates.
"""
eps = 0.5
eps = 0.2
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
......@@ -283,6 +289,15 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
batch["no_recycling_iters"] = np.array([3., 3., 3., 3., ])
if consts.is_multimer:
n_res = batch['aatype'].shape[1]
n_extra_seq = batch['extra_msa'].shape[1]
batch["asym_id"] = np.ones((4, n_res))
batch["entity_id"] = np.ones((4, n_res))
batch["sym_id"] = np.ones((4, n_res))
batch["extra_deletion_matrix"] = np.random.randint(0, 2, size=(4, n_extra_seq, n_res))
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
batch["aatype"] = batch["aatype"].long()
......@@ -291,6 +306,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
batch["residx_atom37_to_atom14"] = batch[
"residx_atom37_to_atom14"
].long()
# print(batch["target_feat"].shape)
batch["target_feat"] = torch.nn.functional.one_hot(batch["aatype"], consts.msa_logits - 1).to(torch.float32)
batch["template_all_atom_mask"] = batch["template_all_atom_masks"]
batch.update(
data_transforms.atom37_to_torsion_angles("template_")(batch)
......@@ -299,7 +316,6 @@ class TestDeepSpeedKernel(unittest.TestCase):
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
model = compare_utils.get_global_pretrained_openfold()
......@@ -316,8 +332,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
out_repro = out_repro["sm"]["positions"][-1].squeeze(0)
out_repro_ds = out_repro_ds["sm"]["positions"][-1].squeeze(0)
err = torch.max(torch.abs(out_repro - out_repro_ds))
self.assertTrue(err < eps, f'Error: {err}')
compare_utils.assert_mean_abs_diff_small(out_repro, out_repro_ds, eps)
if __name__ == "__main__":
......
......@@ -12,15 +12,18 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import numpy as np
import unittest
from tests.config import consts
from tests.data_utils import random_asym_ids
from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
PreembeddingEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
TemplateSingleEmbedder,
TemplatePairEmbedder
)
......@@ -36,13 +39,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res = 17
n_clust = 19
max_relative_chain = 2
max_relative_idx = 32
use_chain_relative = True
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
msa_emb, pair_emb = ie(batch)
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
msa_emb, pair_emb = ie(tf, ri, msa)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......@@ -99,7 +119,7 @@ class TestTemplateAngleEmbedder(unittest.TestCase):
n_templ = 4
n_res = 256
tae = TemplateAngleEmbedder(
tae = TemplateSingleEmbedder(
template_angle_dim,
c_m,
)
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -65,8 +68,10 @@ class TestEvoformerStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
no_column_attention=False,
opm_first=opm_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
).eval()
......@@ -121,8 +126,10 @@ class TestEvoformerStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
blocks_per_ckpt=None,
no_column_attention=True,
opm_first=False,
fuse_projection_weights=False,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
).eval()
......@@ -171,7 +178,7 @@ class TestEvoformerStack(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
key = jax.random.PRNGKey(42)
out_gt = f.apply(params, key, activations, masks)
......@@ -193,8 +200,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
# Inplace version
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
......@@ -210,8 +217,8 @@ class TestEvoformerStack(unittest.TestCase):
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
self.assertTrue(torch.mean(torch.abs(out_repro_msa - out_gt_msa)) < consts.eps)
self.assertTrue(torch.max(torch.abs(out_repro_pair - out_gt_pair)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt_msa, out_repro_msa, consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt_pair, out_repro_pair, consts.eps)
class TestExtraMSAStack(unittest.TestCase):
......@@ -231,6 +238,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n = 5
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -247,6 +256,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
ckpt=False,
inf=inf,
eps=eps,
......@@ -328,7 +339,7 @@ class TestMSATransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -336,15 +347,14 @@ class TestMSATransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core.msa_transition(
model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
from tests.data_utils import random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -40,6 +43,20 @@ if compare_utils.alphafold_is_installed():
class TestFeats(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask):
......@@ -131,7 +148,9 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames(
if consts.is_multimer:
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
return self.am_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask
)
......@@ -150,9 +169,23 @@ class TestFeats(unittest.TestCase):
}
out_gt = f.apply({}, None, **batch)
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())))
else:
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def rigid3x4_to_4x4(rigid3arr):
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
four_by_four[..., :3, :4] = rigid3arr
four_by_four[..., 3, 3] = 1
return four_by_four
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:]
......@@ -164,10 +197,12 @@ class TestFeats(unittest.TestCase):
return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
out_gt["rigidgroups_gt_frames"] = convert_func(
out_gt["rigidgroups_gt_frames"]
)
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_alt_gt_frames"] = convert_func(
out_gt["rigidgroups_alt_gt_frames"]
)
......@@ -187,6 +222,12 @@ class TestFeats(unittest.TestCase):
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2))
......@@ -208,7 +249,7 @@ class TestFeats(unittest.TestCase):
def run_torsion_angles_to_frames(
aatype, backb_to_global, torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
return self.am_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos,
......@@ -221,7 +262,14 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
......@@ -240,13 +288,21 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
if consts.is_multimer:
rots_gt = torch.as_tensor(np.array(out_gt_rot))
trans_gt = torch.as_tensor(np.array(out_gt_trans))
else:
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1
......@@ -264,6 +320,12 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
......@@ -282,8 +344,7 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines):
am = alphafold.model
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
return self.am_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines
)
......@@ -294,13 +355,24 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
if consts.is_multimer:
out_gt = torch.as_tensor(np.array(out_gt.to_array()))
else:
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
......@@ -314,7 +386,7 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
......@@ -12,26 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import torch
import numpy as np
import unittest
from pathlib import Path
from tests.config import consts
from openfold.config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from openfold.utils.import_weights import import_jax_weights_, import_openfold_weights_
class TestImportWeights(unittest.TestCase):
def test_import_jax_weights_(self):
npz_path = "openfold/resources/params/params_model_1_ptm.npz"
npz_path = Path(__file__).parent.resolve() / f"../openfold/resources/params/params_{consts.model}.npz"
c = model_config("model_1_ptm")
c = model_config(consts.model)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
import_jax_weights_(
model,
npz_path,
version=consts.model
)
data = np.load(npz_path)
......@@ -65,9 +70,26 @@ class TestImportWeights(unittest.TestCase):
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
for w_alpha, w_repro in test_pairs:
self.assertTrue(torch.all(w_alpha == w_repro))
def test_import_openfold_weights_(self):
model_name = 'initial_training'
pt_path = Path(__file__).parent.resolve() / f"../openfold/resources/openfold_params/{model_name}.pt"
if os.path.exists(pt_path):
c = model_config(model_name)
c.globals.blocks_per_ckpt = None
model = AlphaFold(c)
model.eval()
d = torch.load(pt_path)
import_openfold_weights_(
model=model,
state_dict=d,
)
......@@ -13,18 +13,18 @@
# limitations under the License.
import os
import math
import torch
import numpy as np
from pathlib import Path
import unittest
import ml_collections as mlc
from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats
from openfold.utils.loss import (
torsion_angle_loss,
compute_fape,
......@@ -43,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss,
tm_loss,
compute_plddt,
compute_tm,
chain_center_of_mass_loss
)
from openfold.utils.tensor_utils import (
tree_map,
......@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -64,7 +66,31 @@ def affine_vector_to_4x4(affine):
return r.to_tensor_4x4()
def affine_vector_to_rigid(am_rigid, affine):
rigid_flat = np.split(affine, 7, axis=-1)
rigid_flat = [r.squeeze(-1) for r in rigid_flat]
qw, qx, qy, qz = rigid_flat[:4]
trans = rigid_flat[4:]
rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
translation = am_rigid.Vec3Array(*trans)
return am_rigid.Rigid3Array(rotation, translation)
class TestLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
......@@ -127,7 +153,10 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss(
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
......@@ -184,12 +213,22 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_clash_loss_compare(self):
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind):
return alphafold.model.all_atom.between_residue_clash_loss(
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
asym_id
)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind
)
f = hk.transform(run_brcl)
......@@ -198,10 +237,24 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange(
n_res,
)
residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = np.array(atomtype_radius).astype(np.float32)
atom_radius = (
atom_exists
* atomtype_radius[residx_atom14_to_atom37]
)
asym_id = None
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
out_gt = f.apply(
{},
......@@ -210,6 +263,7 @@ class TestLoss(unittest.TestCase):
atom_exists,
atom_radius,
res_ind,
asym_id
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
......@@ -219,6 +273,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(atom_exists).cuda(),
torch.tensor(atom_radius).cuda(),
torch.tensor(res_ind).cuda(),
torch.tensor(asym_id).cuda() if asym_id is not None else None,
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
......@@ -242,6 +297,36 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt - out_repro)) < consts.eps
)
@compare_utils.skip_unless_alphafold_installed()
def test_compute_ptm_compare(self):
n_res = consts.n_res
max_bin = 31
no_bins = 64
logits = np.random.rand(n_res, n_res, no_bins)
boundaries = np.linspace(0, max_bin, num=(no_bins - 1))
ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
ptm_gt = torch.tensor(ptm_gt)
logits_t = torch.tensor(logits)
ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)
self.assertTrue(
torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
)
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
asym_id=asym_id, interface=True)
iptm_gt = torch.tensor(iptm_gt)
iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
asym_id=torch.tensor(asym_id), interface=True)
self.assertTrue(
torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
)
def test_find_structural_violations(self):
n = consts.n_res
......@@ -265,8 +350,21 @@ class TestLoss(unittest.TestCase):
def test_find_structural_violations_compare(self):
def run_fsv(batch, pos, config):
cwd = os.getcwd()
os.chdir("tests/test_data")
loss = alphafold.model.folding.find_structural_violations(
fpath = Path(__file__).parent.resolve() / "test_data"
os.chdir(str(fpath))
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
return self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
config,
batch['asym_id']
)
loss = self.am_fold.find_structural_violations(
batch,
pos,
config,
......@@ -287,6 +385,9 @@ class TestLoss(unittest.TestCase):
).astype(np.int64),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict(
......@@ -380,14 +481,14 @@ class TestLoss(unittest.TestCase):
n_seq = consts.n_seq
value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
"logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
}
batch = {
"true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32
),
)
}
out_gt = f.apply({}, None, value, batch)["loss"]
......@@ -399,7 +500,9 @@ class TestLoss(unittest.TestCase):
with torch.no_grad():
out_repro = masked_msa_loss(
value["logits"],
**batch,
batch["true_msa"],
batch["bert_mask"],
consts.msa_logits
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
......@@ -506,10 +609,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch):
if consts.is_multimer:
pred_angles = np.reshape(
value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])
unnormed_angles = np.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])
chi_loss, _, _ = self.am_fold.supervised_chi_loss(
batch['seq_mask'],
batch['chi_mask'],
batch['aatype'],
batch['chi_angles'],
pred_angles,
unnormed_angles,
c_chi_loss
)
return chi_loss
ret = {
"loss": jax.numpy.array(0.0),
}
alphafold.model.folding.supervised_chi_loss(
self.am_fold.supervised_chi_loss(
ret, batch, value, c_chi_loss
)
return ret["loss"]
......@@ -561,6 +682,40 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
loss_sum_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=False, **batch
)
loss_sum_clash = loss_sum_clash.cpu()
loss_avg_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=True, **batch
)
loss_avg_clash = loss_avg_clash.cpu()
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config()
......@@ -570,15 +725,31 @@ class TestLoss(unittest.TestCase):
ret = {
"loss": np.array(0.0).astype(np.float32),
}
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
viol = self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
c_viol,
batch['asym_id']
)
return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
violations=viol,
config=c_viol)
value = {}
value[
"violations"
] = alphafold.model.folding.find_structural_violations(
] = self.am_fold.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
alphafold.model.folding.structural_violation_loss(
self.am_fold.structural_violation_loss(
ret,
batch,
value,
......@@ -593,13 +764,17 @@ class TestLoss(unittest.TestCase):
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
"aatype": np.random.randint(0, 21, (n_res,))
}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -676,10 +851,31 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value):
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.interface_fape,
pair_mask=1. - intra_chain_mask)
return intra_chain_bb_loss + interface_bb_loss
ret = {
"loss": np.array(0.0),
}
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
self.am_fold.backbone_loss(ret, batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_bb_loss)
......@@ -691,7 +887,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"use_clamped_fape": np.array(0.0),
"use_clamped_fape": np.array(0.0)
}
value = {
......@@ -703,6 +899,9 @@ class TestLoss(unittest.TestCase):
),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
out_gt = f.apply({}, None, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -715,6 +914,16 @@ class TestLoss(unittest.TestCase):
)
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None]
== batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
**{**batch, **c_sm.intra_chain_fape})
interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
**{**batch, **c_sm.interface_fape})
out_repro = intra_chain_out + interface_out
out_repro = out_repro.cpu()
else:
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
......@@ -726,9 +935,29 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions):
if consts.is_multimer:
atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
aatype=batch["aatype"], all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
aatype=batch["aatype"],
all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"],
use_alt=alt_naming_is_better)
return self.am_fold.sidechain_loss(gt_sc_frames,
gt_sc_frames_mask,
gt_positions,
gt_mask,
pred_frames,
pred_positions,
c_sm)['loss']
batch = {
**batch,
**alphafold.model.all_atom.atom37_to_frames(
**self.am_atom.atom37_to_frames(
batch["aatype"],
batch["all_atom_positions"],
batch["all_atom_mask"],
......@@ -738,21 +967,21 @@ class TestLoss(unittest.TestCase):
v["sidechains"] = {}
v["sidechains"][
"frames"
] = alphafold.model.r3.rigids_from_tensor4x4(
] = self.am_rigid.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
value["sidechains"]["atom_pos"]
)
v.update(
alphafold.model.folding.compute_renamed_ground_truth(
self.am_fold.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
)
)
value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
ret = self.am_fold.sidechain_loss(batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_sidechain_loss)
......@@ -816,6 +1045,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
......@@ -882,6 +1112,33 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
batch = {
"all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
"all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
"asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
}
config = {
"weight": 0.05,
"clamp_distance": -4.0,
}
final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = chain_center_of_mass_loss(
all_atom_pred_pos=final_atom_positions,
**{**batch, **config},
)
out_repro = out_repro.cpu()
if __name__ == "__main__":
unittest.main()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pickle
import torch
import torch.nn as nn
......@@ -20,13 +21,13 @@ import unittest
from openfold.config import model_config
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
from openfold.utils.tensor_utils import tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
random_template_feats,
random_extra_msa_feats,
random_asym_ids
)
if compare_utils.alphafold_is_installed():
......@@ -36,13 +37,27 @@ if compare_utils.alphafold_is_installed():
class TestModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config("model_1")
c = model_config(consts.model)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
......@@ -51,29 +66,39 @@ class TestModel(unittest.TestCase):
model.eval()
batch = {}
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,)).cuda()
tf = torch.randint(c.model.input_embedder.tf_dim - 1, size=(n_res,))
batch["target_feat"] = nn.functional.one_hot(
tf, c.model.input_embedder.tf_dim
).float().cuda()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1).cuda()
batch["residue_index"] = torch.arange(n_res).cuda()
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)).cuda()
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v).cuda() for k, v in t_feats.items()})
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
extra_feats = random_extra_msa_feats(n_extra_seq, n_res)
batch.update({k: torch.tensor(v).cuda() for k, v in extra_feats.items()})
batch.update({k: torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(n_seq, n_res)
).float().cuda()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float().cuda()
).float()
batch["seq_mask"] = torch.randint(low=0, high=2, size=(n_res,)).float()
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.).cuda()
batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer:
batch["asym_id"] = torch.as_tensor(random_asym_ids(n_res))
batch["entity_id"] = batch["asym_id"].clone()
batch["sym_id"] = torch.ones(n_res)
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
batch = tensor_tree_map(add_recycling_dims, batch)
to_cuda_device = lambda t: t.cuda()
batch = tensor_tree_map(to_cuda_device, batch)
with torch.no_grad():
out = model(batch)
......@@ -118,10 +143,14 @@ class TestModel(unittest.TestCase):
out = model(batch)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer, "Additional changes required for multimer.")
def test_compare(self):
#TODO: Fix test data for multimer MSA features
def run_alphafold(batch):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
model = self.am_modules.AlphaFold(config.model)
return model(
batch=batch,
is_training=False,
......@@ -132,7 +161,8 @@ class TestModel(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
fpath = Path(__file__).parent.resolve() / "test_data/sample_feats.pickle"
with open(str(fpath), "rb") as fp:
batch = pickle.load(fp)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
......@@ -141,7 +171,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = self.am_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
......
......@@ -79,7 +79,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_row_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(
params, None, msa_act, msa_mask, pair_act
......@@ -96,7 +96,7 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnAttention(unittest.TestCase):
......@@ -144,7 +144,7 @@ class TestMSAColumnAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "msa_column_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -158,7 +158,7 @@ class TestMSAColumnAttention(unittest.TestCase):
)
).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps)
compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
class TestMSAColumnGlobalAttention(unittest.TestCase):
......@@ -207,7 +207,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/"
+ "msa_column_global_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
......@@ -222,7 +222,7 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
.cpu()
)
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__":
......
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import shutil
import torch
import unittest
from openfold.utils.tensor_utils import tensor_tree_map
from openfold.config import model_config
from openfold.data.data_modules import OpenFoldMultimerDataModule
from openfold.model.model import AlphaFold
from openfold.utils.loss import AlphaFoldLoss
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from tests.config import consts
import logging
logger = logging.getLogger(__name__)
@unittest.skipIf(not consts.is_multimer or consts.template_mmcif_dir is None, "Template mmcif dir required.")
class TestMultimerDataModule(unittest.TestCase):
def setUp(self):
"""
Set up model config
use model_1_multimer_v3 for now
"""
self.config = model_config(
consts.model,
train=True,
low_prec=True)
self.data_module = OpenFoldMultimerDataModule(
config=self.config.data,
batch_seed=42,
train_epoch_len=100,
template_mmcif_dir= consts.template_mmcif_dir,
template_release_dates_cache_path=os.path.join(os.getcwd(),"tests/test_data/mmcif_cache.json"),
max_template_date="2500-01-01",
train_data_dir=os.path.join(os.getcwd(),"tests/test_data/mmcifs"),
train_alignment_dir=os.path.join(os.getcwd(),"tests/test_data/alignments/"),
kalign_binary_path=shutil.which('kalign'),
train_mmcif_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_mmcifs_cache.json"),
train_chain_data_cache_path=os.path.join(os.getcwd(),
"tests/test_data/train_chain_data_cache.json"),
)
# setup model
self.c = model_config(consts.model, train=True)
self.c.loss.masked_msa.num_classes = 22 # somehow need overwrite this part in multimer loss config
self.c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
self.c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
self.model = AlphaFold(self.c)
self.loss = AlphaFoldLoss(self.c.loss)
def testPrepareData(self):
self.data_module.prepare_data()
self.data_module.setup()
train_dataset = self.data_module.train_dataset
all_chain_features = train_dataset[1]
add_batch_size_dimension = lambda t: (
t.unsqueeze(0)
)
all_chain_features = tensor_tree_map(add_batch_size_dimension, all_chain_features)
with torch.no_grad():
ground_truth = all_chain_features.pop('gt_features', None)
# Run the model
out = self.model(all_chain_features)
# Remove the recycling dimension
all_chain_features = tensor_tree_map(lambda t: t[..., -1], all_chain_features)
all_chain_features = multi_chain_permutation_align(out=out,
features=all_chain_features,
ground_truth=ground_truth)
self.loss(out, all_chain_features)
\ No newline at end of file
......@@ -74,14 +74,14 @@ class TestOuterProductMean(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/"
+ "evoformer_iteration/outer_product_mean"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, msa_act, msa_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
......@@ -92,7 +92,7 @@ class TestOuterProductMean(unittest.TestCase):
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < 5e-4)
compare_utils.assert_max_abs_diff_small(out_gt, out_repro, 5e-4)
if __name__ == "__main__":
......
......@@ -69,14 +69,14 @@ class TestPairTransition(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ "pair_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0].pair_stack
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
......
# Copyright 2021 AlQuraishi Laboratory
# Dingquan Yu @ EMBL-Hamburg Kosinski group
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import unittest
from openfold.utils.multi_chain_permutation import (pad_features, get_least_asym_entity_or_longest_length,
compute_permutation_alignment, split_ground_truth_labels,
merge_labels)
class TestPermutation(unittest.TestCase):
def setUp(self):
"""
create fake input structure features
and rotation matrices
"""
theta = math.pi / 4
device = 'cpu'
self.rotation_matrix_z = torch.tensor([
[math.cos(theta), -math.sin(theta), 0],
[math.sin(theta), math.cos(theta), 0],
[0, 0, 1]
], device=device)
self.rotation_matrix_x = torch.tensor([
[1, 0, 0],
[0, math.cos(theta), -math.sin(theta)],
[0, math.sin(theta), math.cos(theta)],
], device=device)
self.rotation_matrix_y = torch.tensor([
[math.cos(theta), 0, math.sin(theta)],
[0, 1, 0],
[-math.sin(theta), 1, math.cos(theta)],
], device=device)
self.chain_a_num_res = 9
self.chain_b_num_res = 13
# below create default fake ground truth structures for a hetero-pentamer A2B3
self.residue_index = list(range(self.chain_a_num_res)) * 2 + list(range(self.chain_b_num_res)) * 3
self.num_res = self.chain_a_num_res * 2 + self.chain_b_num_res * 3
self.asym_id = torch.tensor([[1] * self.chain_a_num_res + [2] * self.chain_a_num_res + [
3] * self.chain_b_num_res + [4] * self.chain_b_num_res + [5] * self.chain_b_num_res], device=device)
self.sym_id = self.asym_id
self.entity_id = torch.tensor([[1] * (self.chain_a_num_res * 2) + [2] * (self.chain_b_num_res * 3)],
device=device)
def test_1_selecting_anchors(self):
batch = {
'asym_id': self.asym_id,
'sym_id': self.sym_id,
'entity_id': self.entity_id,
'seq_length': torch.tensor([57])
}
anchor_gt_asym, anchor_pred_asym = get_least_asym_entity_or_longest_length(batch, batch['asym_id'])
anchor_gt_asym = int(anchor_gt_asym)
anchor_pred_asym = {int(i) for i in anchor_pred_asym}
expected_anchors = {1, 2}
expected_non_anchors = {3, 4, 5}
self.assertIn(anchor_gt_asym, expected_anchors)
self.assertNotIn(anchor_gt_asym, expected_non_anchors)
# Check that predicted anchors are within expected anchor set
self.assertEqual(anchor_pred_asym, expected_anchors & anchor_pred_asym)
self.assertEqual(set(), anchor_pred_asym & expected_non_anchors)
def test_2_permutation_pentamer(self):
batch = {
'asym_id': self.asym_id,
'sym_id': self.sym_id,
'entity_id': self.entity_id,
'seq_length': torch.tensor([57]),
'aatype': torch.randint(21, size=(1, 57))
}
batch['asym_id'] = batch['asym_id'].reshape(1, self.num_res)
batch["residue_index"] = torch.tensor([self.residue_index])
# create fake ground truth atom positions
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
# Below permutate predicted chain positions
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
pred_atom_mask = torch.ones((1, self.num_res, 37))
out = {
'final_atom_positions': pred_atom_position,
'final_atom_mask': pred_atom_mask
}
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37))), dim=1)
batch['all_atom_positions'] = true_atom_position
batch['all_atom_mask'] = true_atom_mask
aligns, _ = compute_permutation_alignment(out, batch,
batch)
print(f"##### aligns is {aligns}")
possible_outcome = [[(0, 1), (1, 0), (2, 3), (3, 4), (4, 2)], [(0, 0), (1, 1), (2, 3), (3, 4), (4, 2)]]
wrong_outcome = [[(0, 1), (1, 0), (2, 4), (3, 2), (4, 3)], [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]]
self.assertIn(aligns, possible_outcome)
self.assertNotIn(aligns, wrong_outcome)
@unittest.skip("Test needs to be fixed post-refactor")
def test_3_merge_labels(self):
nres_pad = 325 - 57 # suppose the cropping size is 325
batch = {
'asym_id': pad_features(self.asym_id, nres_pad, pad_dim=1),
'sym_id': pad_features(self.sym_id, nres_pad, pad_dim=1),
'entity_id': pad_features(self.entity_id, nres_pad, pad_dim=1),
'aatype': torch.randint(21, size=(1, 325)),
'seq_length': torch.tensor([57])
}
batch['asym_id'] = batch['asym_id'].reshape(1, 325)
batch["residue_index"] = pad_features(torch.tensor(self.residue_index).reshape(1, 57), nres_pad, pad_dim=1)
# create fake ground truth atom positions
chain_a1_pos = torch.randint(15, (self.chain_a_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_a_num_res, 37, 3)
chain_a2_pos = torch.matmul(chain_a1_pos, self.rotation_matrix_x) + 10
chain_b1_pos = torch.randint(low=15, high=30, size=(self.chain_b_num_res, 3 * 37),
dtype=torch.float).reshape(1, self.chain_b_num_res, 37, 3)
chain_b2_pos = torch.matmul(chain_b1_pos, self.rotation_matrix_y) + 10
chain_b3_pos = torch.matmul(torch.matmul(chain_b1_pos, self.rotation_matrix_z), self.rotation_matrix_x) + 30
# Below permutate predicted chain positions
pred_atom_position = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos), dim=1)
pred_atom_mask = torch.ones((1, self.num_res, 37))
pred_atom_position = pad_features(pred_atom_position, nres_pad, pad_dim=1)
pred_atom_mask = pad_features(pred_atom_mask, nres_pad, pad_dim=1)
out = {
'final_atom_positions': pred_atom_position,
'final_atom_mask': pred_atom_mask
}
true_atom_position = torch.cat((chain_a1_pos, chain_a2_pos, chain_b1_pos, chain_b2_pos, chain_b3_pos), dim=1)
true_atom_mask = torch.cat((torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_a_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37)),
torch.ones((1, self.chain_b_num_res, 37))), dim=1)
batch['all_atom_positions'] = pad_features(true_atom_position, nres_pad, pad_dim=1)
batch['all_atom_mask'] = pad_features(true_atom_mask, nres_pad=nres_pad, pad_dim=1)
# tensor_to_cuda = lambda t: t.to('cuda')
# ground_truth = tensor_tree_map(tensor_to_cuda,ground_truth)
aligns, per_asym_residue_index = compute_permutation_alignment(out,
batch,
batch)
print(f"##### aligns is {aligns}")
labels = split_ground_truth_labels(batch)
labels = merge_labels(per_asym_residue_index, labels, aligns,
original_nres=batch['aatype'].shape[-1])
self.assertTrue(torch.equal(labels['residue_index'], batch['residue_index']))
expected_permutated_gt_pos = torch.cat((chain_a2_pos, chain_a1_pos, chain_b2_pos, chain_b3_pos, chain_b1_pos),
dim=1)
expected_permutated_gt_pos = pad_features(expected_permutated_gt_pos, nres_pad, pad_dim=1)
self.assertTrue(torch.equal(labels['all_atom_positions'], expected_permutated_gt_pos))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment