Unverified Commit bb3f51e5 authored by Christina Floristean's avatar Christina Floristean Committed by GitHub
Browse files

Merge pull request #405 from aqlaboratory/multimer

Full multimer merge
parents ce211367 c33a0bd6
...@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase): ...@@ -30,7 +30,7 @@ class TestLMA(unittest.TestCase):
q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size, q, kv, _, biases = random_attention_inputs(batch_size=consts.batch_size,
n_seq=consts.n_seq, n_seq=consts.n_seq,
n=2**12, n=2 ** 12,
no_heads=no_heads, no_heads=no_heads,
c_hidden=c_hidden) c_hidden=c_hidden)
...@@ -44,10 +44,10 @@ class TestLMA(unittest.TestCase): ...@@ -44,10 +44,10 @@ class TestLMA(unittest.TestCase):
l = a(q, kv, biases=biases, use_lma=True).cpu() l = a(q, kv, biases=biases, use_lma=True).cpu()
real = a(q, kv, biases=biases).cpu() real = a(q, kv, biases=biases).cpu()
err = torch.max(torch.abs(l - real)) err = torch.max(torch.abs(l - real))
self.assertTrue(err < consts.eps, f'Error: {err}') self.assertTrue(err < consts.eps, f'Error: {err}')
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
\ No newline at end of file
...@@ -18,21 +18,19 @@ import unittest ...@@ -18,21 +18,19 @@ import unittest
from openfold.data.data_transforms import make_atom14_masks_np from openfold.data.data_transforms import make_atom14_masks_np
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask, restype_atom14_mask,
restype_atom14_rigid_group_positions,
restype_atom37_mask, restype_atom37_mask,
) )
from openfold.model.structure_module import ( from openfold.model.structure_module import (
StructureModule, StructureModule,
StructureModuleTransition, StructureModuleTransition,
BackboneUpdate,
AngleResnet, AngleResnet,
InvariantPointAttention, InvariantPointAttention,
) )
import openfold.utils.feats as feats
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array
from openfold.utils.geometry.rotation_matrix import Rot3Array
from openfold.utils.geometry.vector import Vec3Array
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import ( from tests.data_utils import (
...@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed(): ...@@ -46,6 +44,20 @@ if compare_utils.alphafold_is_installed():
class TestStructureModule(unittest.TestCase): class TestStructureModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_structure_module_shape(self): def test_structure_module_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
n = consts.n_res n = consts.n_res
...@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase): ...@@ -81,6 +93,7 @@ class TestStructureModule(unittest.TestCase):
trans_scale_factor, trans_scale_factor,
ar_epsilon, ar_epsilon,
inf, inf,
is_multimer=consts.is_multimer
) )
s = torch.rand((batch_size, n, c_s)) s = torch.rand((batch_size, n, c_s))
...@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase): ...@@ -89,7 +102,11 @@ class TestStructureModule(unittest.TestCase):
out = sm({"single": s, "pair": z}, f) out = sm({"single": s, "pair": z}, f)
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7)) if consts.is_multimer:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4))
else:
self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7))
self.assertTrue( self.assertTrue(
out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) out["angles"].shape == (no_layers, batch_size, n, no_angles, 2)
) )
...@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase): ...@@ -121,11 +138,14 @@ class TestStructureModule(unittest.TestCase):
c_global = config.model.global_config c_global = config.model.global_config
def run_sm(representations, batch): def run_sm(representations, batch):
sm = alphafold.model.folding.StructureModule(c_sm, c_global) sm = self.am_fold.StructureModule(c_sm, c_global)
representations = { representations = {
k: jax.lax.stop_gradient(v) for k, v in representations.items() k: jax.lax.stop_gradient(v) for k, v in representations.items()
} }
batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()} batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()}
if consts.is_multimer:
return sm(representations, batch, is_training=False, compute_loss=True)
return sm(representations, batch, is_training=False) return sm(representations, batch, is_training=False)
f = hk.transform(run_sm) f = hk.transform(run_sm)
...@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase): ...@@ -177,10 +197,24 @@ class TestStructureModule(unittest.TestCase):
# The structure module, thanks to angle normalization, is very volatile # The structure module, thanks to angle normalization, is very volatile
# We only assess the mean here. Heuristically speaking, it seems to # We only assess the mean here. Heuristically speaking, it seems to
# have lower error in general on real rather than synthetic data. # have lower error in general on real rather than synthetic data.
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < 0.05) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05)
class TestInvariantPointAttention(unittest.TestCase): class TestInvariantPointAttention(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
c_m = 13 c_m = 13
c_z = 17 c_z = 17
...@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -197,13 +231,18 @@ class TestInvariantPointAttention(unittest.TestCase):
mask = torch.ones((batch_size, n_res)) mask = torch.ones((batch_size, n_res))
rot_mats = torch.rand((batch_size, n_res, 3, 3)) rot_mats = torch.rand((batch_size, n_res, 3, 3))
rots = Rotation(rot_mats=rot_mats, quats=None)
trans = torch.rand((batch_size, n_res, 3)) trans = torch.rand((batch_size, n_res, 3))
r = Rigid(rots, trans) if consts.is_multimer:
rotation = Rot3Array.from_array(rot_mats)
translation = Vec3Array.from_array(trans)
r = Rigid3Array(rotation, translation)
else:
rots = Rotation(rot_mats=rot_mats, quats=None)
r = Rigid(rots, trans)
ipa = InvariantPointAttention( ipa = InvariantPointAttention(
c_m, c_z, c_hidden, no_heads, no_qp, no_vp c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer
) )
shape_before = s.shape shape_before = s.shape
...@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -215,16 +254,26 @@ class TestInvariantPointAttention(unittest.TestCase):
def test_ipa_compare(self): def test_ipa_compare(self):
def run_ipa(act, static_feat_2d, mask, affine): def run_ipa(act, static_feat_2d, mask, affine):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
ipa = alphafold.model.folding.InvariantPointAttention( ipa = self.am_fold.InvariantPointAttention(
config.model.heads.structure_module, config.model.heads.structure_module,
config.model.global_config, config.model.global_config,
) )
attn = ipa(
inputs_1d=act, if consts.is_multimer:
inputs_2d=static_feat_2d, attn = ipa(
mask=mask, inputs_1d=act,
affine=affine, inputs_2d=static_feat_2d,
) mask=mask,
rigid=affine
)
else:
attn = ipa(
inputs_1d=act,
inputs_2d=static_feat_2d,
mask=mask,
affine=affine
)
return attn return attn
f = hk.transform(run_ipa) f = hk.transform(run_ipa)
...@@ -238,13 +287,20 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -238,13 +287,20 @@ class TestInvariantPointAttention(unittest.TestCase):
sample_mask = np.ones((n_res, 1)) sample_mask = np.ones((n_res, 1))
affines = random_affines_4x4((n_res,)) affines = random_affines_4x4((n_res,))
rigids = alphafold.model.r3.rigids_from_tensor4x4(affines)
quats = alphafold.model.r3.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats if consts.is_multimer:
rigids = self.am_rigid.Rigid3Array.from_array4x4(affines)
transformations = Rigid3Array.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = rigids
else:
rigids = self.am_rigid.rigids_from_tensor4x4(affines)
quats = self.am_rigid.rigids_to_quataffine(rigids)
transformations = Rigid.from_tensor_4x4(
torch.as_tensor(affines).float().cuda()
)
sample_affine = quats
ipa_params = compare_utils.fetch_alphafold_module_weights( ipa_params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/structure_module/" "alphafold/alphafold_iteration/structure_module/"
...@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase): ...@@ -265,7 +321,7 @@ class TestInvariantPointAttention(unittest.TestCase):
torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(), torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(),
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class TestAngleResnet(unittest.TestCase): class TestAngleResnet(unittest.TestCase):
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import re
import torch import torch
import numpy as np import numpy as np
import unittest import unittest
...@@ -19,7 +20,6 @@ from openfold.model.template import ( ...@@ -19,7 +20,6 @@ from openfold.model.template import (
TemplatePointwiseAttention, TemplatePointwiseAttention,
TemplatePairStack, TemplatePairStack,
) )
from openfold.utils.tensor_utils import tree_map
import tests.compare_utils as compare_utils import tests.compare_utils as compare_utils
from tests.config import consts from tests.config import consts
from tests.data_utils import random_template_feats from tests.data_utils import random_template_feats
...@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase): ...@@ -54,6 +54,20 @@ class TestTemplatePointwiseAttention(unittest.TestCase):
class TestTemplatePairStack(unittest.TestCase): class TestTemplatePairStack(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
def test_shape(self): def test_shape(self):
batch_size = consts.batch_size batch_size = consts.batch_size
c_t = consts.c_t c_t = consts.c_t
...@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -65,6 +79,8 @@ class TestTemplatePairStack(unittest.TestCase):
dropout = 0.25 dropout = 0.25
n_templ = consts.n_templ n_templ = consts.n_templ
n_res = consts.n_res n_res = consts.n_res
tri_mul_first = consts.is_multimer
fuse_projection_weights = True if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model) else False
blocks_per_ckpt = None blocks_per_ckpt = None
chunk_size = 4 chunk_size = 4
inf = 1e7 inf = 1e7
...@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -78,6 +94,8 @@ class TestTemplatePairStack(unittest.TestCase):
no_heads=no_heads, no_heads=no_heads,
pair_transition_n=pt_inner_dim, pair_transition_n=pt_inner_dim,
dropout_rate=dropout, dropout_rate=dropout,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
blocks_per_ckpt=None, blocks_per_ckpt=None,
inf=inf, inf=inf,
eps=eps, eps=eps,
...@@ -96,12 +114,40 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -96,12 +114,40 @@ class TestTemplatePairStack(unittest.TestCase):
def run_template_pair_stack(pair_act, pair_mask): def run_template_pair_stack(pair_act, pair_mask):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
c_ee = config.model.embeddings_and_evoformer c_ee = config.model.embeddings_and_evoformer
tps = alphafold.model.modules.TemplatePairStack(
c_ee.template.template_pair_stack, if consts.is_multimer:
config.model.global_config, safe_key = alphafold.model.prng.SafeKey(hk.next_rng_key())
name="template_pair_stack", template_iteration = self.am_modules.TemplateEmbeddingIteration(
) c_ee.template.template_pair_stack,
act = tps(pair_act, pair_mask, is_training=False) config.model.global_config,
name='template_embedding_iteration')
def template_iteration_fn(x):
act, safe_key = x
safe_key, safe_subkey = safe_key.split()
act = template_iteration(
act=act,
pair_mask=pair_mask,
is_training=False,
safe_key=safe_subkey)
return (act, safe_key)
if config.model.global_config.use_remat:
template_iteration_fn = hk.remat(template_iteration_fn)
safe_key, safe_subkey = safe_key.split()
template_stack = alphafold.model.layer_stack.layer_stack(
c_ee.template.template_pair_stack.num_block)(
template_iteration_fn)
act, _ = template_stack((pair_act, safe_subkey))
else:
tps = self.am_modules.TemplatePairStack(
c_ee.template.template_pair_stack,
config.model.global_config,
name="template_pair_stack",
)
act = tps(pair_act, pair_mask, is_training=False)
ln = hk.LayerNorm([-1], True, True, name="output_layer_norm") ln = hk.LayerNorm([-1], True, True, name="output_layer_norm")
act = ln(act) act = ln(act)
return act return act
...@@ -115,10 +161,16 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -115,10 +161,16 @@ class TestTemplatePairStack(unittest.TestCase):
low=0, high=2, size=(n_res, n_res) low=0, high=2, size=(n_res, n_res)
).astype(np.float32) ).astype(np.float32)
params = compare_utils.fetch_alphafold_module_weights( if consts.is_multimer:
"alphafold/alphafold_iteration/evoformer/template_embedding/" params = compare_utils.fetch_alphafold_module_weights(
+ "single_template_embedding/template_pair_stack" "alphafold/alphafold_iteration/evoformer/template_embedding/"
) + "single_template_embedding/template_embedding_iteration"
)
else:
params = compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/"
+ "single_template_embedding/template_pair_stack"
)
params.update( params.update(
compare_utils.fetch_alphafold_module_weights( compare_utils.fetch_alphafold_module_weights(
"alphafold/alphafold_iteration/evoformer/template_embedding/" "alphafold/alphafold_iteration/evoformer/template_embedding/"
...@@ -132,26 +184,44 @@ class TestTemplatePairStack(unittest.TestCase): ...@@ -132,26 +184,44 @@ class TestTemplatePairStack(unittest.TestCase):
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.template_pair_stack( out_repro = model.template_embedder.template_pair_stack(
torch.as_tensor(pair_act).unsqueeze(-4).cuda(), torch.as_tensor(pair_act).unsqueeze(-4).cuda(),
torch.as_tensor(pair_mask).unsqueeze(-3).cuda(), torch.as_tensor(pair_mask).unsqueeze(-3).cuda(),
chunk_size=None, chunk_size=None,
_mask_trans=False, _mask_trans=False,
).cpu() ).cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps)
class Template(unittest.TestCase): class Template(unittest.TestCase):
@classmethod
def setUpClass(cls):
if compare_utils.alphafold_is_installed():
if consts.is_multimer:
cls.am_atom = alphafold.model.all_atom_multimer
cls.am_fold = alphafold.model.folding_multimer
cls.am_modules = alphafold.model.modules_multimer
cls.am_rigid = alphafold.model.geometry
else:
cls.am_atom = alphafold.model.all_atom
cls.am_fold = alphafold.model.folding
cls.am_modules = alphafold.model.modules
cls.am_rigid = alphafold.model.r3
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_compare(self): def test_compare(self):
def test_template_embedding(pair, batch, mask_2d): def test_template_embedding(pair, batch, mask_2d, mc_mask_2d):
config = compare_utils.get_alphafold_config() config = compare_utils.get_alphafold_config()
te = alphafold.model.modules.TemplateEmbedding( te = self.am_modules.TemplateEmbedding(
config.model.embeddings_and_evoformer.template, config.model.embeddings_and_evoformer.template,
config.model.global_config, config.model.global_config,
) )
act = te(pair, batch, mask_2d, is_training=False)
if consts.is_multimer:
act = te(pair, batch, mask_2d, multichain_mask_2d=mc_mask_2d, is_training=False)
else:
act = te(pair, batch, mask_2d, is_training=False)
return act return act
f = hk.transform(test_template_embedding) f = hk.transform(test_template_embedding)
...@@ -162,6 +232,14 @@ class Template(unittest.TestCase): ...@@ -162,6 +232,14 @@ class Template(unittest.TestCase):
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32) pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res) batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"] batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
multichain_mask_2d = None
if consts.is_multimer:
asym_id = batch['asym_id'][0]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32) pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
# Fetch pretrained parameters (but only from one block)] # Fetch pretrained parameters (but only from one block)]
params = compare_utils.fetch_alphafold_module_weights( params = compare_utils.fetch_alphafold_module_weights(
...@@ -169,7 +247,7 @@ class Template(unittest.TestCase): ...@@ -169,7 +247,7 @@ class Template(unittest.TestCase):
) )
out_gt = f.apply( out_gt = f.apply(
params, jax.random.PRNGKey(42), pair_act, batch, pair_mask params, jax.random.PRNGKey(42), pair_act, batch, pair_mask, multichain_mask_2d
).block_until_ready() ).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
...@@ -177,17 +255,36 @@ class Template(unittest.TestCase): ...@@ -177,17 +255,36 @@ class Template(unittest.TestCase):
batch["target_feat"] = np.eye(22)[inds] batch["target_feat"] = np.eye(22)[inds]
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()}, template_feats = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
torch.as_tensor(pair_act).cuda(), if consts.is_multimer:
torch.as_tensor(pair_mask).cuda(), out_repro_all = model.template_embedder(
templ_dim=0, template_feats,
inplace_safe=False torch.as_tensor(pair_act).cuda(),
) torch.as_tensor(pair_mask).cuda(),
out_repro = out_repro["template_pair_embedding"] templ_dim=0,
chunk_size=consts.chunk_size,
multichain_mask_2d=torch.as_tensor(multichain_mask_2d).cuda(),
_mask_trans=False,
use_lma=False,
inplace_safe=False
)
else:
out_repro_all = model.template_embedder(
template_feats,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
chunk_size=consts.chunk_size,
mask_trans=False,
use_lma=False,
inplace_safe=False
)
out_repro = out_repro_all["template_pair_embedding"]
out_repro = out_repro.cpu() out_repro = out_repro.cpu()
self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -79,16 +79,16 @@ class TestTriangularAttention(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_att_start model.evoformer.blocks[0].pair_stack.tri_att_start
if starting if starting
else model.evoformer.blocks[0].core.tri_att_end else model.evoformer.blocks[0].pair_stack.tri_att_end
) )
# To save memory, the full model transposes inputs outside of the # To save memory, the full model transposes inputs outside of the
...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestTriangularAttention(unittest.TestCase):
chunk_size=None, chunk_size=None,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_att_end_compare(self): def test_tri_att_end_compare(self):
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
# limitations under the License. # limitations under the License.
import torch import torch
import re
import numpy as np import numpy as np
import unittest import unittest
from openfold.model.triangular_multiplicative_update import * from openfold.model.triangular_multiplicative_update import *
...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -31,10 +32,16 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
c_z = consts.c_z c_z = consts.c_z
c = 11 c = 11
tm = TriangleMultiplicationOutgoing( if re.fullmatch("^model_[1-5]_multimer_v3$", consts.model):
c_z, tm = FusedTriangleMultiplicationOutgoing(
c, c_z,
) c,
)
else:
tm = TriangleMultiplicationOutgoing(
c_z,
c,
)
n_res = consts.c_z n_res = consts.c_z
batch_size = consts.batch_size batch_size = consts.batch_size
...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -62,7 +69,7 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
config.model.global_config, config.model.global_config,
name=name, name=name,
) )
act = tri_mul(act=pair_act, mask=pair_mask) act = tri_mul(pair_act, pair_mask)
return act return act
f = hk.transform(run_tri_mul) f = hk.transform(run_tri_mul)
...@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -78,24 +85,25 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
"alphafold/alphafold_iteration/evoformer/evoformer_iteration/" "alphafold/alphafold_iteration/evoformer/evoformer_iteration/"
+ name + name
) )
params = tree_map(lambda n: n[0], params, jax.numpy.DeviceArray) params = tree_map(lambda n: n[0], params, jax.Array)
out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready() out_gt = f.apply(params, None, pair_act, pair_mask).block_until_ready()
out_gt = torch.as_tensor(np.array(out_gt)) out_gt = torch.as_tensor(np.array(out_gt))
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_mul_in model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].core.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_repro = module( out_repro = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(), mask=torch.as_tensor(pair_mask, dtype=torch.float32).cuda(),
inplace_safe=True, _inplace_chunk_size=4, inplace_safe=True, _inplace_chunk_size=4,
).cpu() ).cpu()
self.assertTrue(torch.mean(torch.abs(out_gt - out_repro)) < consts.eps) compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, consts.eps)
@compare_utils.skip_unless_alphafold_installed() @compare_utils.skip_unless_alphafold_installed()
def test_tri_mul_out_compare(self): def test_tri_mul_out_compare(self):
...@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase): ...@@ -112,12 +120,11 @@ class TestTriangularMultiplicativeUpdate(unittest.TestCase):
pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res)) pair_mask = np.random.randint(low=0, high=2, size=(n_res, n_res))
pair_mask = pair_mask.astype(np.float32) pair_mask = pair_mask.astype(np.float32)
model = compare_utils.get_global_pretrained_openfold() model = compare_utils.get_global_pretrained_openfold()
module = ( module = (
model.evoformer.blocks[0].core.tri_mul_in model.evoformer.blocks[0].pair_stack.tri_mul_in
if incoming if incoming
else model.evoformer.blocks[0].core.tri_mul_out else model.evoformer.blocks[0].pair_stack.tri_mul_out
) )
out_stock = module( out_stock = module(
torch.as_tensor(pair_act, dtype=torch.float32).cuda(), torch.as_tensor(pair_act, dtype=torch.float32).cuda(),
......
import argparse import argparse
import logging import logging
import os import os
import random
import sys import sys
import time
import numpy as np
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor from pytorch_lightning.callbacks.lr_monitor import LearningRateMonitor
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin from pytorch_lightning.plugins.training_type import DeepSpeedPlugin, DDPPlugin
from pytorch_lightning.plugins.environments import SLURMEnvironment
import torch import torch
from openfold.config import model_config from openfold.config import model_config
from openfold.data.data_modules import ( from openfold.data.data_modules import OpenFoldDataModule, OpenFoldMultimerDataModule
OpenFoldDataModule,
DummyDataLoader,
)
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants from openfold.np import residue_constants
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse_utils import remove_arguments
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca
from openfold.utils.lr_schedulers import AlphaFoldLRScheduler from openfold.utils.lr_schedulers import AlphaFoldLRScheduler
from openfold.utils.multi_chain_permutation import multi_chain_permutation_align
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
...@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import ( ...@@ -39,6 +33,7 @@ from openfold.utils.validation_metrics import (
) )
from openfold.utils.import_weights import ( from openfold.utils.import_weights import (
import_jax_weights_, import_jax_weights_,
import_openfold_weights_
) )
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint, get_fp32_state_dict_from_zero_checkpoint,
...@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -53,7 +48,10 @@ class OpenFoldWrapper(pl.LightningModule):
super(OpenFoldWrapper, self).__init__() super(OpenFoldWrapper, self).__init__()
self.config = config self.config = config
self.model = AlphaFold(config) self.model = AlphaFold(config)
self.is_multimer = self.config.globals.is_multimer
self.loss = AlphaFoldLoss(config.loss) self.loss = AlphaFoldLoss(config.loss)
self.ema = ExponentialMovingAverage( self.ema = ExponentialMovingAverage(
model=self.model, decay=config.ema.decay model=self.model, decay=config.ema.decay
) )
...@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -98,12 +96,19 @@ class OpenFoldWrapper(pl.LightningModule):
if(self.ema.device != batch["aatype"].device): if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
ground_truth = batch.pop('gt_features', None)
# Run the model # Run the model
outputs = self(batch) outputs = self(batch)
# Remove the recycling dimension # Remove the recycling dimension
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss # Compute loss
loss, loss_breakdown = self.loss( loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, batch, _return_breakdown=True
...@@ -126,13 +131,21 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -126,13 +131,21 @@ class OpenFoldWrapper(pl.LightningModule):
clone_param = lambda t: t.detach().clone() clone_param = lambda t: t.detach().clone()
self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict()) self.cached_weights = tensor_tree_map(clone_param, self.model.state_dict())
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
ground_truth = batch.pop('gt_features', None)
# Run the model # Run the model
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0. batch["use_clamped_fape"] = 0.
if self.is_multimer:
batch = multi_chain_permutation_align(out=outputs,
features=batch,
ground_truth=ground_truth)
# Compute loss and other metrics
_, loss_breakdown = self.loss( _, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, batch, _return_breakdown=True
) )
...@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -221,6 +234,7 @@ class OpenFoldWrapper(pl.LightningModule):
lr_scheduler = AlphaFoldLRScheduler( lr_scheduler = AlphaFoldLRScheduler(
optimizer, optimizer,
last_epoch=self.last_lr_step
) )
return { return {
...@@ -265,8 +279,8 @@ def main(args): ...@@ -265,8 +279,8 @@ def main(args):
train=True, train=True,
low_prec=(str(args.precision) == "16") low_prec=(str(args.precision) == "16")
) )
model_module = OpenFoldWrapper(config) model_module = OpenFoldWrapper(config)
if(args.resume_from_ckpt): if(args.resume_from_ckpt):
if(os.path.isdir(args.resume_from_ckpt)): if(os.path.isdir(args.resume_from_ckpt)):
last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt) last_global_step = get_global_step_from_zero_checkpoint(args.resume_from_ckpt)
...@@ -281,7 +295,7 @@ def main(args): ...@@ -281,7 +295,7 @@ def main(args):
else: else:
sd = torch.load(args.resume_from_ckpt) sd = torch.load(args.resume_from_ckpt)
sd = {k[len("module."):]:v for k,v in sd.items()} sd = {k[len("module."):]:v for k,v in sd.items()}
model_module.load_state_dict(sd) import_openfold_weights_(model=model_module, state_dict=sd)
logging.info("Successfully loaded model weights...") logging.info("Successfully loaded model weights...")
if(args.resume_from_jax_params): if(args.resume_from_jax_params):
model_module.load_from_jax(args.resume_from_jax_params) model_module.load_from_jax(args.resume_from_jax_params)
...@@ -291,12 +305,18 @@ def main(args): ...@@ -291,12 +305,18 @@ def main(args):
if(args.script_modules): if(args.script_modules):
script_preset_(model_module) script_preset_(model_module)
#data_module = DummyDataLoader("new_batch.pickle") if "multimer" in args.config_preset:
data_module = OpenFoldDataModule( data_module = OpenFoldMultimerDataModule(
config=config.data, config=config.data,
batch_seed=args.seed, batch_seed=args.seed,
**vars(args) **vars(args)
) )
else:
data_module = OpenFoldDataModule(
config=config.data,
batch_seed=args.seed,
**vars(args)
)
data_module.prepare_data() data_module.prepare_data()
data_module.setup() data_module.setup()
...@@ -416,6 +436,10 @@ if __name__ == "__main__": ...@@ -416,6 +436,10 @@ if __name__ == "__main__":
help='''Cutoff for all templates. In training mode, templates are also help='''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target''' filtered by the release date of the target'''
) )
parser.add_argument(
"--train_mmcif_data_cache_path", type=str, default=None,
help="Path to the json file which records all the information of mmcif structures used during training"
)
parser.add_argument( parser.add_argument(
"--use_single_seq_mode", type=str, default=False, "--use_single_seq_mode", type=str, default=False,
help="Use single sequence embeddings instead of MSAs." help="Use single sequence embeddings instead of MSAs."
...@@ -436,6 +460,10 @@ if __name__ == "__main__": ...@@ -436,6 +460,10 @@ if __name__ == "__main__":
"--val_alignment_dir", type=str, default=None, "--val_alignment_dir", type=str, default=None,
help="Directory containing precomputed validation alignments" help="Directory containing precomputed validation alignments"
) )
parser.add_argument(
"--val_mmcif_data_cache_path", type=str, default=None,
help="path to the json file which records all the information of mmcif structures used during validation"
)
parser.add_argument( parser.add_argument(
"--kalign_binary_path", type=str, default='/usr/bin/kalign', "--kalign_binary_path", type=str, default='/usr/bin/kalign',
help="Path to the kalign binary" help="Path to the kalign binary"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment