"deploy/dynemo/operator/pkg/compoundai/reqcli/http.go" did not exist on "5ddc7f7df5ab77c4efae9fd6ca299c3040c91533"
Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
...@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities): ...@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities):
extra_cuda_flags += cc_flag extra_cuda_flags += cc_flag
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
if bare_metal_major != -1: if bare_metal_major != -1:
modules = [CUDAExtension( modules = [CUDAExtension(
name="attn_core_inplace_cuda", name="attn_core_inplace_cuda",
......
...@@ -46,26 +46,26 @@ def import_alphafold(): ...@@ -46,26 +46,26 @@ def import_alphafold():
def get_alphafold_config(): def get_alphafold_config():
config = alphafold.model.config.model_config("model_1_ptm") # noqa config = alphafold.model.config.model_config(consts.model) # noqa
config.model.global_config.deterministic = True config.model.global_config.deterministic = True
return config return config
_param_path = "openfold/resources/params/params_model_1_ptm.npz" _param_path = f"openfold/resources/params/params_{consts.model}.npz"
_model = None _model = None
def get_global_pretrained_openfold(): def get_global_pretrained_openfold():
global _model global _model
if _model is None: if _model is None:
_model = AlphaFold(model_config("model_1_ptm")) _model = AlphaFold(model_config(consts.model))
_model = _model.eval() _model = _model.eval()
if not os.path.exists(_param_path): if not os.path.exists(_param_path):
raise FileNotFoundError( raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the """Cannot load pretrained parameters. Make sure to run the
installation script before running tests.""" installation script before running tests."""
) )
import_jax_weights_(_model, _param_path, version="model_1_ptm") import_jax_weights_(_model, _param_path, version=consts.model)
_model = _model.cuda() _model = _model.cuda()
return _model return _model
......
...@@ -2,8 +2,11 @@ import ml_collections as mlc ...@@ -2,8 +2,11 @@ import ml_collections as mlc
consts = mlc.ConfigDict( consts = mlc.ConfigDict(
{ {
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2, "batch_size": 2,
"n_res": 11, "n_res": 22,
"n_seq": 13, "n_seq": 13,
"n_templ": 3, "n_templ": 3,
"n_extra": 17, "n_extra": 17,
...@@ -16,6 +19,7 @@ consts = mlc.ConfigDict( ...@@ -16,6 +19,7 @@ consts = mlc.ConfigDict(
"c_s": 384, "c_s": 384,
"c_t": 64, "c_t": 64,
"c_e": 64, "c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22
} }
) )
......
...@@ -12,9 +12,36 @@ ...@@ -12,9 +12,36 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from random import randint
import numpy as np import numpy as np
from scipy.spatial.transform import Rotation from scipy.spatial.transform import Rotation
from tests.config import consts
def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
n_chain = randint(1, n_res // min_chain_len) if consts.is_multimer else 1
if not split_chains:
return [0] * n_res
assert n_res >= n_chain
pieces = []
asym_ids = []
final_idx = n_chain - 1
for idx in range(n_chain - 1):
n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len:
final_idx = idx
break
piece = randint(min_chain_len, n_stop)
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64)
def random_template_feats(n_templ, n, batch_size=None): def random_template_feats(n_templ, n, batch_size=None):
b = [] b = []
...@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None): ...@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None):
} }
batch = {k: v.astype(np.float32) for k, v in batch.items()} batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64) batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
if consts.is_multimer:
asym_ids = np.array(random_asym_ids(n))
batch["asym_id"] = np.tile(asym_ids[np.newaxis, :], (*b, n_templ, 1))
return batch return batch
......
...@@ -15,19 +15,13 @@ ...@@ -15,19 +15,13 @@
import pickle import pickle
import shutil import shutil
import torch
import numpy as np import numpy as np
import unittest import unittest
from openfold.data.data_pipeline import DataPipeline from openfold.data.data_pipeline import DataPipeline
from openfold.data.templates import TemplateHitFeaturizer from openfold.data.templates import HhsearchHitFeaturizer, HmmsearchHitFeaturizer
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase): ...@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp: with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp:
alphafold_feature_dict = pickle.load(fp) alphafold_feature_dict = pickle.load(fp)
template_featurizer = TemplateHitFeaturizer( if consts.is_multimer:
mmcif_dir="tests/test_data/mmcifs", # template_featurizer = HmmsearchHitFeaturizer(
max_template_date="2021-12-20", # mmcif_dir="tests/test_data/mmcifs",
max_hits=20, # max_template_date="2021-12-20",
kalign_binary_path=shutil.which("kalign"), # max_hits=20,
_zero_center_positions=False, # kalign_binary_path=shutil.which("kalign"),
) # _zero_center_positions=False,
# )
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
else:
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
data_pipeline = DataPipeline( data_pipeline = DataPipeline(
template_featurizer=template_featurizer, template_featurizer=template_featurizer,
......
import copy
import gzip import gzip
import os
import pickle import pickle
import numpy as np import numpy as np
...@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase): ...@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
} }
protein = make_hhblits_profile(protein) protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15) protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
assert 'bert_mask' in protein assert 'bert_mask' in protein
assert 'true_msa' in protein assert 'true_msa' in protein
assert 'msa' in protein assert 'msa' in protein
......
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import torch import torch
import numpy as np
import unittest import unittest
from tests.config import consts
from tests.data_utils import random_asym_ids
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder
) )
...@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase): ...@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res = 17 n_res = 17
n_clust = 19 n_clust = 19
max_relative_chain = 2
max_relative_idx = 32
use_chain_relative = True
tf = torch.rand((b, n_res, tf_dim)) tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res)) ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim)) msa = torch.rand((b, n_clust, n_res, msa_dim))
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
msa_emb, pair_emb = ie(batch)
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m)) self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n = 2 transition_n = 2
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first,
fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n = 5 transition_n = 5
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first,
fuse_projection_weights,
ckpt=False, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core.msa_transition( model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
) )
......
...@@ -25,13 +25,16 @@ from openfold.np.residue_constants import ( ...@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
) )
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
) )
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_affines_4x4 from tests.data_utils import random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed(): ...@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class TestFeats(unittest.TestCase): class TestFeats(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self): def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask): def test_pbf(aatype, all_atom_pos, all_atom_mask):
...@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase): ...@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self): def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask): def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames( if consts.is_multimer:
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
return self.am_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask aatype, all_atom_positions, all_atom_mask
) )
...@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase): ...@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase):
} }
out_gt = f.apply({}, None, **batch) out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())))
else:
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()} out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def rigid3x4_to_4x4(rigid3arr):
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
four_by_four[..., :3, :4] = rigid3arr
four_by_four[..., 3, 3] = 1
return four_by_four
def flat12_to_4x4(flat12): def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3) rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:] trans = flat12[..., 9:]
...@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase): ...@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase):
return four_by_four return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4( convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
out_gt["rigidgroups_gt_frames"] = convert_func(
out_gt["rigidgroups_gt_frames"] out_gt["rigidgroups_gt_frames"]
) )
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4( out_gt["rigidgroups_alt_gt_frames"] = convert_func(
out_gt["rigidgroups_alt_gt_frames"] out_gt["rigidgroups_alt_gt_frames"]
) )
...@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase): ...@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase):
n = 5 n = 5
rots = torch.rand((batch_size, n, 3, 3)) rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3)) trans = torch.rand((batch_size, n, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2)) angles = torch.rand((batch_size, n, 7, 2))
...@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase): ...@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase):
def run_torsion_angles_to_frames( def run_torsion_angles_to_frames(
aatype, backb_to_global, torsion_angles_sin_cos aatype, backb_to_global, torsion_angles_sin_cos
): ):
return alphafold.model.all_atom.torsion_angles_to_frames( return self.am_atom.torsion_angles_to_frames(
aatype, aatype,
backb_to_global, backb_to_global,
torsion_angles_sin_cos, torsion_angles_sin_cos,
...@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase): ...@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,)) aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4( if consts.is_multimer:
torch.as_tensor(affines).float() rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
) transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2) torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
...@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase): ...@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase):
) )
# Convert the Rigids to 4x4 transformation tensors # Convert the Rigids to 4x4 transformation tensors
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot)) out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
trans_gt = list( out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
) if consts.is_multimer:
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1) rots_gt = torch.as_tensor(np.array(out_gt_rot))
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3) trans_gt = torch.as_tensor(np.array(out_gt_trans))
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1) else:
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1) transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4)) bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1 bottom_row[..., 3] = 1
...@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase): ...@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3)) rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3)) trans = torch.rand((batch_size, n_res, 8, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
...@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase): ...@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self): def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines): def run_f(aatype, affines):
am = alphafold.model return self.am_atom.frames_and_literature_positions_to_atom14_pos(
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines aatype, affines
) )
...@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase): ...@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,)) aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8)) affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4( if consts.is_multimer:
torch.as_tensor(affines).float() rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
) transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids) out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt) jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1 if consts.is_multimer:
) out_gt = torch.as_tensor(np.array(out_gt.to_array()))
else:
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
out_repro = feats.frames_and_literature_positions_to_atom14_pos( out_repro = feats.frames_and_literature_positions_to_atom14_pos(
transformations.cuda(), transformations.cuda(),
......
...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
) )
][1].transpose(-1, -2) ][1].transpose(-1, -2)
), ),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight, model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
), ),
] ]
......
...@@ -13,18 +13,18 @@ ...@@ -13,18 +13,18 @@
# limitations under the License. # limitations under the License.
import os import os
import math
import torch import torch
import numpy as np import numpy as np
from pathlib import Path
import unittest import unittest
import ml_collections as mlc import ml_collections as mlc
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils.rigid_utils import ( from openfold.utils.rigid_utils import (
Rotation, Rotation,
Rigid, Rigid,
) )
import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
torsion_angle_loss, torsion_angle_loss,
compute_fape, compute_fape,
...@@ -43,6 +43,8 @@ from openfold.utils.loss import ( ...@@ -43,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss, sidechain_loss,
tm_loss, tm_loss,
compute_plddt, compute_plddt,
compute_tm,
chain_center_of_mass_loss
) )
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
...@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import ( ...@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
) )
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4 from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine): ...@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine):
return r.to_tensor_4x4() return r.to_tensor_4x4()
def affine_vector_to_rigid(am_rigid, affine):
rigid_flat = np.split(affine, 7, axis=-1)
rigid_flat = [r.squeeze(-1) for r in rigid_flat]
qw, qx, qy, qz = rigid_flat[:4]
trans = rigid_flat[4:]
rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
translation = am_rigid.Vec3Array(*trans)
return am_rigid.Rigid3Array(rotation, translation)
class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_run_torsion_angle_loss(self): def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size batch_size = consts.batch_size
n_res = consts.n_res n_res = consts.n_res
...@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase): ...@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self): def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype): def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss( if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_bond_loss(
pred_pos, pred_pos,
pred_atom_mask, pred_atom_mask,
residue_index, residue_index,
...@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase): ...@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_between_residue_clash_loss_compare(self): def test_between_residue_clash_loss_compare(self):
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind): def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
return alphafold.model.all_atom.between_residue_clash_loss( if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
asym_id
)
return self.am_atom.between_residue_clash_loss(
pred_pos, pred_pos,
atom_exists, atom_exists,
atom_radius, atom_radius,
res_ind, res_ind
) )
f = hk.transform(run_brcl) f = hk.transform(run_brcl)
...@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase): ...@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32) pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32) atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange( res_ind = np.arange(
n_res, n_res,
) )
residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = np.array(atomtype_radius).astype(np.float32)
atom_radius = (
atom_exists
* atomtype_radius[residx_atom14_to_atom37]
)
asym_id = None
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
out_gt = f.apply( out_gt = f.apply(
{}, {},
...@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase): ...@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase):
atom_exists, atom_exists,
atom_radius, atom_radius,
res_ind, res_ind,
asym_id
) )
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt) out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt) out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
...@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase): ...@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(atom_exists).cuda(), torch.tensor(atom_exists).cuda(),
torch.tensor(atom_radius).cuda(), torch.tensor(atom_radius).cuda(),
torch.tensor(res_ind).cuda(), torch.tensor(res_ind).cuda(),
torch.tensor(asym_id).cuda() if asym_id is not None else None,
) )
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro) out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
...@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase): ...@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt - out_repro)) < consts.eps torch.max(torch.abs(out_gt - out_repro)) < consts.eps
) )
@compare_utils.skip_unless_alphafold_installed()
def test_compute_ptm_compare(self):
n_res = consts.n_res
max_bin = 31
no_bins = 64
logits = np.random.rand(n_res, n_res, no_bins)
boundaries = np.linspace(0, max_bin, num=(no_bins - 1))
ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
ptm_gt = torch.tensor(ptm_gt)
logits_t = torch.tensor(logits)
ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)
self.assertTrue(
torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
)
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
asym_id=asym_id, interface=True)
iptm_gt = torch.tensor(iptm_gt)
iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
asym_id=torch.tensor(asym_id), interface=True)
self.assertTrue(
torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
)
def test_find_structural_violations(self): def test_find_structural_violations(self):
n = consts.n_res n = consts.n_res
...@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase): ...@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase):
def test_find_structural_violations_compare(self): def test_find_structural_violations_compare(self):
def run_fsv(batch, pos, config): def run_fsv(batch, pos, config):
cwd = os.getcwd() cwd = os.getcwd()
os.chdir("tests/test_data") fpath = Path(__file__).parent.resolve() / "test_data"
loss = alphafold.model.folding.find_structural_violations( os.chdir(str(fpath))
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
return self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
config,
batch['asym_id']
)
loss = self.am_fold.find_structural_violations(
batch, batch,
pos, pos,
config, config,
...@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase): ...@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase):
).astype(np.int64), ).astype(np.int64),
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
pred_pos = np.random.rand(n_res, 14, 3) pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict( config = mlc.ConfigDict(
...@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase): ...@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase):
n_seq = consts.n_seq n_seq = consts.n_seq
value = { value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32), "logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
} }
batch = { batch = {
"true_msa": np.random.randint(0, 21, (n_res, n_seq)), "true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype( "bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32 np.float32
), )
} }
out_gt = f.apply({}, None, value, batch)["loss"] out_gt = f.apply({}, None, value, batch)["loss"]
...@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase): ...@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase):
with torch.no_grad(): with torch.no_grad():
out_repro = masked_msa_loss( out_repro = masked_msa_loss(
value["logits"], value["logits"],
**batch, batch["true_msa"],
batch["bert_mask"],
consts.msa_logits
) )
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
...@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase): ...@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss = config.model.heads.structure_module c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch): def run_supervised_chi_loss(value, batch):
if consts.is_multimer:
pred_angles = np.reshape(
value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])
unnormed_angles = np.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])
chi_loss, _, _ = self.am_fold.supervised_chi_loss(
batch['seq_mask'],
batch['chi_mask'],
batch['aatype'],
batch['chi_angles'],
pred_angles,
unnormed_angles,
c_chi_loss
)
return chi_loss
ret = { ret = {
"loss": jax.numpy.array(0.0), "loss": jax.numpy.array(0.0),
} }
alphafold.model.folding.supervised_chi_loss( self.am_fold.supervised_chi_loss(
ret, batch, value, c_chi_loss ret, batch, value, c_chi_loss
) )
return ret["loss"] return ret["loss"]
...@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase): ...@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
loss_sum_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=False, **batch
)
loss_sum_clash = loss_sum_clash.cpu()
loss_avg_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=True, **batch
)
loss_avg_clash = loss_avg_clash.cpu()
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_violation_loss_compare(self): def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
...@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase): ...@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase):
ret = { ret = {
"loss": np.array(0.0).astype(np.float32), "loss": np.array(0.0).astype(np.float32),
} }
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
viol = self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
c_viol,
batch['asym_id']
)
return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
violations=viol,
config=c_viol)
value = {} value = {}
value[ value[
"violations" "violations"
] = alphafold.model.folding.find_structural_violations( ] = self.am_fold.find_structural_violations(
batch, batch,
atom14_pred_pos, atom14_pred_pos,
c_viol, c_viol,
) )
alphafold.model.folding.structural_violation_loss(
self.am_fold.structural_violation_loss(
ret, ret,
batch, batch,
value, value,
...@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase): ...@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase):
batch = { batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res), "residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 21, (n_res,))
} }
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()} if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32) atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
out_gt = f.apply({}, None, batch, atom14_pred_pos) out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready())) out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
...@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase): ...@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value): def run_bb_loss(batch, value):
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.interface_fape,
pair_mask=1. - intra_chain_mask)
return intra_chain_bb_loss + interface_bb_loss
ret = { ret = {
"loss": np.array(0.0), "loss": np.array(0.0),
} }
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm) self.am_fold.backbone_loss(ret, batch, value, c_sm)
return ret["loss"] return ret["loss"]
f = hk.transform(run_bb_loss) f = hk.transform(run_bb_loss)
...@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase): ...@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype( "backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32 np.float32
), ),
"use_clamped_fape": np.array(0.0), "use_clamped_fape": np.array(0.0)
} }
value = { value = {
...@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase): ...@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase):
), ),
} }
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
out_gt = f.apply({}, None, batch, value) out_gt = f.apply({}, None, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready())) out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
...@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase): ...@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase):
) )
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"] batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm}) if consts.is_multimer:
out_repro = out_repro.cpu() intra_chain_mask = (batch["asym_id"][..., None]
== batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
**{**batch, **c_sm.intra_chain_fape})
interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
**{**batch, **c_sm.interface_fape})
out_repro = intra_chain_out + interface_out
out_repro = out_repro.cpu()
else:
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
...@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase): ...@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions): def run_sidechain_loss(batch, value, atom14_pred_positions):
if consts.is_multimer:
atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
aatype=batch["aatype"], all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
aatype=batch["aatype"],
all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"],
use_alt=alt_naming_is_better)
return self.am_fold.sidechain_loss(gt_sc_frames,
gt_sc_frames_mask,
gt_positions,
gt_mask,
pred_frames,
pred_positions,
c_sm)['loss']
batch = { batch = {
**batch, **batch,
**alphafold.model.all_atom.atom37_to_frames( **self.am_atom.atom37_to_frames(
batch["aatype"], batch["aatype"],
batch["all_atom_positions"], batch["all_atom_positions"],
batch["all_atom_mask"], batch["all_atom_mask"],
...@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase): ...@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase):
v["sidechains"] = {} v["sidechains"] = {}
v["sidechains"][ v["sidechains"][
"frames" "frames"
] = alphafold.model.r3.rigids_from_tensor4x4( ] = self.am_rigid.rigids_from_tensor4x4(
value["sidechains"]["frames"] value["sidechains"]["frames"]
) )
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor( v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
value["sidechains"]["atom_pos"] value["sidechains"]["atom_pos"]
) )
v.update( v.update(
alphafold.model.folding.compute_renamed_ground_truth( self.am_fold.compute_renamed_ground_truth(
batch, batch,
atom14_pred_positions, atom14_pred_positions,
) )
) )
value = v value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm) ret = self.am_fold.sidechain_loss(batch, value, c_sm)
return ret["loss"] return ret["loss"]
f = hk.transform(run_sidechain_loss) f = hk.transform(run_sidechain_loss)
...@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase): ...@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self): def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error c_tm = config.model.heads.predicted_aligned_error
...@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase): ...@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
batch = {
"all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
"all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
"asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
}
config = {
"weight": 0.05,
"clamp_distance": -4.0,
}
final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = chain_center_of_mass_loss(
all_atom_pred_pos=final_atom_positions,
**{**batch, **config},
)
out_repro = out_repro.cpu()
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pathlib import Path
import pickle import pickle
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -20,8 +21,7 @@ import unittest ...@@ -20,8 +21,7 @@ import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.utils.feats as feats from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed(): ...@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed():
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_dry_run(self): def test_dry_run(self):
n_seq = consts.n_seq n_seq = consts.n_seq
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config("model_1") c = model_config(consts.model, train=True)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
...@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase): ...@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase):
).float() ).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1) batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res) batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim)) batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res) t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()}) batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
...@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase): ...@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase):
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer:
batch["asym_id"] = torch.randint(0, 1, size=(n_res,))
batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch["sym_id"] = torch.randint(0, 1, size=(n_res,))
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
...@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase): ...@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase):
out = model(batch) out = model(batch)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer, "Additional changes required for multimer.")
def test_compare(self): def test_compare(self):
#TODO: Fix test data for multimer MSA features
def run_alphafold(batch): def run_alphafold(batch):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
model = self.am_modules.AlphaFold(config.model)
return model( return model(
batch=batch, batch=batch,
is_training=False, is_training=False,
...@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase): ...@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights("") params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp: fpath = Path(__file__).parent.resolve() / "test_data/sample_feats.pickle"
with open(str(fpath), "rb") as fp:
batch = pickle.load(fp) batch = pickle.load(fp)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch) out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
...@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase): ...@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches # atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0] batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = self.am_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,]) batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
......
...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core model.evoformer.blocks[0]
.outer_product_mean( .outer_product_mean(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4, chunk_size=4,
......
...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core model.evoformer.blocks[0].pair_stack
.pair_transition( .pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4, chunk_size=4,
......
...@@ -13,20 +13,17 @@ ...@@ -13,20 +13,17 @@
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import Attention
Attention,
)
from tests.config import consts from tests.config import consts
class TestLMA(unittest.TestCase): class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self): def test_lma_vs_attention(self):
batch_size = consts.batch_size batch_size = consts.batch_size
c_hidden = 32 c_hidden = 32
n = 2**12 n = 2 ** 12
no_heads = 4 no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda() q = torch.rand(batch_size, n, c_hidden).cuda()
...@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase): ...@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase):
bias = [torch.rand(no_heads, 1, n)] bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias] bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
a = Attention( a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
with torch.no_grad(): with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True) l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias) real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -18,21 +18,19 @@ import unittest ...@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask, restype_atom37_mask,
) )
from openfold.model.structure_module import ( from openfold.model.structure_module import (
StructureModule, StructureModule,
StructureModuleTransition, StructureModuleTransition,
BackboneUpdate,
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed(): ...@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase): class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self): def test_structure_module_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
n = consts.n_res n = consts.n_res
...@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor, trans_scale_factor,
ar_epsilon, ar_epsilon,
inf, inf,
is_multimer=consts.is_multimer
) )
s = torch.rand((batch_size, n, c_s)) s = torch.rand((batch_size, n, c_s))
...@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out = sm({"single": s, "pair": z}, f) out = sm({"single": s, "pair": z}, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7)) if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase): ...@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config c_global = config.model.global_config
def run_sm(representations, batch): def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global) sm = self.am_fold.StructureModule(c_sm, c_global)
representations = { representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items() k: jax.lax.stop_gradient(v) for k, v in representations.items()
} }
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()} batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False) return sm(representations, batch, is_training=False)
f = hk.transform(run_sm) f = hk.transform(run_sm)
...@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase): ...@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase):
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
c_z = 17 c_z = 17
...@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
r = Rigid(rots, trans) if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
) )
shape_before = s.shape shape_before = s.shape
...@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self): def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine): def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention( ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module, config.model.heads.structure_module,
config.model.global_config, config.model.global_config,
) )
attn = ipa(
inputs_1d=act, if consts.is_multimer:
inputs_2d=static_feat_2d, attn = ipa(
mask=mask, inputs_1d=act,
affine=affine, inputs_2d=static_feat_2d,
) mask=mask,
rigid=affine
)
else:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn return attn
f = hk.transform(run_ipa) f = hk.transform(run_ipa)
...@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1)) sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights( ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/" "alphafold/alphafold_iteration/structure_module/"
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -19,7 +20,6 @@ from openfold.model.template import ( ...@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
TemplatePairStack, TemplatePairStack,
) )
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_template_feats from tests.data_utils import random_template_feats
...@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase): ...@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase): class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
c_t = consts.c_t c_t = consts.c_t
...@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25 dropout = 0.25
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None blocks_per_ckpt = None
chunk_size = 4 chunk_size = 4
inf = 1e7 inf = 1e7
...@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pt_inner_dim, pair_transition_n=pt_inner_dim,
dropout_rate=dropout, dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask): def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack, if consts.is_multimer:
config.model.global_config, safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
name="template_pair_stack", template_iteration = self.am_modules.TemplateEmbeddingIteration(
) c_ee.template.template_pair_stack,
act = tps(pair_act, pair_mask, is_training=False) config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm") ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act) act = ln(act)
return act return act
...@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res) low=0, high=2, size=(n_res, n_res)
).astype(np.float32) ).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights( if consts.is_multimer:
"alphafold/alphafold_iteration/evoformer/template_embedding/" params = compare_utils.fetch_alphafold_module_weights(
+ "single_template_embedding/template_pair_stack" "alphafold/alphafold_iteration/evoformer/template_embedding/"
) + "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update( params.update(
compare_utils.fetch_alphafold_module_weights( compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" "alphafold/alphafold_iteration/evoformer/template_embedding/"
...@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack( out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(), torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(), torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None, chunk_size=None,
...@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase):
class Template(unittest.TestCase): class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def test_template_embedding(pair, batch, mask_2d): def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding( te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template, config.model.embeddings_and_evoformer.template,
config.model.global_config, config.model.global_config,
) )
act = te(pair, batch, mask_2d, is_training=False)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act return act
f = hk.transform(test_template_embedding) f = hk.transform(test_template_embedding)
...@@ -162,6 +230,14 @@ class Template(unittest.TestCase): ...@@ -162,6 +230,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res) batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"] batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)] # Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
...@@ -169,7 +245,7 @@ class Template(unittest.TestCase): ...@@ -169,7 +245,7 @@ class Template(unittest.TestCase):
) )
out_gt = f.apply( out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready() ).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -177,13 +253,30 @@ class Template(unittest.TestCase): ...@@ -177,13 +253,30 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds] batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()}, template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
torch.as_tensor(pair_act).cuda(), if consts.is_multimer:
torch.as_tensor(pair_mask).cuda(), out_repro = model.template_embedder(
templ_dim=0, template_feats,
inplace_safe=False torch.as_tensor(pair_act).cuda(),
) torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
use_lma=False,
inplace_safe=False
)
else:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
use_lma=False,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
......
...@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_att_start model.evoformer.blocks[0].pair_stack.tri_att_start
if starting if starting
else model.evoformer.blocks[0].core.tri_att_end else model.evoformer.blocks[0].pair_stack.tri_att_end
) )
# To save memory, the full model transposes inputs outside of the # To save memory, the full model transposes inputs outside of the
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import re
import numpy as np import numpy as np
import unittest import unittest
from openfold.model.triangular_multiplicative_update import * from openfold.model.triangular_multiplicative_update import *
...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z c_z = consts.c_z
c = 11 c = 11
tm = TriangleMultiplicationOutgoing( if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
c_z, tm = FusedTriangleMultiplicationOutgoing(
c, c_z,
) c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z n_res = consts.c_z
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config, config.model.global_config,
name=name, name=name,
) )
act = tri_mul(act=pair_act, mask=pair_mask) act = tri_mul(pair_act, pair_mask)
return act return act
f = hk.transform(run_tri_mul) f = hk.transform(run_tri_mul)
...@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_mul_in model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].core.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
...@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32) pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_mul_in model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].core.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_stock = module( out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
...@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import ( ...@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot, quat_to_rot,
rot_to_quat, rot_to_quat,
) )
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment