Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
......@@ -75,6 +75,8 @@ for major, minor in list(compute_capabilities):
extra_cuda_flags += cc_flag
cc_flag = ['-gencode', 'arch=compute_70,code=sm_70']
if bare_metal_major != -1:
modules = [CUDAExtension(
name="attn_core_inplace_cuda",
......
......@@ -46,26 +46,26 @@ def import_alphafold():
def get_alphafold_config():
config = alphafold.model.config.model_config("model_1_ptm") # noqa
config = alphafold.model.config.model_config(consts.model) # noqa
config.model.global_config.deterministic = True
return config
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_param_path = f"openfold/resources/params/params_{consts.model}.npz"
_model = None
def get_global_pretrained_openfold():
global _model
if _model is None:
_model = AlphaFold(model_config("model_1_ptm"))
_model = AlphaFold(model_config(consts.model))
_model = _model.eval()
if not os.path.exists(_param_path):
raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_(_model, _param_path, version="model_1_ptm")
import_jax_weights_(_model, _param_path, version=consts.model)
_model = _model.cuda()
return _model
......
......@@ -2,8 +2,11 @@ import ml_collections as mlc
consts = mlc.ConfigDict(
{
"model": "model_1_multimer_v3", # monomer:model_1_ptm, multimer: model_1_multimer_v3
"is_multimer": True, # monomer: False, multimer: True
"chunk_size": 4,
"batch_size": 2,
"n_res": 11,
"n_res": 22,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
......@@ -16,6 +19,7 @@ consts = mlc.ConfigDict(
"c_s": 384,
"c_t": 64,
"c_e": 64,
"msa_logits": 22 # monomer: 23, multimer: 22
}
)
......
......@@ -12,9 +12,36 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from random import randint
import numpy as np
from scipy.spatial.transform import Rotation
from tests.config import consts
def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
n_chain = randint(1, n_res // min_chain_len) if consts.is_multimer else 1
if not split_chains:
return [0] * n_res
assert n_res >= n_chain
pieces = []
asym_ids = []
final_idx = n_chain - 1
for idx in range(n_chain - 1):
n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
if n_stop <= min_chain_len:
final_idx = idx
break
piece = randint(min_chain_len, n_stop)
pieces.append(piece)
asym_ids.extend(piece * [idx])
asym_ids.extend((n_res - sum(pieces)) * [final_idx])
return np.array(asym_ids).astype(np.int64)
def random_template_feats(n_templ, n, batch_size=None):
b = []
......@@ -39,6 +66,11 @@ def random_template_feats(n_templ, n, batch_size=None):
}
batch = {k: v.astype(np.float32) for k, v in batch.items()}
batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
if consts.is_multimer:
asym_ids = np.array(random_asym_ids(n))
batch["asym_id"] = np.tile(asym_ids[np.newaxis, :], (*b, n_templ, 1))
return batch
......
......@@ -15,19 +15,13 @@
import pickle
import shutil
import torch
import numpy as np
import unittest
from openfold.data.data_pipeline import DataPipeline
from openfold.data.templates import TemplateHitFeaturizer
from openfold.model.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
)
from openfold.data.templates import HhsearchHitFeaturizer, HmmsearchHitFeaturizer
import tests.compare_utils as compare_utils
from tests.config import consts
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -45,13 +39,29 @@ class TestDataPipeline(unittest.TestCase):
with open("tests/test_data/alphafold_feature_dict.pickle", "rb") as fp:
alphafold_feature_dict = pickle.load(fp)
template_featurizer = TemplateHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
if consts.is_multimer:
# template_featurizer = HmmsearchHitFeaturizer(
# mmcif_dir="tests/test_data/mmcifs",
# max_template_date="2021-12-20",
# max_hits=20,
# kalign_binary_path=shutil.which("kalign"),
# _zero_center_positions=False,
# )
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
else:
template_featurizer = HhsearchHitFeaturizer(
mmcif_dir="tests/test_data/mmcifs",
max_template_date="2021-12-20",
max_hits=20,
kalign_binary_path=shutil.which("kalign"),
_zero_center_positions=False,
)
data_pipeline = DataPipeline(
template_featurizer=template_featurizer,
......
import copy
import gzip
import os
import pickle
import numpy as np
......@@ -181,7 +177,7 @@ class TestDataTransforms(unittest.TestCase):
}
protein = make_hhblits_profile(protein)
masked_msa_config = config.data.common.masked_msa
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15)
protein = make_masked_msa.__wrapped__(protein, masked_msa_config, replace_fraction=0.15, seed=42)
assert 'bert_mask' in protein
assert 'true_msa' in protein
assert 'msa' in protein
......
......@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import numpy as np
import unittest
from tests.config import consts
from tests.data_utils import random_asym_ids
from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
TemplatePairEmbedder
)
......@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res = 17
n_clust = 19
max_relative_chain = 2
max_relative_idx = 32
use_chain_relative = True
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa,
"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id}
msa_emb, pair_emb = ie(batch)
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf=tf, ri=ri, msa=msa, inplace_safe=False)
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -48,6 +49,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -65,6 +68,8 @@ class TestEvoformerStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -174,6 +179,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n = 5
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
inf = 1e9
eps = 1e-10
......@@ -190,6 +197,8 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
opm_first,
fuse_projection_weights,
ckpt=False,
inf=inf,
eps=eps,
......@@ -277,7 +286,7 @@ class TestMSATransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core.msa_transition(
model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
......
......@@ -25,13 +25,16 @@ from openfold.np.residue_constants import (
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_4x4
from tests.data_utils import random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class TestFeats(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask):
......@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames(
if consts.is_multimer:
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
return self.am_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask
)
......@@ -150,9 +168,23 @@ class TestFeats(unittest.TestCase):
}
out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())))
else:
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def rigid3x4_to_4x4(rigid3arr):
four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4)
four_by_four[..., :3, :4] = rigid3arr
four_by_four[..., 3, 3] = 1
return four_by_four
def flat12_to_4x4(flat12):
rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3)
trans = flat12[..., 9:]
......@@ -164,10 +196,12 @@ class TestFeats(unittest.TestCase):
return four_by_four
out_gt["rigidgroups_gt_frames"] = flat12_to_4x4(
convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4
out_gt["rigidgroups_gt_frames"] = convert_func(
out_gt["rigidgroups_gt_frames"]
)
out_gt["rigidgroups_alt_gt_frames"] = flat12_to_4x4(
out_gt["rigidgroups_alt_gt_frames"] = convert_func(
out_gt["rigidgroups_alt_gt_frames"]
)
......@@ -187,7 +221,13 @@ class TestFeats(unittest.TestCase):
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2))
......@@ -208,7 +248,7 @@ class TestFeats(unittest.TestCase):
def run_torsion_angles_to_frames(
aatype, backb_to_global, torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
return self.am_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos,
......@@ -221,10 +261,17 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
......@@ -240,13 +287,21 @@ class TestFeats(unittest.TestCase):
)
# Convert the Rigids to 4x4 transformation tensors
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt.rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt.trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array()
out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array()
if consts.is_multimer:
rots_gt = torch.as_tensor(np.array(out_gt_rot))
trans_gt = torch.as_tensor(np.array(out_gt_trans))
else:
rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot))
trans_gt = list(
map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans)
)
rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1)
rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3)
trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1)
transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1)
bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4))
bottom_row[..., 3] = 1
......@@ -264,7 +319,13 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
......@@ -282,8 +343,7 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines):
am = alphafold.model
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
return self.am_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines
)
......@@ -294,16 +354,27 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
if consts.is_multimer:
out_gt = torch.as_tensor(np.array(out_gt.to_array()))
else:
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
out_repro = feats.frames_and_literature_positions_to_atom14_pos(
transformations.cuda(),
......
......@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
......
......@@ -13,18 +13,18 @@
# limitations under the License.
import os
import math
import torch
import numpy as np
from pathlib import Path
import unittest
import ml_collections as mlc
from openfold.data import data_transforms
from openfold.np import residue_constants
from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats
from openfold.utils.loss import (
torsion_angle_loss,
compute_fape,
......@@ -43,6 +43,8 @@ from openfold.utils.loss import (
sidechain_loss,
tm_loss,
compute_plddt,
compute_tm,
chain_center_of_mass_loss
)
from openfold.utils.tensor_utils import (
tree_map,
......@@ -51,7 +53,7 @@ from openfold.utils.tensor_utils import (
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -64,7 +66,30 @@ def affine_vector_to_4x4(affine):
return r.to_tensor_4x4()
def affine_vector_to_rigid(am_rigid, affine):
rigid_flat = np.split(affine, 7, axis=-1)
rigid_flat = [r.squeeze(-1) for r in rigid_flat]
qw, qx, qy, qz = rigid_flat[:4]
trans = rigid_flat[4:]
rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
translation = am_rigid.Vec3Array(*trans)
return am_rigid.Rigid3Array(rotation, translation)
class TestLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
......@@ -127,7 +152,10 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss(
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
......@@ -184,12 +212,22 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_clash_loss_compare(self):
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind):
return alphafold.model.all_atom.between_residue_clash_loss(
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
asym_id
)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
res_ind
)
f = hk.transform(run_brcl)
......@@ -198,10 +236,24 @@ class TestLoss(unittest.TestCase):
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom_exists = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
atom_radius = np.random.rand(n_res, 14).astype(np.float32)
res_ind = np.arange(
n_res,
)
residx_atom14_to_atom37 = np.random.randint(0, 37, (n_res, 14)).astype(np.int64)
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = np.array(atomtype_radius).astype(np.float32)
atom_radius = (
atom_exists
* atomtype_radius[residx_atom14_to_atom37]
)
asym_id = None
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
out_gt = f.apply(
{},
......@@ -210,6 +262,7 @@ class TestLoss(unittest.TestCase):
atom_exists,
atom_radius,
res_ind,
asym_id
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
......@@ -219,6 +272,7 @@ class TestLoss(unittest.TestCase):
torch.tensor(atom_exists).cuda(),
torch.tensor(atom_radius).cuda(),
torch.tensor(res_ind).cuda(),
torch.tensor(asym_id).cuda() if asym_id is not None else None,
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
......@@ -242,6 +296,36 @@ class TestLoss(unittest.TestCase):
torch.max(torch.abs(out_gt - out_repro)) < consts.eps
)
@compare_utils.skip_unless_alphafold_installed()
def test_compute_ptm_compare(self):
n_res = consts.n_res
max_bin = 31
no_bins = 64
logits = np.random.rand(n_res, n_res, no_bins)
boundaries = np.linspace(0, max_bin, num=(no_bins - 1))
ptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries)
ptm_gt = torch.tensor(ptm_gt)
logits_t = torch.tensor(logits)
ptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin)
self.assertTrue(
torch.max(torch.abs(ptm_gt - ptm_repro)) < consts.eps
)
if consts.is_multimer:
asym_id = random_asym_ids(n_res)
iptm_gt = alphafold.common.confidence.predicted_tm_score(logits, boundaries,
asym_id=asym_id, interface=True)
iptm_gt = torch.tensor(iptm_gt)
iptm_repro = compute_tm(logits_t, no_bins=no_bins, max_bin=max_bin,
asym_id=torch.tensor(asym_id), interface=True)
self.assertTrue(
torch.max(torch.abs(iptm_gt - iptm_repro)) < consts.eps
)
def test_find_structural_violations(self):
n = consts.n_res
......@@ -265,8 +349,21 @@ class TestLoss(unittest.TestCase):
def test_find_structural_violations_compare(self):
def run_fsv(batch, pos, config):
cwd = os.getcwd()
os.chdir("tests/test_data")
loss = alphafold.model.folding.find_structural_violations(
fpath = Path(__file__).parent.resolve() / "test_data"
os.chdir(str(fpath))
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
return self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
config,
batch['asym_id']
)
loss = self.am_fold.find_structural_violations(
batch,
pos,
config,
......@@ -287,6 +384,9 @@ class TestLoss(unittest.TestCase):
).astype(np.int64),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
pred_pos = np.random.rand(n_res, 14, 3)
config = mlc.ConfigDict(
......@@ -380,14 +480,14 @@ class TestLoss(unittest.TestCase):
n_seq = consts.n_seq
value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
"logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
}
batch = {
"true_msa": np.random.randint(0, 21, (n_res, n_seq)),
"bert_mask": np.random.randint(0, 2, (n_res, n_seq)).astype(
np.float32
),
)
}
out_gt = f.apply({}, None, value, batch)["loss"]
......@@ -399,7 +499,9 @@ class TestLoss(unittest.TestCase):
with torch.no_grad():
out_repro = masked_msa_loss(
value["logits"],
**batch,
batch["true_msa"],
batch["bert_mask"],
consts.msa_logits
)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
......@@ -506,10 +608,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch):
if consts.is_multimer:
pred_angles = np.reshape(
value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])
unnormed_angles = np.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])
chi_loss, _, _ = self.am_fold.supervised_chi_loss(
batch['seq_mask'],
batch['chi_mask'],
batch['aatype'],
batch['chi_angles'],
pred_angles,
unnormed_angles,
c_chi_loss
)
return chi_loss
ret = {
"loss": jax.numpy.array(0.0),
}
alphafold.model.folding.supervised_chi_loss(
self.am_fold.supervised_chi_loss(
ret, batch, value, c_chi_loss
)
return ret["loss"]
......@@ -561,6 +681,40 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss(self):
config = compare_utils.get_alphafold_config()
c_viol = config.model.heads.structure_module
n_res = consts.n_res
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
batch = tree_map(lambda n: torch.tensor(n).cuda(), batch, np.ndarray)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
atom14_pred_pos = torch.tensor(atom14_pred_pos).cuda()
batch = data_transforms.make_atom14_masks(batch)
loss_sum_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=False, **batch
)
loss_sum_clash = loss_sum_clash.cpu()
loss_avg_clash = violation_loss(
find_structural_violations(batch, atom14_pred_pos, **c_viol),
average_clashes=True, **batch
)
loss_avg_clash = loss_avg_clash.cpu()
@compare_utils.skip_unless_alphafold_installed()
def test_violation_loss_compare(self):
config = compare_utils.get_alphafold_config()
......@@ -570,15 +724,31 @@ class TestLoss(unittest.TestCase):
ret = {
"loss": np.array(0.0).astype(np.float32),
}
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
viol = self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
c_viol,
batch['asym_id']
)
return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
violations=viol,
config=c_viol)
value = {}
value[
"violations"
] = alphafold.model.folding.find_structural_violations(
] = self.am_fold.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
alphafold.model.folding.structural_violation_loss(
self.am_fold.structural_violation_loss(
ret,
batch,
value,
......@@ -593,13 +763,17 @@ class TestLoss(unittest.TestCase):
batch = {
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
"aatype": np.random.randint(0, 21, (n_res,))
}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -676,10 +850,31 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value):
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.interface_fape,
pair_mask=1. - intra_chain_mask)
return intra_chain_bb_loss + interface_bb_loss
ret = {
"loss": np.array(0.0),
}
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
self.am_fold.backbone_loss(ret, batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_bb_loss)
......@@ -691,7 +886,7 @@ class TestLoss(unittest.TestCase):
"backbone_affine_mask": np.random.randint(0, 2, (n_res,)).astype(
np.float32
),
"use_clamped_fape": np.array(0.0),
"use_clamped_fape": np.array(0.0)
}
value = {
......@@ -703,6 +898,9 @@ class TestLoss(unittest.TestCase):
),
}
if consts.is_multimer:
batch["asym_id"] = random_asym_ids(n_res)
out_gt = f.apply({}, None, batch, value)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -715,8 +913,18 @@ class TestLoss(unittest.TestCase):
)
batch["backbone_rigid_mask"] = batch["backbone_affine_mask"]
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None]
== batch["asym_id"][..., None, :]).to(dtype=value["traj"].dtype)
intra_chain_out = backbone_loss(traj=value["traj"], pair_mask=intra_chain_mask,
**{**batch, **c_sm.intra_chain_fape})
interface_out = backbone_loss(traj=value["traj"], pair_mask=1. - intra_chain_mask,
**{**batch, **c_sm.interface_fape})
out_repro = intra_chain_out + interface_out
out_repro = out_repro.cpu()
else:
out_repro = backbone_loss(traj=value["traj"], **{**batch, **c_sm})
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
......@@ -726,9 +934,29 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions):
if consts.is_multimer:
atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
aatype=batch["aatype"], all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
aatype=batch["aatype"],
all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"],
use_alt=alt_naming_is_better)
return self.am_fold.sidechain_loss(gt_sc_frames,
gt_sc_frames_mask,
gt_positions,
gt_mask,
pred_frames,
pred_positions,
c_sm)['loss']
batch = {
**batch,
**alphafold.model.all_atom.atom37_to_frames(
**self.am_atom.atom37_to_frames(
batch["aatype"],
batch["all_atom_positions"],
batch["all_atom_mask"],
......@@ -738,21 +966,21 @@ class TestLoss(unittest.TestCase):
v["sidechains"] = {}
v["sidechains"][
"frames"
] = alphafold.model.r3.rigids_from_tensor4x4(
] = self.am_rigid.rigids_from_tensor4x4(
value["sidechains"]["frames"]
)
v["sidechains"]["atom_pos"] = alphafold.model.r3.vecs_from_tensor(
v["sidechains"]["atom_pos"] = self.am_rigid.vecs_from_tensor(
value["sidechains"]["atom_pos"]
)
v.update(
alphafold.model.folding.compute_renamed_ground_truth(
self.am_fold.compute_renamed_ground_truth(
batch,
atom14_pred_positions,
)
)
value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
ret = self.am_fold.sidechain_loss(batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_sidechain_loss)
......@@ -816,6 +1044,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
......@@ -882,6 +1111,33 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
def test_chain_center_of_mass_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
batch = {
"all_atom_positions": np.random.rand(batch_size, n_res, 37, 3).astype(np.float32) * 10.0,
"all_atom_mask": np.random.randint(0, 2, (batch_size, n_res, 37)).astype(np.float32),
"asym_id": np.stack([random_asym_ids(n_res) for _ in range(batch_size)])
}
config = {
"weight": 0.05,
"clamp_distance": -4.0,
}
final_atom_positions = torch.rand(batch_size, n_res, 37, 3).cuda()
to_tensor = lambda t: torch.tensor(t).cuda()
batch = tree_map(to_tensor, batch, np.ndarray)
out_repro = chain_center_of_mass_loss(
all_atom_pred_pos=final_atom_positions,
**{**batch, **config},
)
out_repro = out_repro.cpu()
if __name__ == "__main__":
unittest.main()
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
import pickle
import torch
import torch.nn as nn
......@@ -20,8 +21,7 @@ import unittest
from openfold.config import model_config
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
from openfold.utils.tensor_utils import tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -36,13 +36,26 @@ if compare_utils.alphafold_is_installed():
class TestModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config("model_1")
c = model_config(consts.model, train=True)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
......@@ -56,6 +69,7 @@ class TestModel(unittest.TestCase):
).float()
batch["aatype"] = torch.argmax(batch["target_feat"], dim=-1)
batch["residue_index"] = torch.arange(n_res)
batch["msa_feat"] = torch.rand((n_seq, n_res, c.model.input_embedder.msa_dim))
t_feats = random_template_feats(n_templ, n_res)
batch.update({k: torch.tensor(v) for k, v in t_feats.items()})
......@@ -68,6 +82,12 @@ class TestModel(unittest.TestCase):
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer:
batch["asym_id"] = torch.randint(0, 1, size=(n_res,))
batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch["sym_id"] = torch.randint(0, 1, size=(n_res,))
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
......@@ -77,10 +97,14 @@ class TestModel(unittest.TestCase):
out = model(batch)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(consts.is_multimer, "Additional changes required for multimer.")
def test_compare(self):
#TODO: Fix test data for multimer MSA features
def run_alphafold(batch):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
model = self.am_modules.AlphaFold(config.model)
return model(
batch=batch,
is_training=False,
......@@ -91,7 +115,8 @@ class TestModel(unittest.TestCase):
params = compare_utils.fetch_alphafold_module_weights("")
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
fpath = Path(__file__).parent.resolve() / "test_data/sample_feats.pickle"
with open(str(fpath), "rb") as fp:
batch = pickle.load(fp)
out_gt = f.apply(params, jax.random.PRNGKey(42), batch)
......@@ -100,7 +125,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = self.am_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
......
......@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
......
......@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0].pair_stack
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
......
......@@ -13,20 +13,17 @@
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.primitives import (
Attention,
)
from openfold.model.primitives import Attention
from tests.config import consts
class TestLMA(unittest.TestCase):
def test_lma_vs_attention(self):
batch_size = consts.batch_size
c_hidden = 32
n = 2**12
c_hidden = 32
n = 2 ** 12
no_heads = 4
q = torch.rand(batch_size, n, c_hidden).cuda()
......@@ -34,20 +31,17 @@ class TestLMA(unittest.TestCase):
bias = [torch.rand(no_heads, 1, n)]
bias = [b.cuda() for b in bias]
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
a = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
with torch.no_grad():
l = a(q, kv, biases=bias, use_lma=True)
real = a(q, kv, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
if __name__ == "__main__":
unittest.main()
......@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask,
)
from openfold.model.structure_module import (
StructureModule,
StructureModuleTransition,
BackboneUpdate,
AngleResnet,
InvariantPointAttention,
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self):
batch_size = consts.batch_size
n = consts.n_res
......@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor,
ar_epsilon,
inf,
is_multimer=consts.is_multimer
)
s = torch.rand((batch_size, n, c_s))
......@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out = sm({"single": s, "pair": z}, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
sm = self.am_fold.StructureModule(c_sm, c_global)
representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items()
}
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
......@@ -181,6 +200,19 @@ class TestStructureModule(unittest.TestCase):
class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
c_m = 13
c_z = 17
......@@ -197,13 +229,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3))
r = Rigid(rots, trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
)
shape_before = s.shape
......@@ -215,16 +252,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention(
ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module,
config.model.global_config,
)
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine,
)
if consts.is_multimer:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
rigid=affine
)
else:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn
f = hk.transform(run_ipa)
......@@ -238,13 +285,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/"
......
......@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import re
import torch
import numpy as np
import unittest
......@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention,
TemplatePairStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
......@@ -54,6 +54,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
batch_size = consts.batch_size
c_t = consts.c_t
......@@ -65,6 +78,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25
n_templ = consts.n_templ
n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None
chunk_size = 4
inf = 1e7
......@@ -78,6 +93,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads,
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -96,12 +113,40 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
if consts.is_multimer:
safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
template_iteration = self.am_modules.TemplateEmbeddingIteration(
c_ee.template.template_pair_stack,
config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act)
return act
......@@ -115,10 +160,16 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res)
).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
if consts.is_multimer:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update(
compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
......@@ -132,7 +183,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack(
out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None,
......@@ -143,15 +194,32 @@ class TestTemplatePairStack(unittest.TestCase):
class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def test_template_embedding(pair, batch, mask_2d):
def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config,
)
act = te(pair, batch, mask_2d, is_training=False)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act
f = hk.transform(test_template_embedding)
......@@ -162,6 +230,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
......@@ -169,7 +245,7 @@ class Template(unittest.TestCase):
)
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
......@@ -177,13 +253,30 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
inplace_safe=False
)
template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
use_lma=False,
inplace_safe=False
)
else:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
use_lma=False,
inplace_safe=False
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu()
......
......@@ -86,9 +86,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_att_start
model.evoformer.blocks[0].pair_stack.tri_att_start
if starting
else model.evoformer.blocks[0].core.tri_att_end
else model.evoformer.blocks[0].pair_stack.tri_att_end
)
# To save memory, the full model transposes inputs outside of the
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import torch
import re
import numpy as np
import unittest
from openfold.model.triangular_multiplicative_update import *
......@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z
c = 11
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
tm = FusedTriangleMultiplicationOutgoing(
c_z,
c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z
batch_size = consts.batch_size
......@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
act = tri_mul(pair_act, pair_mask)
return act
f = hk.transform(run_tri_mul)
......@@ -85,10 +92,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
......@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
......@@ -23,7 +23,7 @@ from openfold.utils.rigid_utils import (
quat_to_rot,
rot_to_quat,
)
from openfold.utils.tensor_utils import chunk_layer, _chunk_slice
from openfold.utils.chunk_utils import chunk_layer, _chunk_slice
import tests.compare_utils as compare_utils
from tests.config import consts
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment