Commit d8ee9c5f authored by Christina Floristean's avatar Christina Floristean
Browse files

All non-cuda tests passing for monomer/multimer. Tri mul/attn and OPM order switched.

parent 260db67f
......@@ -12,14 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import torch
import numpy as np
import unittest
from tests.config import consts
from tests.data_utils import random_asym_ids
from openfold.model.embedders import (
InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
TemplatePairEmbedder
)
......@@ -35,13 +38,30 @@ class TestInputEmbedder(unittest.TestCase):
n_res = 17
n_clust = 19
max_relative_chain = 2
max_relative_idx = 32
use_chain_relative = True
tf = torch.rand((b, n_res, tf_dim))
ri = torch.rand((b, n_res))
msa = torch.rand((b, n_clust, n_res, msa_dim))
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(tf, ri, msa)
asym_ids_flat = torch.Tensor(random_asym_ids(n_res))
asym_id = torch.tile(asym_ids_flat.unsqueeze(0), (b, 1))
entity_id = asym_id
sym_id = torch.zeros_like(entity_id)
batch = {"target_feat": tf, "residue_index": ri, "msa_feat": msa}
if consts.is_multimer:
ie = InputEmbedderMultimer(tf_dim, msa_dim, c_z, c_m,
max_relative_idx=max_relative_idx,
use_chain_relative=use_chain_relative,
max_relative_chain=max_relative_chain)
batch.update({"asym_id": asym_id, "entity_id": entity_id, "sym_id": sym_id})
else:
ie = InputEmbedder(tf_dim, msa_dim, c_z, c_m, relpos_k)
msa_emb, pair_emb = ie(batch)
self.assertTrue(msa_emb.shape == (b, n_clust, n_res, c_m))
self.assertTrue(pair_emb.shape == (b, n_res, n_res, c_z))
......
......@@ -48,6 +48,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n = 2
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
inf = 1e9
eps = 1e-10
......@@ -65,6 +66,7 @@ class TestEvoformerStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
opm_first,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -156,6 +158,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n = 5
msa_dropout = 0.15
pair_stack_dropout = 0.25
opm_first = consts.is_multimer
inf = 1e9
eps = 1e-10
......@@ -172,6 +175,7 @@ class TestExtraMSAStack(unittest.TestCase):
transition_n,
msa_dropout,
pair_stack_dropout,
opm_first,
ckpt=False,
inf=inf,
eps=eps,
......@@ -259,7 +263,7 @@ class TestMSATransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core.msa_transition(
model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
)
......
......@@ -25,6 +25,9 @@ from openfold.np.residue_constants import (
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
from openfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
......@@ -40,6 +43,19 @@ if compare_utils.alphafold_is_installed():
class TestFeats(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_pseudo_beta_fn_compare(self):
def test_pbf(aatype, all_atom_pos, all_atom_mask):
......@@ -131,7 +147,9 @@ class TestFeats(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_atom37_to_frames_compare(self):
def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask):
return alphafold.model.all_atom.atom37_to_frames(
if consts.is_multimer:
all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions)
return self.am_atom.atom37_to_frames(
aatype, all_atom_positions, all_atom_mask
)
......@@ -150,7 +168,14 @@ class TestFeats(unittest.TestCase):
}
out_gt = f.apply({}, None, **batch)
to_tensor = lambda t: torch.tensor(np.array(t))
if consts.is_multimer:
to_tensor = (lambda t: torch.tensor(np.array(t))
if not isinstance(t, self.am_rigid.Rigid3Array)
else torch.tensor(np.array(t.to_array())).view(*t.shape[:2], 12))
else:
to_tensor = lambda t: torch.tensor(np.array(t))
out_gt = {k: to_tensor(v) for k, v in out_gt.items()}
def flat12_to_4x4(flat12):
......@@ -187,7 +212,13 @@ class TestFeats(unittest.TestCase):
n = 5
rots = torch.rand((batch_size, n, 3, 3))
trans = torch.rand((batch_size, n, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
angles = torch.rand((batch_size, n, 7, 2))
......@@ -208,7 +239,7 @@ class TestFeats(unittest.TestCase):
def run_torsion_angles_to_frames(
aatype, backb_to_global, torsion_angles_sin_cos
):
return alphafold.model.all_atom.torsion_angles_to_frames(
return self.am_atom.torsion_angles_to_frames(
aatype,
backb_to_global,
torsion_angles_sin_cos,
......@@ -221,10 +252,17 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
torsion_angles_sin_cos = np.random.rand(n_res, 7, 2)
......@@ -264,7 +302,13 @@ class TestFeats(unittest.TestCase):
rots = torch.rand((batch_size, n_res, 8, 3, 3))
trans = torch.rand((batch_size, n_res, 8, 3))
ts = Rigid(Rotation(rot_mats=rots), trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rots)
translation = Vec3Array.from_array(trans)
ts = Rigid3Array(rotation, translation)
else:
ts = Rigid(Rotation(rot_mats=rots), trans)
f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long()
......@@ -277,13 +321,15 @@ class TestFeats(unittest.TestCase):
torch.tensor(restype_atom14_rigid_group_positions),
)
if consts.is_multimer:
xyz = xyz.to_tensor()
self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3))
@compare_utils.skip_unless_alphafold_installed()
def test_frames_and_literature_positions_to_atom14_pos_compare(self):
def run_f(aatype, affines):
am = alphafold.model
return am.all_atom.frames_and_literature_positions_to_atom14_pos(
return self.am_atom.frames_and_literature_positions_to_atom14_pos(
aatype, affines
)
......@@ -294,16 +340,27 @@ class TestFeats(unittest.TestCase):
aatype = np.random.randint(0, 21, size=(n_res,))
affines = random_affines_4x4((n_res, 8))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float()
)
out_gt = f.apply({}, None, aatype, rigids)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
if consts.is_multimer:
out_gt = torch.as_tensor(np.array(out_gt.to_array()))
else:
out_gt = torch.stack(
[torch.as_tensor(np.array(x)) for x in out_gt], dim=-1
)
out_repro = feats.frames_and_literature_positions_to_atom14_pos(
transformations.cuda(),
......
......@@ -65,7 +65,7 @@ class TestImportWeights(unittest.TestCase):
)
][1].transpose(-1, -2)
),
model.evoformer.blocks[1].core.outer_product_mean.linear_1.weight,
model.evoformer.blocks[1].outer_product_mean.linear_1.weight,
),
]
......
......@@ -13,7 +13,6 @@
# limitations under the License.
import os
import math
import torch
import numpy as np
import unittest
......@@ -24,7 +23,6 @@ from openfold.utils.rigid_utils import (
Rotation,
Rigid,
)
import openfold.utils.feats as feats
from openfold.utils.loss import (
torsion_angle_loss,
compute_fape,
......@@ -51,7 +49,7 @@ from openfold.utils.tensor_utils import (
)
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_affines_vector, random_affines_4x4
from tests.data_utils import random_affines_vector, random_affines_4x4, random_asym_ids
if compare_utils.alphafold_is_installed():
alphafold = compare_utils.import_alphafold()
......@@ -64,7 +62,30 @@ def affine_vector_to_4x4(affine):
return r.to_tensor_4x4()
def affine_vector_to_rigid(am_rigid, affine):
rigid_flat = np.split(affine, 7, axis=-1)
rigid_flat = [r.squeeze(-1) for r in rigid_flat]
qw, qx, qy, qz = rigid_flat[:4]
trans = rigid_flat[4:]
rotation = am_rigid.Rot3Array.from_quaternion(qw, qx, qy, qz, normalize=True)
translation = am_rigid.Vec3Array(*trans)
return am_rigid.Rigid3Array(rotation, translation)
class TestLoss(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_run_torsion_angle_loss(self):
batch_size = consts.batch_size
n_res = consts.n_res
......@@ -127,7 +148,10 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss(
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
......@@ -184,12 +208,22 @@ class TestLoss(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_clash_loss_compare(self):
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind):
return alphafold.model.all_atom.between_residue_clash_loss(
def run_brcl(pred_pos, atom_exists, atom_radius, res_ind, asym_id):
if consts.is_multimer:
pred_pos = self.am_rigid.Vec3Array.from_array(pred_pos)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
asym_id
)
return self.am_atom.between_residue_clash_loss(
pred_pos,
atom_exists,
atom_radius,
res_ind,
res_ind
)
f = hk.transform(run_brcl)
......@@ -202,6 +236,7 @@ class TestLoss(unittest.TestCase):
res_ind = np.arange(
n_res,
)
asym_id = random_asym_ids(n_res)
out_gt = f.apply(
{},
......@@ -210,6 +245,7 @@ class TestLoss(unittest.TestCase):
atom_exists,
atom_radius,
res_ind,
asym_id
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
......@@ -266,7 +302,19 @@ class TestLoss(unittest.TestCase):
def run_fsv(batch, pos, config):
cwd = os.getcwd()
os.chdir("tests/test_data")
loss = alphafold.model.folding.find_structural_violations(
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(pos)
return self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
config,
batch['asym_id']
)
loss = self.am_fold.find_structural_violations(
batch,
pos,
config,
......@@ -285,6 +333,7 @@ class TestLoss(unittest.TestCase):
"residx_atom14_to_atom37": np.random.randint(
0, 37, (n_res, 14)
).astype(np.int64),
"asym_id": random_asym_ids(n_res)
}
pred_pos = np.random.rand(n_res, 14, 3)
......@@ -380,7 +429,7 @@ class TestLoss(unittest.TestCase):
n_seq = consts.n_seq
value = {
"logits": np.random.rand(n_res, n_seq, 23).astype(np.float32),
"logits": np.random.rand(n_res, n_seq, consts.msa_logits).astype(np.float32),
}
batch = {
......@@ -506,10 +555,28 @@ class TestLoss(unittest.TestCase):
c_chi_loss = config.model.heads.structure_module
def run_supervised_chi_loss(value, batch):
if consts.is_multimer:
pred_angles = np.reshape(
value['sidechains']['angles_sin_cos'], [-1, consts.n_res, 7, 2])
unnormed_angles = np.reshape(
value['sidechains']['unnormalized_angles_sin_cos'], [-1, consts.n_res, 7, 2])
chi_loss, _, _ = self.am_fold.supervised_chi_loss(
batch['seq_mask'],
batch['chi_mask'],
batch['aatype'],
batch['chi_angles'],
pred_angles,
unnormed_angles,
c_chi_loss
)
return chi_loss
ret = {
"loss": jax.numpy.array(0.0),
}
alphafold.model.folding.supervised_chi_loss(
self.am_fold.supervised_chi_loss(
ret, batch, value, c_chi_loss
)
return ret["loss"]
......@@ -570,15 +637,31 @@ class TestLoss(unittest.TestCase):
ret = {
"loss": np.array(0.0).astype(np.float32),
}
if consts.is_multimer:
atom14_pred_pos = self.am_rigid.Vec3Array.from_array(atom14_pred_pos)
viol = self.am_fold.find_structural_violations(
batch['aatype'],
batch['residue_index'],
batch['atom14_atom_exists'],
atom14_pred_pos,
c_viol,
batch['asym_id']
)
return self.am_fold.structural_violation_loss(mask=batch['atom14_atom_exists'],
violations=viol,
config=c_viol)
value = {}
value[
"violations"
] = alphafold.model.folding.find_structural_violations(
] = self.am_fold.find_structural_violations(
batch,
atom14_pred_pos,
c_viol,
)
alphafold.model.folding.structural_violation_loss(
self.am_fold.structural_violation_loss(
ret,
batch,
value,
......@@ -594,12 +677,14 @@ class TestLoss(unittest.TestCase):
"seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32),
"residue_index": np.arange(n_res),
"aatype": np.random.randint(0, 21, (n_res,)),
"asym_id": random_asym_ids(n_res)
}
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
atom14_pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
alphafold.model.tf.data_transforms.make_atom14_masks(batch)
batch = {k: np.array(v) for k, v in batch.items()}
out_gt = f.apply({}, None, batch, atom14_pred_pos)
out_gt = torch.tensor(np.array(out_gt.block_until_ready()))
......@@ -676,10 +761,31 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_bb_loss(batch, value):
if consts.is_multimer:
intra_chain_mask = (batch["asym_id"][..., None] == batch["asym_id"][..., None, :]).astype(np.float32)
gt_rigid = affine_vector_to_rigid(self.am_rigid, batch["backbone_affine_tensor"])
target_rigid = affine_vector_to_rigid(self.am_rigid, value['traj'])
intra_chain_bb_loss, intra_chain_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.intra_chain_fape,
pair_mask=intra_chain_mask)
interface_bb_loss, interface_fape = self.am_fold.backbone_loss(
gt_rigid=gt_rigid,
gt_frames_mask=batch["backbone_affine_mask"],
gt_positions_mask=batch["backbone_affine_mask"],
target_rigid=target_rigid,
config=c_sm.interface_fape,
pair_mask=1. - intra_chain_mask)
return intra_chain_bb_loss + interface_bb_loss
ret = {
"loss": np.array(0.0),
}
alphafold.model.folding.backbone_loss(ret, batch, value, c_sm)
self.am_fold.backbone_loss(ret, batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_bb_loss)
......@@ -692,6 +798,7 @@ class TestLoss(unittest.TestCase):
np.float32
),
"use_clamped_fape": np.array(0.0),
"asym_id": random_asym_ids(n_res)
}
value = {
......@@ -726,9 +833,29 @@ class TestLoss(unittest.TestCase):
c_sm = config.model.heads.structure_module
def run_sidechain_loss(batch, value, atom14_pred_positions):
if consts.is_multimer:
atom14_pred_positions = self.am_rigid.Vec3Array.from_array(atom14_pred_positions)
all_atom_positions = self.am_rigid.Vec3Array.from_array(batch["all_atom_positions"])
gt_positions, gt_mask, alt_naming_is_better = self.am_fold.compute_atom14_gt(
aatype=batch["aatype"], all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"], pred_pos=atom14_pred_positions)
pred_frames = self.am_rigid.Rigid3Array.from_array4x4(value["sidechains"]["frames"])
pred_positions = self.am_rigid.Vec3Array.from_array(value["sidechains"]["atom_pos"])
gt_sc_frames, gt_sc_frames_mask = self.am_fold.compute_frames(
aatype=batch["aatype"],
all_atom_positions=all_atom_positions,
all_atom_mask=batch["all_atom_mask"],
use_alt=alt_naming_is_better)
return self.am_fold.sidechain_loss(gt_sc_frames,
gt_sc_frames_mask,
gt_positions,
gt_mask,
pred_frames,
pred_positions,
c_sm)['loss']
batch = {
**batch,
**alphafold.model.all_atom.atom37_to_frames(
**self.am_atom.atom37_to_frames(
batch["aatype"],
batch["all_atom_positions"],
batch["all_atom_mask"],
......@@ -752,7 +879,7 @@ class TestLoss(unittest.TestCase):
)
value = v
ret = alphafold.model.folding.sidechain_loss(batch, value, c_sm)
ret = self.am_fold.sidechain_loss(batch, value, c_sm)
return ret["loss"]
f = hk.transform(run_sidechain_loss)
......@@ -816,6 +943,7 @@ class TestLoss(unittest.TestCase):
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
@compare_utils.skip_unless_alphafold_installed()
@unittest.skipIf(not consts.is_multimer and "ptm" not in consts.model, "Not enabled for non-ptm models.")
def test_tm_loss_compare(self):
config = compare_utils.get_alphafold_config()
c_tm = config.model.heads.predicted_aligned_error
......
......@@ -20,8 +20,7 @@ import unittest
from openfold.config import model_config
from openfold.data import data_transforms
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
from openfold.utils.tensor_utils import tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -36,13 +35,26 @@ if compare_utils.alphafold_is_installed():
class TestModel(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_dry_run(self):
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config("model_1")
c = model_config(consts.model)
c.model.evoformer_stack.no_blocks = 4 # no need to go overboard here
c.model.evoformer_stack.blocks_per_ckpt = None # don't want to set up
# deepspeed for this test
......@@ -68,6 +80,12 @@ class TestModel(unittest.TestCase):
batch.update(data_transforms.make_atom14_masks(batch))
batch["no_recycling_iters"] = torch.tensor(2.)
if consts.is_multimer:
batch["asym_id"] = torch.randint(0, 1, size=(n_res,))
batch["entity_id"] = torch.randint(0, 1, size=(n_res,))
batch["sym_id"] = torch.randint(0, 1, size=(n_res,))
batch["extra_deletion_matrix"] = torch.randint(0, 2, size=(n_extra_seq, n_res))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.data.common.max_recycling_iters)
)
......@@ -80,7 +98,8 @@ class TestModel(unittest.TestCase):
def test_compare(self):
def run_alphafold(batch):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
model = self.am_modules.AlphaFold(config.model)
return model(
batch=batch,
is_training=False,
......@@ -100,7 +119,8 @@ class TestModel(unittest.TestCase):
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = self.am_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch["no_recycling_iters"] = np.array([3., 3., 3., 3.,])
......
......@@ -81,7 +81,7 @@ class TestOuterProductMean(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0]
.outer_product_mean(
torch.as_tensor(msa_act).cuda(),
chunk_size=4,
......
......@@ -76,7 +76,7 @@ class TestPairTransition(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
out_repro = (
model.evoformer.blocks[0].core
model.evoformer.blocks[0].pair_stack
.pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
chunk_size=4,
......
......@@ -13,12 +13,10 @@
# limitations under the License.
import torch
import numpy as np
import unittest
from openfold.model.primitives import (
Attention,
LowMemoryAttention,
Attention
)
from tests.config import consts
......@@ -40,7 +38,7 @@ class TestLMA(unittest.TestCase):
gating_fill = torch.rand(c_hidden * no_heads, c_hidden)
o_fill = torch.rand(c_hidden, c_hidden * no_heads)
lma = LowMemoryAttention(
lma = Attention(
c_hidden, c_hidden, c_hidden, c_hidden, no_heads
).cuda()
a = Attention(
......@@ -60,7 +58,7 @@ class TestLMA(unittest.TestCase):
m.linear_o.weight.copy_(o_fill)
with torch.no_grad():
l = lma(q, k, v, 1024, 4096, biases=bias)
l = lma(q, k, v, biases=bias, use_lma=True, q_chunk_size=1024, kv_chunk_size=4096)
real = a(q, k, v, biases=bias)
self.assertTrue(torch.max(torch.abs(l - real)) < consts.eps)
......
......@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask,
)
from openfold.model.structure_module import (
StructureModule,
StructureModuleTransition,
BackboneUpdate,
AngleResnet,
InvariantPointAttention,
)
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
......@@ -46,6 +44,19 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self):
batch_size = consts.batch_size
n = consts.n_res
......@@ -81,6 +92,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor,
ar_epsilon,
inf,
is_multimer=consts.is_multimer
)
s = torch.rand((batch_size, n, c_s))
......@@ -89,7 +101,11 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
)
......@@ -121,11 +137,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config
def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global)
sm = self.am_fold.StructureModule(c_sm, c_global)
representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items()
}
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False)
f = hk.transform(run_sm)
......@@ -178,6 +197,19 @@ class TestStructureModule(unittest.TestCase):
class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
c_m = 13
c_z = 17
......@@ -194,13 +226,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3))
r = Rigid(rots, trans)
if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans)
ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp
c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
)
shape_before = s.shape
......@@ -212,16 +249,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention(
ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module,
config.model.global_config,
)
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine,
)
if consts.is_multimer:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
rigid=affine
)
else:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn
f = hk.transform(run_ipa)
......@@ -235,13 +282,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/"
......
......@@ -19,7 +19,6 @@ from openfold.model.template import (
TemplatePointwiseAttention,
TemplatePairStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
......@@ -54,6 +53,19 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self):
batch_size = consts.batch_size
c_t = consts.c_t
......@@ -65,6 +77,7 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25
n_templ = consts.n_templ
n_res = consts.n_res
tri_mul_first = consts.is_multimer
blocks_per_ckpt = None
chunk_size = 4
inf = 1e7
......@@ -78,6 +91,7 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads,
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
tri_mul_first=tri_mul_first,
blocks_per_ckpt=None,
inf=inf,
eps=eps,
......@@ -96,12 +110,40 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
if consts.is_multimer:
safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
template_iteration = self.am_modules.TemplateEmbeddingIteration(
c_ee.template.template_pair_stack,
config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act)
return act
......@@ -115,10 +157,16 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res)
).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
if consts.is_multimer:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update(
compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
......@@ -132,7 +180,7 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack(
out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None,
......@@ -143,15 +191,32 @@ class TestTemplatePairStack(unittest.TestCase):
class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def test_template_embedding(pair, batch, mask_2d):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config,
)
act = te(pair, batch, mask_2d, is_training=False)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=multichain_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act
f = hk.transform(test_template_embedding)
......@@ -162,6 +227,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
batch["multichain_mask_2d"] = multichain_mask_2d
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
......@@ -177,12 +250,26 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
)
template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
if consts.is_multimer:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=multichain_mask_2d,
)
else:
out_repro = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu()
......
......@@ -85,9 +85,9 @@ class TestTriangularAttention(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_att_start
model.evoformer.blocks[0].pair_stack.tri_att_start
if starting
else model.evoformer.blocks[0].core.tri_att_end
else model.evoformer.blocks[0].pair_stack.tri_att_end
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
......@@ -87,9 +87,9 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].core.tri_mul_in
model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming
else model.evoformer.blocks[0].core.tri_mul_out
else model.evoformer.blocks[0].pair_stack.tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment