Commit d8ee9c5f authored by Christina Floristean's avatar Christina Floristean
Browse files

All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.

parent 260db67f
...@@ -12,14 +12,17 @@ ...@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import random
import torch import torch
import numpy as np
import unittest import unittest
from tests.config import consts
from tests.data_utils import random_asym_ids
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateAngleEmbedder,
TemplatePairEmbedder, TemplatePairEmbedder
) )
...@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase): ...@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res = 17 n_res = 17
n_clust = 19 n_clust = 19
max_relative_chain = 2
max_relative_idx = 32
use_chain_relative = True
tf = torch.rand((b, n_res, tf_dim)) tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res)) ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim)) msa = torch.rand((b, n_clust, n_res, msa_dim))
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa}
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch.update({"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id})
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k) ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa) msa_emb, pair_emb = ie(batch)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m)) self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z)) self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
...@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n = 2 transition_n = 2
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase): ...@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n = 5 transition_n = 5
msa_dropout = 0.15 msa_dropout = 0.15
pair_stack_dropout = 0.25 pair_stack_dropout = 0.25
opm_first = consts.is_multimer
inf = 1e9 inf = 1e9
eps = 1e-10 eps = 1e-10
...@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase): ...@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n, transition_n,
msa_dropout, msa_dropout,
pair_stack_dropout, pair_stack_dropout,
opm_first,
ckpt=False, ckpt=False,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase): ...@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core.msa_transition( model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(), torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
) )
......
...@@ -25,6 +25,9 @@ from openfold.np.residue_constants import ( ...@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
) )
import openfold.utils.feats as feats import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
tree_map, tree_map,
tensor_tree_map, tensor_tree_map,
...@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed(): ...@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class TestFeats(unittest.TestCase): class TestFeats(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self): def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask): def test_pbf(aatype, all_atom_pos, all_atom_mask):
...@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase): ...@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self): def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask): def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames( if consts.is_multimer:
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
return self.am_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask aatype, all_atom_positions, all_atom_mask
) )
...@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase): ...@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase):
} }
out_gt = f.apply({}, None, **batch) out_gt = f.apply({}, None, **batch)
if consts.is_multimer:
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())).view(*t.shape[:2], 12))
else:
to_tensor = lambda t: torch.tensor(np.array(t)) to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()} out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def flat12_to_4x4(flat12): def flat12_to_4x4(flat12):
...@@ -187,6 +212,12 @@ class TestFeats(unittest.TestCase): ...@@ -187,6 +212,12 @@ class TestFeats(unittest.TestCase):
n = 5 n = 5
rots = torch.rand((batch_size, n, 3, 3)) rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3)) trans = torch.rand((batch_size, n, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans) ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2)) angles = torch.rand((batch_size, n, 7, 2))
...@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase): ...@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase):
def run_torsion_angles_to_frames( def run_torsion_angles_to_frames(
aatype, backb_to_global, torsion_angles_sin_cos aatype, backb_to_global, torsion_angles_sin_cos
): ):
return alphafold.model.all_atom.torsion_angles_to_frames( return self.am_atom.torsion_angles_to_frames(
aatype, aatype,
backb_to_global, backb_to_global,
torsion_angles_sin_cos, torsion_angles_sin_cos,
...@@ -221,7 +252,14 @@ class TestFeats(unittest.TestCase): ...@@ -221,7 +252,14 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,)) aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4( transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float() torch.as_tensor(affines).float()
) )
...@@ -264,6 +302,12 @@ class TestFeats(unittest.TestCase): ...@@ -264,6 +302,12 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3)) rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3)) trans = torch.rand((batch_size, n_res, 8, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans) ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
...@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase): ...@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions), torch.tensor(restype_atom14_rigid_group_positions),
) )
if consts.is_multimer:
xyz = xyz.to_tensor()
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3)) self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self): def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines): def run_f(aatype, affines):
am = alphafold.model return self.am_atom.frames_and_literature_positions_to_atom14_pos(
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines aatype, affines
) )
...@@ -294,13 +340,24 @@ class TestFeats(unittest.TestCase): ...@@ -294,13 +340,24 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,)) aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8)) affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4( transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float() torch.as_tensor(affines).float()
) )
out_gt = f.apply({}, None, aatype, rigids) out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt) jax.tree_map(lambda x: x.block_until_ready(), out_gt)
if consts.is_multimer:
out_gt = torch.as_tensor(np.array(out_gt.to_array()))
else:
out_gt = torch.stack( out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1 [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
) )
......
...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase): ...@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
) )
][1].transpose(-1, -2) ][1].transpose(-1, -2)
), ),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight, model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
), ),
] ]
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# limitations under the License. # limitations under the License.
import os import os
import math
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import ( ...@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import (
Rotation, Rotation,
Rigid, Rigid,
) )
import openfold.utils.feats as feats
from openfold.utils.loss import ( from openfold.utils.loss import (
torsion_angle_loss, torsion_angle_loss,
compute_fape, compute_fape,
...@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import ( ...@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import (
) )
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4 from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed(): if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold() alphafold = compare_utils.import_alphafold()
...@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine): ...@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine):
return r.to_tensor_4x4() return r.to_tensor_4x4()
def affine_vector_to_rigid(am_rigid, affine):
rigid_flat = np.split(affine, 7, axis=-1)
rigid_flat = [r.squeeze(-1) for r in rigid_flat]
qw, qx, qy, qz = rigid_flat[:4]
trans = rigid_flat[4:]
rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
translation = am_rigid.Vec3Array(*trans)
return am_rigid.Rigid3Array(rotation, translation)
class TestLoss(unittest.TestCase): class TestLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_run_torsion_angle_loss(self): def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size batch_size = consts.batch_size
n_res = consts.n_res n_res = consts.n_res
...@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase): ...@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self): def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype): def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss( if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_bond_loss(
pred_pos, pred_pos,
pred_atom_mask, pred_atom_mask,
residue_index, residue_index,
...@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase): ...@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_between_residue_clash_loss_compare(self): def test_between_residue_clash_loss_compare(self):
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind): def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
return alphafold.model.all_atom.between_residue_clash_loss( if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_clash_loss(
pred_pos, pred_pos,
atom_exists, atom_exists,
atom_radius, atom_radius,
res_ind, res_ind,
asym_id
)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind
) )
f = hk.transform(run_brcl) f = hk.transform(run_brcl)
...@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase): ...@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase):
res_ind = np.arange( res_ind = np.arange(
n_res, n_res,
) )
asym_id = random_asym_ids(n_res)
out_gt = f.apply( out_gt = f.apply(
{}, {},
...@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase): ...@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase):
atom_exists, atom_exists,
atom_radius, atom_radius,
res_ind, res_ind,
asym_id
) )
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt) out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt) out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
...@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase): ...@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase):
def run_fsv(batch, pos, config): def run_fsv(batch, pos, config):
cwd = os.getcwd() cwd = os.getcwd()
os.chdir("tests/test_data") os.chdir("tests/test_data")
loss = alphafold.model.folding.find_structural_violations(
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
return self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
config,
batch['asym_id']
)
loss = self.am_fold.find_structural_violations(
batch, batch,
pos, pos,
config, config,
...@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase): ...@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37": np.random.randint( "residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14) 0, 37, (n_res, 14)
).astype(np.int64), ).astype(np.int64),
"asym_id": random_asym_ids(n_res)
} }
pred_pos = np.random.rand(n_res, 14, 3) pred_pos = np.random.rand(n_res, 14, 3)
...@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase): ...@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase):
n_seq = consts.n_seq n_seq = consts.n_seq
value = { value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32), "logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
} }
batch = { batch = {
...@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase): ...@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss = config.model.heads.structure_module c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch): def run_supervised_chi_loss(value, batch):
if consts.is_multimer:
pred_angles = np.reshape(
value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])
unnormed_angles = np.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])
chi_loss, _, _ = self.am_fold.supervised_chi_loss(
batch['seq_mask'],
batch['chi_mask'],
batch['aatype'],
batch['chi_angles'],
pred_angles,
unnormed_angles,
c_chi_loss
)
return chi_loss
ret = { ret = {
"loss": jax.numpy.array(0.0), "loss": jax.numpy.array(0.0),
} }
alphafold.model.folding.supervised_chi_loss( self.am_fold.supervised_chi_loss(
ret, batch, value, c_chi_loss ret, batch, value, c_chi_loss
) )
return ret["loss"] return ret["loss"]
...@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase): ...@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase):
ret = { ret = {
"loss": np.array(0.0).astype(np.float32), "loss": np.array(0.0).astype(np.float32),
} }
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
viol = self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
c_viol,
batch['asym_id']
)
return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
violations=viol,
config=c_viol)
value = {} value = {}
value[ value[
"violations" "violations"
] = alphafold.model.folding.find_structural_violations( ] = self.am_fold.find_structural_violations(
batch, batch,
atom14_pred_pos, atom14_pred_pos,
c_viol, c_viol,
) )
alphafold.model.folding.structural_violation_loss(
self.am_fold.structural_violation_loss(
ret, ret,
batch, batch,
value, value,
...@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase): ...@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase):
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res), "residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)), "aatype": np.random.randint(0, 21, (n_res,)),
"asym_id": random_asym_ids(n_res)
} }
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32) atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
out_gt = f.apply({}, None, batch, atom14_pred_pos) out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready())) out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
...@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase): ...@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value): def run_bb_loss(batch, value):
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.interface_fape,
pair_mask=1. - intra_chain_mask)
return intra_chain_bb_loss + interface_bb_loss
ret = { ret = {
"loss": np.array(0.0), "loss": np.array(0.0),
} }
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm) self.am_fold.backbone_loss(ret, batch, value, c_sm)
return ret["loss"] return ret["loss"]
f = hk.transform(run_bb_loss) f = hk.transform(run_bb_loss)
...@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase): ...@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase):
np.float32 np.float32
), ),
"use_clamped_fape": np.array(0.0), "use_clamped_fape": np.array(0.0),
"asym_id": random_asym_ids(n_res)
} }
value = { value = {
...@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase): ...@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions): def run_sidechain_loss(batch, value, atom14_pred_positions):
if consts.is_multimer:
atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
aatype=batch["aatype"], all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
aatype=batch["aatype"],
all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"],
use_alt=alt_naming_is_better)
return self.am_fold.sidechain_loss(gt_sc_frames,
gt_sc_frames_mask,
gt_positions,
gt_mask,
pred_frames,
pred_positions,
c_sm)['loss']
batch = { batch = {
**batch, **batch,
**alphafold.model.all_atom.atom37_to_frames( **self.am_atom.atom37_to_frames(
batch["aatype"], batch["aatype"],
batch["all_atom_positions"], batch["all_atom_positions"],
batch["all_atom_mask"], batch["all_atom_mask"],
...@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase): ...@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase):
) )
value = v value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm) ret = self.am_fold.sidechain_loss(batch, value, c_sm)
return ret["loss"] return ret["loss"]
f = hk.transform(run_sidechain_loss) f = hk.transform(run_sidechain_loss)
...@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase): ...@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self): def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error c_tm = config.model.heads.predicted_aligned_error
......
...@@ -20,8 +20,7 @@ import unittest ...@@ -20,8 +20,7 @@ import unittest
from openfold.config import model_config from openfold.config import model_config
from openfold.data import data_transforms from openfold.data import data_transforms
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
import openfold.utils.feats as feats from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed(): ...@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed():
class TestModel(unittest.TestCase): class TestModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_dry_run(self): def test_dry_run(self):
n_seq = consts.n_seq n_seq = consts.n_seq
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
n_extra_seq = consts.n_extra n_extra_seq = consts.n_extra
c = model_config("model_1") c = model_config(consts.model)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test # deepspeed for this test
...@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase): ...@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase):
batch.update(data_transforms.make_atom14_masks(batch)) batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.) batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer:
batch["asym_id"] = torch.randint(0, 1, size=(n_res,))
batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch["sym_id"] = torch.randint(0, 1, size=(n_res,))
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: ( add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters) t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
) )
...@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase): ...@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase):
def test_compare(self): def test_compare(self):
def run_alphafold(batch): def run_alphafold(batch):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
model = self.am_modules.AlphaFold(config.model)
return model( return model(
batch=batch, batch=batch,
is_training=False, is_training=False,
...@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase): ...@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches # atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0] batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0] batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = self.am_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready())) out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,]) batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
......
...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase): ...@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core model.evoformer.blocks[0]
.outer_product_mean( .outer_product_mean(
torch.as_tensor(msa_act).cuda(), torch.as_tensor(msa_act).cuda(),
chunk_size=4, chunk_size=4,
......
...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase): ...@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = ( out_repro = (
model.evoformer.blocks[0].core model.evoformer.blocks[0].pair_stack
.pair_transition( .pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4, chunk_size=4,
......
...@@ -13,12 +13,10 @@ ...@@ -13,12 +13,10 @@
# limitations under the License. # limitations under the License.
import torch import torch
import numpy as np
import unittest import unittest
from openfold.model.primitives import ( from openfold.model.primitives import (
Attention, Attention
LowMemoryAttention,
) )
from tests.config import consts from tests.config import consts
...@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase): ...@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase):
gating_fill = torch.rand(c_hidden * no_heads, c_hidden) gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads) o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention( lma = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda() ).cuda()
a = Attention( a = Attention(
...@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase): ...@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase):
m.linear_o.weight.copy_(o_fill) m.linear_o.weight.copy_(o_fill)
with torch.no_grad(): with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias) l = lma(q, k, v, biases=bias, use_lma=True, q_chunk_size=1024, kv_chunk_size=4096)
real = a(q, k, v, biases=bias) real = a(q, k, v, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps) self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
......
...@@ -18,21 +18,19 @@ import unittest ...@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask, restype_atom37_mask,
) )
from openfold.model.structure_module import ( from openfold.model.structure_module import (
StructureModule, StructureModule,
StructureModuleTransition, StructureModuleTransition,
BackboneUpdate,
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed(): ...@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase): class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self): def test_structure_module_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
n = consts.n_res n = consts.n_res
...@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor, trans_scale_factor,
ar_epsilon, ar_epsilon,
inf, inf,
is_multimer=consts.is_multimer
) )
s = torch.rand((batch_size, n, c_s)) s = torch.rand((batch_size, n, c_s))
...@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f) out = sm(s, z, f)
if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7)) self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase): ...@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config c_global = config.model.global_config
def run_sm(representations, batch): def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global) sm = self.am_fold.StructureModule(c_sm, c_global)
representations = { representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items() k: jax.lax.stop_gradient(v) for k, v in representations.items()
} }
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()} batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False) return sm(representations, batch, is_training=False)
f = hk.transform(run_sm) f = hk.transform(run_sm)
...@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase): ...@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase):
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
c_z = 17 c_z = 17
...@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans) r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
) )
shape_before = s.shape shape_before = s.shape
...@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self): def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine): def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention( ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module, config.model.heads.structure_module,
config.model.global_config, config.model.global_config,
) )
if consts.is_multimer:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
rigid=affine
)
else:
attn = ipa( attn = ipa(
inputs_1d=act, inputs_1d=act,
inputs_2d=static_feat_2d, inputs_2d=static_feat_2d,
mask=mask, mask=mask,
affine=affine, affine=affine
) )
return attn return attn
f = hk.transform(run_ipa) f = hk.transform(run_ipa)
...@@ -235,12 +282,19 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -235,12 +282,19 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1)) sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids) if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4( transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda() torch.as_tensor(affines).float().cuda()
) )
sample_affine = quats sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights( ipa_params = compare_utils.fetch_alphafold_module_weights(
......
...@@ -19,7 +19,6 @@ from openfold.model.template import ( ...@@ -19,7 +19,6 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
TemplatePairStack, TemplatePairStack,
) )
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_template_feats from tests.data_utils import random_template_feats
...@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase): ...@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase): class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
c_t = consts.c_t c_t = consts.c_t
...@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25 dropout = 0.25
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
tri_mul_first = consts.is_multimer
blocks_per_ckpt = None blocks_per_ckpt = None
chunk_size = 4 chunk_size = 4
inf = 1e7 inf = 1e7
...@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pt_inner_dim, pair_transition_n=pt_inner_dim,
dropout_rate=dropout, dropout_rate=dropout,
tri_mul_first=tri_mul_first,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -96,7 +110,35 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -96,7 +110,35 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask): def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
if consts.is_multimer:
safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
template_iteration = self.am_modules.TemplateEmbeddingIteration(
c_ee.template.template_pair_stack,
config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack, c_ee.template.template_pair_stack,
config.model.global_config, config.model.global_config,
name="template_pair_stack", name="template_pair_stack",
...@@ -115,6 +157,12 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -115,6 +157,12 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res) low=0, high=2, size=(n_res, n_res)
).astype(np.float32) ).astype(np.float32)
if consts.is_multimer:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" "alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack" + "single_template_embedding/template_pair_stack"
...@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack( out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(), torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(), torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None, chunk_size=None,
...@@ -143,14 +191,31 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -143,14 +191,31 @@ class TestTemplatePairStack(unittest.TestCase):
class Template(unittest.TestCase): class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def test_template_embedding(pair, batch, mask_2d): def test_template_embedding(pair, batch, mask_2d):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding( te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template, config.model.embeddings_and_evoformer.template,
config.model.global_config, config.model.global_config,
) )
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=multichain_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False) act = te(pair, batch, mask_2d, is_training=False)
return act return act
...@@ -162,6 +227,14 @@ class Template(unittest.TestCase): ...@@ -162,6 +227,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res) batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"] batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
batch["multichain_mask_2d"] = multichain_mask_2d
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)] # Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
...@@ -177,12 +250,26 @@ class Template(unittest.TestCase): ...@@ -177,12 +250,26 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds] batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()}, template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(), torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(), torch.as_tensor(pair_mask).cuda(),
templ_dim=0, templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=multichain_mask_2d,
) )
else:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size
)
out_repro = out_repro["template_pair_embedding"] out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
......
...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_att_start model.evoformer.blocks[0].pair_stack.tri_att_start
if starting if starting
else model.evoformer.blocks[0].core.tri_att_end else model.evoformer.blocks[0].pair_stack.tri_att_end
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_mul_in model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].core.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment