Commit 304b5ff7 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Begin to spruce up unit tests, fix config

parent abd78418
......@@ -2,7 +2,15 @@ import copy
import ml_collections as mlc
def model_config(name, train=False):
def set_inf(c, inf):
for k, v in c.items():
if(isinstance(v, mlc.ConfigDict)):
set_inf(v, inf)
elif(k == "inf"):
c[k] = inf
def model_config(name, train=False, low_prec=False):
c = copy.deepcopy(config)
if(name == "model_1"):
pass
......@@ -16,28 +24,34 @@ def model_config(name, train=False):
c.model.template.enabled = False
elif(name == "model_1_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
c.loss.tm.weight = 0.1
elif(name == "model_2_ptm"):
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
c.loss.tm.weight = 0.1
elif(name == "model_3_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
c.loss.tm.weight = 0.1
elif(name == "model_4_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
c.loss.tm.weight = 0.1
elif(name == "model_5_ptm"):
c.model.template.enabled = False
c.model.heads.tm.enabled = True
c.model.loss.tm.weight = 0.1
c.loss.tm.weight = 0.1
else:
raise ValueError("Invalid model name")
if(train):
c.globals.model.blocks_per_ckpt = 1
c.globals.chunk_size = None
if(low_prec):
c.globals.eps = 1e-4
# If we want exact numerical parity with the original, inf can't be
# a global constant
set_inf(c, 1e4)
return c
......@@ -51,7 +65,6 @@ blocks_per_ckpt = mlc.FieldReference(None, field_type=int)
chunk_size = mlc.FieldReference(4, field_type=int)
aux_distogram_bins = mlc.FieldReference(64, field_type=int)
eps = mlc.FieldReference(1e-8, field_type=float)
inf = mlc.FieldReference(1e8, field_type=float)
config = mlc.ConfigDict({
# Recurring FieldReferences that can be changed globally here
......@@ -64,7 +77,6 @@ config = mlc.ConfigDict({
"c_e": c_e,
"c_s": c_s,
"eps": eps,
"inf": inf,
},
"model": {
"no_cycles": 4,
......@@ -82,7 +94,7 @@ config = mlc.ConfigDict({
"min_bin": 3.25,
"max_bin": 20.75,
"no_bins": 15,
"inf": inf,#1e8,
"inf": 1e8,
},
"template": {
"distogram": {
......@@ -111,7 +123,7 @@ config = mlc.ConfigDict({
"dropout_rate": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": inf,
"inf": 1e9,
},
"template_pointwise_attention": {
"c_t": c_t,
......@@ -121,9 +133,9 @@ config = mlc.ConfigDict({
"c_hidden": 16,
"no_heads": 4,
"chunk_size": chunk_size,
"inf": inf,#1e-9,
"inf": 1e9,
},
"inf": inf,
"inf": 1e9,
"eps": eps,#1e-6,
"enabled": True,
"embed_angles": True,
......@@ -148,7 +160,7 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": inf,#1e9,
"inf": 1e9,
"eps": eps,#1e-10,
},
"enabled": True,
......@@ -169,7 +181,7 @@ config = mlc.ConfigDict({
"pair_dropout": 0.25,
"blocks_per_ckpt": blocks_per_ckpt,
"chunk_size": chunk_size,
"inf": inf,#1e9,
"inf": 1e9,
"eps": eps,#1e-10,
},
"structure_module": {
......@@ -187,7 +199,7 @@ config = mlc.ConfigDict({
"no_angles": 7,
"trans_scale_factor": 10,
"epsilon": eps,#1e-12,
"inf": inf,#1e5,
"inf": 1e5,
},
"heads": {
"lddt": {
......
......@@ -316,7 +316,6 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
......@@ -721,7 +720,6 @@ class StructureModule(nn.Module):
# [*, N]
t = T.identity(s.shape[:-1], s.dtype, s.device, self.training)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
......
......@@ -23,6 +23,8 @@ from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import (
batched_gather,
one_hot,
tree_map,
tensor_tree_map,
)
......@@ -143,6 +145,13 @@ def compute_residx(batch):
return out
def compute_residx_np(batch):
batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
out = compute_residx(batch)
out = tensor_tree_map(lambda t: np.array(t), out)
return out
def atom14_to_atom37(atom14, batch):
atom37_data = batched_gather(
atom14,
......
......@@ -19,7 +19,7 @@ setup(
name='openfold',
version='1.0.0',
description='A PyTorch reimplementation of DeepMind\'s AlphaFold 2',
author='Gustaf Ahdritz',
author='Gustaf Ahdritz & DeepMind',
author_email='gahdritz@gmail.com',
license='Apache License, Version 2.0',
url='https://github.com/aqlaboratory/openfold',
......
import os
import importlib
import pkgutil
import sys
import unittest
import numpy as np
from config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"
def alphafold_is_installed():
return importlib.util.find_spec("alphafold") is not None
def skip_unless_alphafold_installed():
return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")
def import_alphafold():
"""
If AlphaFold is installed using the provided setuptools script, this
is necessary to expose all of AlphaFold's precious insides
"""
if("alphafold" in sys.modules):
return sys.modules["alphafold"]
module = importlib.import_module("alphafold")
# Forcefully import alphafold's submodules
submodules = pkgutil.walk_packages(
module.__path__, prefix=("alphafold.")
)
for submodule_info in submodules:
importlib.import_module(submodule_info.name)
sys.modules["alphafold"] = module
globals()["alphafold"] = module
return module
def get_alphafold_config():
config = alphafold.model.config.model_config("model_1_ptm")
config.model.global_config.deterministic = True
return config
_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_model = None
def get_global_pretrained_openfold():
global _model
if(_model is None):
_model = AlphaFold(model_config("model_1_ptm").model)
_model = _model.eval()
if(not os.path.exists(_param_path)):
raise FileNotFoundError(
"""Cannot load pretrained parameters. Make sure to run the
installation script before running tests."""
)
import_jax_weights_(_model, _param_path)
_model = _model.cuda()
return _model
_orig_weights = None
def _get_orig_weights():
global _orig_weights
if(_orig_weights is None):
_orig_weights = np.load(_param_path)
return _orig_weights
def _remove_key_prefix(d, prefix):
for k, v in list(d.items()):
if(k.startswith(prefix)):
d.pop(k)
d[k[len(prefix):]] = v
def fetch_alphafold_module_weights(weight_path):
orig_weights = _get_orig_weights()
params = {
k:v for k,v in orig_weights.items()
if weight_path in k
}
if('/' in weight_path):
spl = weight_path.split('/')
spl = spl if len(spl[-1]) != 0 else spl[:-1]
module_name = spl[-1]
prefix = '/'.join(spl[:-1]) + '/'
_remove_key_prefix(params, prefix)
params = alphafold.model.utils.flat_params_to_haiku(params)
return params
import ml_collections as mlc
consts = mlc.ConfigDict({
"batch_size": 2,
"n_res": 11,
"n_seq": 13,
"n_templ": 3,
"n_extra": 17,
"eps": 5e-4,
# For compatibility with DeepMind's pretrained weights, it's easiest for
# everyone if these take their real values.
"c_m": 256,
"c_z": 128,
"c_s": 384,
"c_t": 64,
"c_e": 64,
})
......@@ -54,7 +54,7 @@ def random_extra_msa_feats(n_extra, n, batch_size=None):
return batch
def random_affine_vectors(dim):
def random_affines_vector(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
......@@ -68,7 +68,7 @@ def random_affine_vectors(dim):
return affines.reshape(*dim, 7)
def random_affine_4x4s(dim):
def random_affines_4x4(dim):
prod_dim = 1
for d in dim:
prod_dim *= d
......
......@@ -15,21 +15,33 @@
import torch
import numpy as np
import unittest
from alphafold.model.evoformer import *
from openfold.model.evoformer import (
MSATransition,
EvoformerStack,
ExtraMSAStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestEvoformerStack(unittest.TestCase):
def test_shape(self):
batch_size = 5
s_t = 27
n_res = 29
c_m = 7
c_z = 11
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c_z = consts.c_z
c_hidden_msa_att = 12
c_hidden_opm = 17
c_hidden_mul = 19
c_hidden_pair_att = 14
c_s = 23
c_s = consts.c_s
no_heads_msa = 3
no_heads_pair = 7
no_blocks = 2
......@@ -59,9 +71,9 @@ class TestEvoformerStack(unittest.TestCase):
eps=eps,
).eval()
m = torch.rand((batch_size, s_t, n_res, c_m))
m = torch.rand((batch_size, n_seq, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
msa_mask = torch.randint(0, 2, size=(batch_size, s_t, n_res))
msa_mask = torch.randint(0, 2, size=(batch_size, n_seq, n_res))
pair_mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
shape_m_before = m.shape
......@@ -73,6 +85,59 @@ class TestEvoformerStack(unittest.TestCase):
self.assertTrue(z.shape == shape_z_before)
self.assertTrue(s.shape == (batch_size, n_res, c_s))
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_ei(activations, masks):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
ei = alphafold.model.modules.EvoformerIteration(
c_e, config.model.global_config, is_extra_msa=False)
return ei(activations, masks, is_training=False)
f = hk.transform(run_ei)
n_res = consts.n_res
n_seq = consts.n_seq
activations = {
'msa': np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32),
'pair': np.random.rand(n_res, n_res, consts.c_z).astype(np.float32),
}
masks = {
'msa': np.random.randint(0, 2, (n_seq, n_res)).astype(np.float32),
'pair': np.random.randint(0, 2, (n_res, n_res)).astype(np.float32),
}
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
key = jax.random.PRNGKey(42)
out_gt = f.apply(
params, key, activations, masks
)
jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt_msa = torch.as_tensor(np.array(out_gt["msa"]))
out_gt_pair = torch.as_tensor(np.array(out_gt["pair"]))
model = compare_utils.get_global_pretrained_openfold()
out_repro_msa, out_repro_pair = model.evoformer.blocks[0](
torch.as_tensor(activations["msa"]).cuda(),
torch.as_tensor(activations["pair"]).cuda(),
torch.as_tensor(masks["msa"]).cuda(),
torch.as_tensor(masks["pair"]).cuda(),
_mask_trans=False,
)
out_repro_msa = out_repro_msa.cpu()
out_repro_pair = out_repro_pair.cpu()
assert(torch.max(torch.abs(out_repro_msa - out_gt_msa) < consts.eps))
assert(torch.max(torch.abs(out_repro_pair - out_gt_pair) < consts.eps))
class TestExtraMSAStack(unittest.TestCase):
def test_shape(self):
......@@ -143,6 +208,47 @@ class TestMSATransition(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_msa_transition(msa_act, msa_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_trans = alphafold.model.modules.Transition(
c_e.msa_transition,
config.model.global_config,
name="msa_transition"
)
act = msa_trans(act=msa_act, mask=msa_mask)
return act
f = hk.transform(run_msa_transition)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.ones((n_seq, n_res)).astype(np.float32) # no mask here either
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_transition(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......@@ -17,23 +17,37 @@ import torch
import numpy as np
import unittest
from alphafold.utils.loss import *
from alphafold.utils.utils import T
from openfold.utils.loss import (
torsion_angle_loss,
compute_fape,
between_residue_bond_loss,
between_residue_clash_loss,
find_structural_violations,
)
from openfold.utils.affine_utils import T
from openfold.utils.tensor_utils import tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestLoss(unittest.TestCase):
def test_run_torsion_angle_loss(self):
batch_size = 2
n = 5
batch_size = consts.batch_size
n_res = consts.n_res
a = torch.rand((batch_size, n, 7, 2))
a_gt = torch.rand((batch_size, n, 7, 2))
a_alt_gt = torch.rand((batch_size, n, 7, 2))
a = torch.rand((batch_size, n_res, 7, 2))
a_gt = torch.rand((batch_size, n_res, 7, 2))
a_alt_gt = torch.rand((batch_size, n_res, 7, 2))
loss = torsion_angle_loss(a, a_gt, a_alt_gt)
def test_run_fape(self):
batch_size = 2
batch_size = consts.batch_size
n_frames = 7
n_atoms = 5
......@@ -45,12 +59,23 @@ class TestLoss(unittest.TestCase):
trans_gt = torch.rand((batch_size, n_frames, 3))
t = T(rots, trans)
t_gt = T(rots_gt, trans_gt)
frames_mask = torch.randint(0, 2, (batch_size, n_frames)).float()
positions_mask = torch.randint(0, 2, (batch_size, n_atoms)).float()
length_scale = 10
loss = compute_fape(
pred_frames=t,
target_frames=t_gt,
frames_mask=frames_mask,
pred_positions=x,
target_positions=x_gt,
positions_mask=positions_mask,
length_scale=length_scale,
)
loss = compute_fape(t, x, t_gt, x_gt)
def test_between_residue_bond_loss(self):
bs = 2
n = 10
def test_run_between_residue_bond_loss(self):
bs = consts.batch_size
n = consts.n_res
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
residue_index = torch.arange(n).unsqueeze(0)
......@@ -63,9 +88,52 @@ class TestLoss(unittest.TestCase):
aatype,
)
@compare_utils.skip_unless_alphafold_installed()
def test_between_residue_bond_loss_compare(self):
def run_brbl(pred_pos, pred_atom_mask, residue_index, aatype):
return alphafold.model.all_atom.between_residue_bond_loss(
pred_pos,
pred_atom_mask,
residue_index,
aatype,
)
f = hk.transform(run_brbl)
n_res = consts.n_res
pred_pos = np.random.rand(n_res, 14, 3).astype(np.float32)
pred_atom_mask = np.random.randint(0, 2, (n_res, 14)).astype(np.float32)
residue_index = np.arange(n_res)
aatype = np.random.randint(0, 22, (n_res,))
out_gt = f.apply(
{}, None,
pred_pos,
pred_atom_mask,
residue_index,
aatype,
)
out_gt = jax.tree_map(lambda x: x.block_until_ready(), out_gt)
out_gt = jax.tree_map(lambda x: torch.tensor(np.copy(x)), out_gt)
out_repro = between_residue_bond_loss(
torch.tensor(pred_pos).cuda(),
torch.tensor(pred_atom_mask).cuda(),
torch.tensor(residue_index).cuda(),
torch.tensor(aatype).cuda(),
)
out_repro = tensor_tree_map(lambda x: x.cpu(), out_repro)
for k in out_gt.keys():
self.assertTrue(
torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps
)
def test_between_residue_clash_loss(self):
bs = 2
n = 10
bs = consts.batch_size
n = consts.n_res
pred_pos = torch.rand(bs, n, 14, 3)
pred_atom_mask = torch.randint(0, 2, (bs, n, 14))
atom14_atom_radius = torch.rand(bs, n, 14)
......@@ -79,7 +147,7 @@ class TestLoss(unittest.TestCase):
)
def test_find_structural_violations(self):
n = 10
n = consts.n_res
batch = {
"atom14_atom_exists": torch.randint(0, 2, (n, 14)),
......@@ -90,12 +158,12 @@ class TestLoss(unittest.TestCase):
pred_pos = torch.rand(n, 14, 3)
config = ml_collections.ConfigDict({
config = {
"clash_overlap_tolerance": 1.5,
"violation_tolerance_factor": 12.0,
})
}
find_structural_violations(batch, pred_pos, config)
find_structural_violations(batch, pred_pos, **config)
if __name__ == "__main__":
......
......@@ -12,25 +12,35 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import torch
import torch.nn as nn
import numpy as np
import unittest
from config import *
from alphafold.model.model import *
from alphafold.utils.utils import my_tree_map
from tests.alphafold.utils.utils import (
from openfold.model.model import AlphaFold
import openfold.utils.feats as feats
from openfold.utils.tensor_utils import tree_map, tensor_tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
random_template_feats,
random_extra_msa_feats,
)
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestModel(unittest.TestCase):
def test_dry_run(self):
batch_size = 2
n_seq = 5
n_templ = 7
n_res = 11
n_extra_seq = 13
batch_size = consts.batch_size
n_seq = consts.n_seq
n_templ = consts.n_templ
n_res = consts.n_res
n_extra_seq = consts.n_extra
c = model_config("model_1").model
c.no_cycles = 2
......@@ -59,20 +69,65 @@ class TestModel(unittest.TestCase):
batch.update({k:torch.tensor(v) for k, v in extra_feats.items()})
batch["msa_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_seq, n_res)
)
).float()
batch["seq_mask"] = torch.randint(
low=0, high=2, size=(batch_size, n_res)
)
).float()
batch.update(feats.compute_residx(batch))
add_recycling_dims = lambda t: (
t.unsqueeze(-1).expand(*t.shape, c.no_cycles)
)
batch = my_tree_map(add_recycling_dims, batch, torch.Tensor)
batch = tensor_tree_map(add_recycling_dims, batch)
with torch.no_grad():
out = model(batch)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_alphafold(batch):
config = compare_utils.get_alphafold_config()
model = alphafold.model.modules.AlphaFold(config.model)
return model(
batch=batch, is_training=False, return_representations=True,
)
f = hk.transform(run_alphafold)
params = compare_utils.fetch_alphafold_module_weights('')
with open("tests/test_data/sample_feats.pickle", "rb") as fp:
batch = pickle.load(fp)
out_gt = jax.jit(f.apply)(params, jax.random.PRNGKey(42), batch)
out_gt = out_gt["structure_module"]["final_atom_positions"]
# atom37_to_atom14 doesn't like batches
batch["residx_atom14_to_atom37"] = batch["residx_atom14_to_atom37"][0]
batch["atom14_atom_exists"] = batch["atom14_atom_exists"][0]
out_gt = alphafold.model.all_atom.atom37_to_atom14(out_gt, batch)
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
batch = {
k:torch.as_tensor(v).cuda() for k,v in batch.items()
}
batch["aatype"] = batch["aatype"].long()
batch["template_aatype"] = batch["template_aatype"].long()
batch["extra_msa"] = batch["extra_msa"].long()
batch["residx_atom37_to_atom14"] = batch["residx_atom37_to_atom14"].long()
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
out_repro = model(batch)
out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro)
out_repro = out_repro["sm"]["positions"][-1]
out_repro = out_repro.squeeze(0)
if __name__ == "__main__":
unittest.main()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 1e-3))
......@@ -15,23 +15,36 @@
import torch
import numpy as np
import unittest
from alphafold.model.msa import *
from openfold.model.msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestMSARowAttentionWithPairBias(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
c_z = 11
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c_z = consts.c_z
c = 52
no_heads = 4
chunk_size=None
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads)
mrapb = MSARowAttentionWithPairBias(c_m, c_z, c, no_heads, chunk_size)
m = torch.rand((batch_size, s_t, n, c_m))
z = torch.rand((batch_size, n, n, c_z))
m = torch.rand((batch_size, n_seq, n_res, c_m))
z = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = m.shape
m = mrapb(m, z)
......@@ -39,19 +52,65 @@ class TestMSARowAttentionWithPairBias(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_msa_row_att(msa_act, msa_mask, pair_act):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_row = alphafold.model.modules.MSARowAttentionWithPairBias(
c_e.msa_row_attention_with_pair_bias,
config.model.global_config
)
act = msa_row(
msa_act=msa_act, msa_mask=msa_mask, pair_act=pair_act
)
return act
f = hk.transform(run_msa_row_att)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_row_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask, pair_act
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_row(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class TestMSAColumnAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c = 44
no_heads = 4
msaca = MSAColumnAttention(c_m, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_m))
x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape
x = msaca(x)
......@@ -59,19 +118,63 @@ class TestMSAColumnAttention(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_msa_col_att(msa_act, msa_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_col = alphafold.model.modules.MSAColumnAttention(
c_e.msa_column_attention,
config.model.global_config
)
act = msa_col(
msa_act=msa_act, msa_mask=msa_mask
)
return act
f = hk.transform(run_msa_col_att)
n_res = consts.n_res
n_seq = consts.n_seq
msa_act = np.random.rand(n_seq, n_res, consts.c_m).astype(np.float32)
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"msa_column_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].msa_att_col(
torch.as_tensor(msa_act).cuda(),
torch.as_tensor(msa_mask).cuda(),
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class TestMSAColumnGlobalAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
n = 5
c_m = 7
batch_size = consts.batch_size
n_seq = consts.n_seq
n_res = consts.n_res
c_m = consts.c_m
c = 44
no_heads = 4
msagca = MSAColumnGlobalAttention(c_m, c, no_heads)
x = torch.rand((batch_size, s_t, n, c_m))
x = torch.rand((batch_size, n_seq, n_res, c_m))
shape_before = x.shape
x = msagca(x)
......@@ -79,6 +182,48 @@ class TestMSAColumnGlobalAttention(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_msa_col_global_att(msa_act, msa_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
msa_col = alphafold.model.modules.MSAColumnGlobalAttention(
c_e.msa_column_attention,
config.model.global_config,
name="msa_column_global_attention"
)
act = msa_col(msa_act=msa_act, msa_mask=msa_mask)
return act
f = hk.transform(run_msa_col_global_att)
n_res = consts.n_res
n_seq = consts.n_seq
c_e = consts.c_e
msa_act = np.random.rand(n_seq, n_res, c_e)
msa_mask = np.random.randint(low=0, high=2, size=(n_seq, n_res))
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/extra_msa_stack/" +
"msa_column_global_attention"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.extra_msa_stack.stack.blocks[0].msa_att_col(
torch.as_tensor(msa_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(msa_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......@@ -15,25 +15,79 @@
import torch
import numpy as np
import unittest
from alphafold.model.outer_product_mean import *
from openfold.model.outer_product_mean import OuterProductMean
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestOuterProductMean(unittest.TestCase):
def test_shape(self):
batch_size = 2
s = 5
n_res = 7
c_m = 11
c = 13
c_z = 17
c = 31
opm = OuterProductMean(c_m, c_z, c)
opm = OuterProductMean(consts.c_m, consts.c_z, c)
m = torch.rand((batch_size, s, n_res, c_m))
mask = torch.randint(0, 2, size=(batch_size, s, n_res))
m = torch.rand(
(consts.batch_size, consts.n_seq, consts.n_res, consts.c_m)
)
mask = torch.randint(
0, 2, size=(consts.batch_size, consts.n_seq, consts.n_res)
)
m = opm(m, mask)
self.assertTrue(m.shape == (batch_size, n_res, n_res, c_z))
self.assertTrue(
m.shape == (consts.batch_size, consts.n_res, consts.n_res, consts.c_z)
)
@compare_utils.skip_unless_alphafold_installed()
def test_opm_compare(self):
def run_opm(msa_act, msa_mask):
config = compare_utils.get_alphafold_config()
c_evo = config.model.embeddings_and_evoformer.evoformer
opm = alphafold.model.modules.OuterProductMean(
c_evo.outer_product_mean,
config.model.global_config,
consts.c_z,
)
act = opm(act=msa_act, mask=msa_mask)
return act
f = hk.transform(run_opm)
n_res = consts.n_res
n_seq = consts.n_seq
c_m = consts.c_m
msa_act = np.random.rand(n_seq, n_res, c_m).astype(np.float32) * 100
msa_mask = np.random.randint(
low=0, high=2, size=(n_seq, n_res)
).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/" +
"evoformer_iteration/outer_product_mean"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, msa_act, msa_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].outer_product_mean(
torch.as_tensor(msa_act).cuda(),
mask=torch.as_tensor(msa_mask).cuda(),
).cpu()
# Even when correct, OPM has large, precision-related errors. It gets
# a special pass from consts.eps.
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < 5e-4))
if __name__ == "__main__":
......
......@@ -15,18 +15,26 @@
import torch
import numpy as np
import unittest
from alphafold.model.pair_transition import *
from openfold.model.pair_transition import PairTransition
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestPairTransition(unittest.TestCase):
def test_shape(self):
c_z = 5
c_z = consts.c_z
n = 4
pt = PairTransition(c_z, n)
batch_size = 4
n_res = 256
batch_size = consts.batch_size
n_res = consts.n_res
z = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
......@@ -36,6 +44,47 @@ class TestPairTransition(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_pair_transition(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
pt = alphafold.model.modules.Transition(
c_e.pair_transition,
config.model.global_config,
name="pair_transition"
)
act = pt(act=pair_act, mask=pair_mask)
return act
f = hk.transform(run_pair_transition)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.ones((n_res, n_res)).astype(np.float32) # no mask
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
"pair_transition"
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt.block_until_ready()))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.evoformer.blocks[0].pair_transition(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
unittest.main()
......
......@@ -16,18 +16,28 @@ import torch
import numpy as np
import unittest
from alphafold.np.residue_constants import (
from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from alphafold.model.structure_module import *
from alphafold.model.structure_module import (
from openfold.model.structure_module import *
from openfold.model.structure_module import (
_torsion_angles_to_frames,
_frames_and_literature_positions_to_atom14_pos,
)
from alphafold.utils.utils import T
from openfold.utils.affine_utils import T
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import (
random_affines_4x4,
)
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestStructureModule(unittest.TestCase):
......@@ -75,7 +85,7 @@ class TestStructureModule(unittest.TestCase):
out = sm(s, z, f)
self.assertTrue(
out["transformations"].shape == (no_layers, batch_size, n, 4, 4)
out["frames"].shape == (no_layers, batch_size, n, 4, 4)
)
self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
......@@ -190,6 +200,62 @@ class TestInvariantPointAttention(unittest.TestCase):
self.assertTrue(s.shape == shape_before)
@compare_utils.skip_unless_alphafold_installed()
def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention(
config.model.heads.structure_module,
config.model.global_config,
)
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn
f = hk.transform(run_ipa)
n_res = consts.n_res
c_s = consts.c_s
c_z = consts.c_z
sample_act = np.random.rand(n_res, c_s)
sample_2d = np.random.rand(n_res, n_res, c_z)
sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = T.from_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/" +
"fold_iteration/invariant_point_attention"
)
out_gt = f.apply(
ipa_params, None, sample_act, sample_2d, sample_mask, sample_affine
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.structure_module.ipa(
torch.as_tensor(sample_act).float().cuda(),
torch.as_tensor(sample_2d).float().cuda(),
transformations,
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps)
class TestAngleResnet(unittest.TestCase):
def test_shape(self):
......
......@@ -15,23 +15,38 @@
import torch
import numpy as np
import unittest
from alphafold.model.template import *
from openfold.model.template import (
TemplatePointwiseAttention,
TemplatePairStack,
)
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
from tests.data_utils import random_template_feats
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTemplatePointwiseAttention(unittest.TestCase):
def test_shape(self):
batch_size = 2
s_t = 3
c_t = 5
c_z = 7
batch_size = consts.batch_size
n_seq = consts.n_seq
c_t = consts.c_t
c_z = consts.c_z
c = 26
no_heads = 13
n = 17
n_res = consts.n_res
inf = 1e7
tpa = TemplatePointwiseAttention(c_t, c_z, c, no_heads, chunk_size=4)
tpa = TemplatePointwiseAttention(
c_t, c_z, c, no_heads, chunk_size=4, inf=inf
)
t = torch.rand((batch_size, s_t, n, n, c_t))
z = torch.rand((batch_size, n, n, c_z))
t = torch.rand((batch_size, n_seq, n_res, n_res, c_t))
z = torch.rand((batch_size, n_res, n_res, c_z))
z_update = tpa(t, z)
......@@ -40,17 +55,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase):
def test_shape(self):
batch_size = 2
c_t = 5
batch_size = consts.batch_size
c_t = consts.c_t
c_hidden_tri_att = 7
c_hidden_tri_mul = 7
no_blocks = 2
no_heads = 4
pt_inner_dim = 15
dropout = 0.25
n_templ = 3
n_res = 5
n_templ = consts.n_templ
n_res = consts.n_res
blocks_per_ckpt = None
chunk_size = 4
inf=1e7
eps=1e-7
tpe = TemplatePairStack(
c_t,
......@@ -60,7 +78,10 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads,
pair_transition_n=pt_inner_dim,
dropout_rate=dropout,
blocks_per_ckpt=None,
chunk_size=chunk_size,
inf=inf,
eps=eps,
)
t = torch.rand((batch_size, n_templ, n_res, n_res, c_t))
......@@ -71,7 +92,98 @@ class TestTemplatePairStack(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack"
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act)
return act
f = hk.transform(run_template_pair_stack)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_t).astype(np.float32)
pair_mask = np.random.randint(
low=0, high=2, size=(n_res, n_res)
).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/template_pair_stack"
)
params.update(compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" +
"single_template_embedding/output_layer_norm"
))
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack(
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
_mask_trans=False,
).cpu()
self.assertTrue(torch.all(torch.abs(out_gt - out_repro) < consts.eps))
class Template(unittest.TestCase):
@compare_utils.skip_unless_alphafold_installed()
def test_compare(self):
def test_template_embedding(pair, batch, mask_2d):
config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template,
config.model.global_config
)
act = te(pair, batch, mask_2d, is_training=False)
return act
f = hk.transform(test_template_embedding)
n_res = consts.n_res
n_templ = consts.n_templ
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding"
)
out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k:torch.as_tensor(v).cuda() for k,v in batch.items()},
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
)
out_repro = out_repro["template_pair_embedding"]
out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
if __name__ == "__main__":
......
......@@ -15,12 +15,21 @@
import torch
import numpy as np
import unittest
from alphafold.model.triangular_attention import *
from openfold.model.triangular_attention import TriangleAttention
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTriangularAttention(unittest.TestCase):
def test_shape(self):
c_z = 2
c_z = consts.c_z
c = 12
no_heads = 4
starting = True
......@@ -32,8 +41,8 @@ class TestTriangularAttention(unittest.TestCase):
starting
)
batch_size = 4
n_res = 7
batch_size = consts.batch_size
n_res = consts.n_res
x = torch.rand((batch_size, n_res, n_res, c_z))
shape_before = x.shape
......@@ -42,9 +51,61 @@ class TestTriangularAttention(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
if __name__ == "__main__":
unittest.main()
def _tri_att_compare(self, starting=False):
name = (
"triangle_attention_" +
("starting" if starting else "ending") +
"_node"
)
def run_tri_att(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_att = alphafold.model.modules.TriangleAttention(
c_e.triangle_attention_starting_node if starting else
c_e.triangle_attention_ending_node,
config.model.global_config,
name=name,
)
act = tri_att(pair_act=pair_act, pair_mask=pair_mask)
return act
f = hk.transform(run_tri_att)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_att_start if starting else
model.evoformer.blocks[0].tri_att_end
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
def test_tri_att_end_compare(self):
self._tri_att_compare()
def test_tri_att_start_compare(self):
self._tri_att_compare(starting=True)
if __name__ == "__main__":
unittest.main()
......@@ -15,12 +15,20 @@
import torch
import numpy as np
import unittest
from alphafold.model.triangular_multiplicative_update import *
from openfold.model.triangular_multiplicative_update import *
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils
from tests.config import consts
if(compare_utils.alphafold_is_installed()):
alphafold = compare_utils.import_alphafold()
import jax
import haiku as hk
class TestTriangularMultiplicativeUpdate(unittest.TestCase):
def test_shape(self):
c_z = 7
c_z = consts.c_z
c = 11
outgoing = True
......@@ -30,8 +38,8 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
outgoing,
)
n_res = 5
batch_size = 2
n_res = consts.c_z
batch_size = consts.batch_size
x = torch.rand((batch_size, n_res, n_res, c_z))
mask = torch.randint(0, 2, size=(batch_size, n_res, n_res))
......@@ -41,6 +49,63 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
self.assertTrue(shape_before == shape_after)
def _tri_mul_compare(self, incoming=False):
name = (
"triangle_multiplication_" +
("incoming" if incoming else "outgoing")
)
def run_tri_mul(pair_act, pair_mask):
config = compare_utils.get_alphafold_config()
c_e = config.model.embeddings_and_evoformer.evoformer
tri_mul = alphafold.model.modules.TriangleMultiplication(
c_e.triangle_multiplication_incoming if incoming else
c_e.triangle_multiplication_outgoing,
config.model.global_config,
name=name,
)
act = tri_mul(act=pair_act, mask=pair_mask)
return act
f = hk.transform(run_tri_mul)
n_res = consts.n_res
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32)
# Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" +
name
)
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray)
out_gt = f.apply(
params, None, pair_act, pair_mask
).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold()
module = (
model.evoformer.blocks[0].tri_mul_in if incoming else
model.evoformer.blocks[0].tri_mul_out
)
out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps))
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self):
self._tri_mul_compare()
@compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_in_compare(self):
self._tri_mul_compare(incoming=True)
if __name__ == "__main__":
unittest.main()
......
......@@ -16,7 +16,8 @@ import math
import torch
import unittest
from alphafold.utils.utils import *
from openfold.utils.affine_utils import *
from openfold.utils.tensor_utils import *
X_90_ROT = torch.tensor([
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment