# Copyright 2021 AlQuraishi Laboratory # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import numpy as np import unittest from openfold.data.data_transforms import make_atom14_masks_np from openfold.np.residue_constants import ( restype_atom14_mask, restype_atom37_mask, ) from openfold.model.structure_module import ( StructureModule, StructureModuleTransition, AngleResnet, InvariantPointAttention, ) from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array from openfold.utils.geometry.rotation_matrix import Rot3Array from openfold.utils.geometry.vector import Vec3Array import tests.compare_utils as compare_utils from tests.config import consts from tests.data_utils import ( random_affines_4x4, ) if compare_utils.alphafold_is_installed(): alphafold = compare_utils.import_alphafold() import jax import haiku as hk class TestStructureModule(unittest.TestCase): @classmethod def setUpClass(cls): if compare_utils.alphafold_is_installed(): if consts.is_multimer: cls.am_atom = alphafold.model.all_atom_multimer cls.am_fold = alphafold.model.folding_multimer cls.am_modules = alphafold.model.modules_multimer cls.am_rigid = alphafold.model.geometry else: cls.am_atom = alphafold.model.all_atom cls.am_fold = alphafold.model.folding cls.am_modules = alphafold.model.modules cls.am_rigid = alphafold.model.r3 def test_structure_module_shape(self): batch_size = consts.batch_size n = consts.n_res c_s = consts.c_s c_z = consts.c_z c_ipa = 13 c_resnet = 17 no_heads_ipa = 6 no_query_points = 4 no_value_points = 4 dropout_rate = 0.1 no_layers = 3 no_transition_layers = 3 no_resnet_layers = 3 ar_epsilon = 1e-6 no_angles = 7 trans_scale_factor = 10 inf = 1e5 sm = StructureModule( c_s, c_z, c_ipa, c_resnet, no_heads_ipa, no_query_points, no_value_points, dropout_rate, no_layers, no_transition_layers, no_resnet_layers, no_angles, trans_scale_factor, ar_epsilon, inf, is_multimer=consts.is_multimer ) s = torch.rand((batch_size, n, c_s)) z = torch.rand((batch_size, n, n, c_z)) f = torch.randint(low=0, high=21, size=(batch_size, n)).long() out = sm({"single": s, "pair": z}, f) if consts.is_multimer: self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 4, 4)) else: self.assertTrue(out["frames"].shape == (no_layers, batch_size, n, 7)) self.assertTrue( out["angles"].shape == (no_layers, batch_size, n, no_angles, 2) ) self.assertTrue( out["positions"].shape == (no_layers, batch_size, n, 14, 3) ) def test_structure_module_transition_shape(self): batch_size = 2 n = 5 c = 7 num_layers = 3 dropout = 0.1 smt = StructureModuleTransition(c, num_layers, dropout) s = torch.rand((batch_size, n, c)) shape_before = s.shape s = smt(s) shape_after = s.shape self.assertTrue(shape_before == shape_after) @compare_utils.skip_unless_alphafold_installed() def test_structure_module_compare(self): config = compare_utils.get_alphafold_config() c_sm = config.model.heads.structure_module c_global = config.model.global_config def run_sm(representations, batch): sm = self.am_fold.StructureModule(c_sm, c_global) representations = { k: jax.lax.stop_gradient(v) for k, v in representations.items() } batch = {k: jax.lax.stop_gradient(v) for k, v in batch.items()} if consts.is_multimer: return sm(representations, batch, is_training=False, compute_loss=True) return sm(representations, batch, is_training=False) f = hk.transform(run_sm) n_res = 200 representations = { "single": np.random.rand(n_res, consts.c_s).astype(np.float32), "pair": np.random.rand(n_res, n_res, consts.c_z).astype(np.float32), } batch = { "seq_mask": np.random.randint(0, 2, (n_res,)).astype(np.float32), "aatype": np.random.randint(0, 21, (n_res,)), } batch["atom14_atom_exists"] = np.take( restype_atom14_mask, batch["aatype"], axis=0 ) batch["atom37_atom_exists"] = np.take( restype_atom37_mask, batch["aatype"], axis=0 ) batch.update(make_atom14_masks_np(batch)) params = compare_utils.fetch_alphafold_module_weights( "alphafold/alphafold_iteration/structure_module" ) key = jax.random.PRNGKey(42) out_gt = f.apply(params, key, representations, batch) out_gt = torch.as_tensor( np.array(out_gt["final_atom14_positions"].block_until_ready()) ) model = compare_utils.get_global_pretrained_openfold() out_repro = model.structure_module( { "single": torch.as_tensor(representations["single"]).cuda(), "pair": torch.as_tensor(representations["pair"]).cuda(), }, torch.as_tensor(batch["aatype"]).cuda(), mask=torch.as_tensor(batch["seq_mask"]).cuda(), inplace_safe=False, ) out_repro = out_repro["positions"][-1].cpu() # The structure module, thanks to angle normalization, is very volatile # We only assess the mean here. Heuristically speaking, it seems to # have lower error in general on real rather than synthetic data. compare_utils.assert_mean_abs_diff_small(out_gt, out_repro, 0.05) class TestInvariantPointAttention(unittest.TestCase): @classmethod def setUpClass(cls): if compare_utils.alphafold_is_installed(): if consts.is_multimer: cls.am_atom = alphafold.model.all_atom_multimer cls.am_fold = alphafold.model.folding_multimer cls.am_modules = alphafold.model.modules_multimer cls.am_rigid = alphafold.model.geometry else: cls.am_atom = alphafold.model.all_atom cls.am_fold = alphafold.model.folding cls.am_modules = alphafold.model.modules cls.am_rigid = alphafold.model.r3 def test_shape(self): c_m = 13 c_z = 17 c_hidden = 19 no_heads = 5 no_qp = 7 no_vp = 11 batch_size = 2 n_res = 23 s = torch.rand((batch_size, n_res, c_m)) z = torch.rand((batch_size, n_res, n_res, c_z)) mask = torch.ones((batch_size, n_res)) rot_mats = torch.rand((batch_size, n_res, 3, 3)) trans = torch.rand((batch_size, n_res, 3)) if consts.is_multimer: rotation = Rot3Array.from_array(rot_mats) translation = Vec3Array.from_array(trans) r = Rigid3Array(rotation, translation) else: rots = Rotation(rot_mats=rot_mats, quats=None) r = Rigid(rots, trans) ipa = InvariantPointAttention( c_m, c_z, c_hidden, no_heads, no_qp, no_vp, is_multimer=consts.is_multimer ) shape_before = s.shape s = ipa(s, z, r, mask) self.assertTrue(s.shape == shape_before) @compare_utils.skip_unless_alphafold_installed() def test_ipa_compare(self): def run_ipa(act, static_feat_2d, mask, affine): config = compare_utils.get_alphafold_config() ipa = self.am_fold.InvariantPointAttention( config.model.heads.structure_module, config.model.global_config, ) if consts.is_multimer: attn = ipa( inputs_1d=act, inputs_2d=static_feat_2d, mask=mask, rigid=affine ) else: attn = ipa( inputs_1d=act, inputs_2d=static_feat_2d, mask=mask, affine=affine ) return attn f = hk.transform(run_ipa) n_res = consts.n_res c_s = consts.c_s c_z = consts.c_z sample_act = np.random.rand(n_res, c_s) sample_2d = np.random.rand(n_res, n_res, c_z) sample_mask = np.ones((n_res, 1)) affines = random_affines_4x4((n_res,)) if consts.is_multimer: rigids = self.am_rigid.Rigid3Array.from_array4x4(affines) transformations = Rigid3Array.from_tensor_4x4( torch.as_tensor(affines).float().cuda() ) sample_affine = rigids else: rigids = self.am_rigid.rigids_from_tensor4x4(affines) quats = self.am_rigid.rigids_to_quataffine(rigids) transformations = Rigid.from_tensor_4x4( torch.as_tensor(affines).float().cuda() ) sample_affine = quats ipa_params = compare_utils.fetch_alphafold_module_weights( "alphafold/alphafold_iteration/structure_module/" + "fold_iteration/invariant_point_attention" ) out_gt = f.apply( ipa_params, None, sample_act, sample_2d, sample_mask, sample_affine ).block_until_ready() out_gt = torch.as_tensor(np.array(out_gt)) with torch.no_grad(): model = compare_utils.get_global_pretrained_openfold() out_repro = model.structure_module.ipa( torch.as_tensor(sample_act).float().cuda(), torch.as_tensor(sample_2d).float().cuda(), transformations, torch.as_tensor(sample_mask.squeeze(-1)).float().cuda(), ).cpu() compare_utils.assert_max_abs_diff_small(out_gt, out_repro, consts.eps) class TestAngleResnet(unittest.TestCase): def test_shape(self): batch_size = 2 n = 3 c_s = 13 c_hidden = 11 no_layers = 5 no_angles = 7 epsilon = 1e-12 ar = AngleResnet(c_s, c_hidden, no_layers, no_angles, epsilon) a = torch.rand((batch_size, n, c_s)) a_initial = torch.rand((batch_size, n, c_s)) _, a = ar(a, a_initial) self.assertTrue(a.shape == (batch_size, n, no_angles, 2)) if __name__ == "__main__": unittest.main()