# Copyright 2021 AlQuraishi Laboratory # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch import numpy as np import unittest import openfold.data.data_transforms as data_transforms from openfold.np.residue_constants import ( restype_rigid_group_default_frame, restype_atom14_to_rigid_group, restype_atom14_mask, restype_atom14_rigid_group_positions, ) import openfold.utils.feats as feats from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.geometry.rigid_matrix_vector import Rigid3Array from openfold.utils.geometry.rotation_matrix import Rot3Array from openfold.utils.geometry.vector import Vec3Array from openfold.utils.tensor_utils import ( tree_map, tensor_tree_map, ) import tests.compare_utils as compare_utils from tests.config import consts from tests.data_utils import random_affines_4x4, random_asym_ids if compare_utils.alphafold_is_installed(): alphafold = compare_utils.import_alphafold() import jax import haiku as hk class TestFeats(unittest.TestCase): @classmethod def setUpClass(cls): if compare_utils.alphafold_is_installed(): if consts.is_multimer: cls.am_atom = alphafold.model.all_atom_multimer cls.am_fold = alphafold.model.folding_multimer cls.am_modules = alphafold.model.modules_multimer cls.am_rigid = alphafold.model.geometry else: cls.am_atom = alphafold.model.all_atom cls.am_fold = alphafold.model.folding cls.am_modules = alphafold.model.modules cls.am_rigid = alphafold.model.r3 @compare_utils.skip_unless_alphafold_installed() def test_pseudo_beta_fn_compare(self): def test_pbf(aatype, all_atom_pos, all_atom_mask): return alphafold.model.modules.pseudo_beta_fn( aatype, all_atom_pos, all_atom_mask, ) f = hk.transform(test_pbf) n_res = consts.n_res aatype = np.random.randint(0, 22, (n_res,)) all_atom_pos = np.random.rand(n_res, 37, 3).astype(np.float32) all_atom_mask = np.random.randint(0, 2, (n_res, 37)) out_gt_pos, out_gt_mask = f.apply( {}, None, aatype, all_atom_pos, all_atom_mask ) out_gt_pos = torch.tensor(np.array(out_gt_pos.block_until_ready())) out_gt_mask = torch.tensor(np.array(out_gt_mask.block_until_ready())) out_repro_pos, out_repro_mask = feats.pseudo_beta_fn( torch.tensor(aatype).cuda(), torch.tensor(all_atom_pos).cuda(), torch.tensor(all_atom_mask).cuda(), ) out_repro_pos = out_repro_pos.cpu() out_repro_mask = out_repro_mask.cpu() self.assertTrue( torch.max(torch.abs(out_gt_pos - out_repro_pos)) < consts.eps ) self.assertTrue( torch.max(torch.abs(out_gt_mask - out_repro_mask)) < consts.eps ) @compare_utils.skip_unless_alphafold_installed() def test_atom37_to_torsion_angles_compare(self): def run_test(aatype, all_atom_pos, all_atom_mask): return alphafold.model.all_atom.atom37_to_torsion_angles( aatype, all_atom_pos, all_atom_mask, placeholder_for_undefined=False, ) f = hk.transform(run_test) n_templ = 7 n_res = 13 aatype = np.random.randint(0, 22, (n_templ, n_res)).astype(np.int64) all_atom_pos = np.random.rand(n_templ, n_res, 37, 3).astype(np.float32) all_atom_mask = np.random.randint(0, 2, (n_templ, n_res, 37)).astype( np.float32 ) out_gt = f.apply({}, None, aatype, all_atom_pos, all_atom_mask) out_gt = jax.tree_map(lambda x: torch.as_tensor(np.array(x)), out_gt) out_repro = data_transforms.atom37_to_torsion_angles()( { "aatype": torch.as_tensor(aatype).cuda(), "all_atom_positions": torch.as_tensor(all_atom_pos).cuda(), "all_atom_mask": torch.as_tensor(all_atom_mask).cuda(), }, ) tasc = out_repro["torsion_angles_sin_cos"].cpu() atasc = out_repro["alt_torsion_angles_sin_cos"].cpu() tam = out_repro["torsion_angles_mask"].cpu() # This function is extremely sensitive to floating point imprecisions, # so it is given much greater latitude in comparison tests. self.assertTrue( torch.mean(torch.abs(out_gt["torsion_angles_sin_cos"] - tasc)) < 0.01 ) self.assertTrue( torch.mean(torch.abs(out_gt["alt_torsion_angles_sin_cos"] - atasc)) < 0.01 ) self.assertTrue( torch.max(torch.abs(out_gt["torsion_angles_mask"] - tam)) < consts.eps ) @compare_utils.skip_unless_alphafold_installed() def test_atom37_to_frames_compare(self): def run_atom37_to_frames(aatype, all_atom_positions, all_atom_mask): if consts.is_multimer: all_atom_positions = self.am_rigid.Vec3Array.from_array(all_atom_positions) return self.am_atom.atom37_to_frames( aatype, all_atom_positions, all_atom_mask ) f = hk.transform(run_atom37_to_frames) n_res = consts.n_res batch = { "aatype": np.random.randint(0, 21, (n_res,)), "all_atom_positions": np.random.rand(n_res, 37, 3).astype( np.float32 ), "all_atom_mask": np.random.randint(0, 2, (n_res, 37)).astype( np.float32 ), } out_gt = f.apply({}, None, **batch) if consts.is_multimer: batch["asym_id"] = random_asym_ids(n_res) to_tensor = (lambda t: torch.tensor(np.array(t)) if not isinstance(t, self.am_rigid.Rigid3Array) else torch.tensor(np.array(t.to_array()))) else: to_tensor = lambda t: torch.tensor(np.array(t)) out_gt = {k: to_tensor(v) for k, v in out_gt.items()} def rigid3x4_to_4x4(rigid3arr): four_by_four = torch.zeros(*rigid3arr.shape[:-2], 4, 4) four_by_four[..., :3, :4] = rigid3arr four_by_four[..., 3, 3] = 1 return four_by_four def flat12_to_4x4(flat12): rot = flat12[..., :9].view(*flat12.shape[:-1], 3, 3) trans = flat12[..., 9:] four_by_four = torch.zeros(*flat12.shape[:-1], 4, 4) four_by_four[..., :3, :3] = rot four_by_four[..., :3, 3] = trans four_by_four[..., 3, 3] = 1 return four_by_four convert_func = rigid3x4_to_4x4 if consts.is_multimer else flat12_to_4x4 out_gt["rigidgroups_gt_frames"] = convert_func( out_gt["rigidgroups_gt_frames"] ) out_gt["rigidgroups_alt_gt_frames"] = convert_func( out_gt["rigidgroups_alt_gt_frames"] ) to_tensor = lambda t: torch.tensor(np.array(t)).cuda() batch = tree_map(to_tensor, batch, np.ndarray) out_repro = data_transforms.atom37_to_frames(batch) out_repro = tensor_tree_map(lambda t: t.cpu(), out_repro) for k, v in out_gt.items(): self.assertTrue( torch.max(torch.abs(out_gt[k] - out_repro[k])) < consts.eps ) def test_torsion_angles_to_frames_shape(self): batch_size = 2 n = 5 rots = torch.rand((batch_size, n, 3, 3)) trans = torch.rand((batch_size, n, 3)) if consts.is_multimer: rotation = Rot3Array.from_array(rots) translation = Vec3Array.from_array(trans) ts = Rigid3Array(rotation, translation) else: ts = Rigid(Rotation(rot_mats=rots), trans) angles = torch.rand((batch_size, n, 7, 2)) aas = torch.tensor([i % 2 for i in range(n)]) aas = torch.stack([aas for _ in range(batch_size)]) frames = feats.torsion_angles_to_frames( ts, angles, aas, torch.tensor(restype_rigid_group_default_frame), ) self.assertTrue(frames.shape == (batch_size, n, 8)) @compare_utils.skip_unless_alphafold_installed() def test_torsion_angles_to_frames_compare(self): def run_torsion_angles_to_frames( aatype, backb_to_global, torsion_angles_sin_cos ): return self.am_atom.torsion_angles_to_frames( aatype, backb_to_global, torsion_angles_sin_cos, ) f = hk.transform(run_torsion_angles_to_frames) n_res = consts.n_res aatype = np.random.randint(0, 21, size=(n_res,)) affines = random_affines_4x4((n_res,)) if consts.is_multimer: rigids = self.am_rigid.Rigid3Array.from_array4x4(affines) transformations = Rigid3Array.from_tensor_4x4( torch.as_tensor(affines).float() ) else: rigids = self.am_rigid.rigids_from_tensor4x4(affines) transformations = Rigid.from_tensor_4x4( torch.as_tensor(affines).float() ) torsion_angles_sin_cos = np.random.rand(n_res, 7, 2) out_gt = f.apply({}, None, aatype, rigids, torsion_angles_sin_cos) jax.tree_map(lambda x: x.block_until_ready(), out_gt) out = feats.torsion_angles_to_frames( transformations.cuda(), torch.as_tensor(torsion_angles_sin_cos).cuda(), torch.as_tensor(aatype).cuda(), torch.tensor(restype_rigid_group_default_frame).cuda(), ) # Convert the Rigids to 4x4 transformation tensors out_gt_rot = out_gt.rot if not consts.is_multimer else out_gt.rotation.to_array() out_gt_trans = out_gt.trans if not consts.is_multimer else out_gt.translation.to_array() if consts.is_multimer: rots_gt = torch.as_tensor(np.array(out_gt_rot)) trans_gt = torch.as_tensor(np.array(out_gt_trans)) else: rots_gt = list(map(lambda x: torch.as_tensor(np.array(x)), out_gt_rot)) trans_gt = list( map(lambda x: torch.as_tensor(np.array(x)), out_gt_trans) ) rots_gt = torch.cat([x.unsqueeze(-1) for x in rots_gt], dim=-1) rots_gt = rots_gt.view(*rots_gt.shape[:-1], 3, 3) trans_gt = torch.cat([x.unsqueeze(-1) for x in trans_gt], dim=-1) transforms_gt = torch.cat([rots_gt, trans_gt.unsqueeze(-1)], dim=-1) bottom_row = torch.zeros((*rots_gt.shape[:-2], 1, 4)) bottom_row[..., 3] = 1 transforms_gt = torch.cat([transforms_gt, bottom_row], dim=-2) transforms_repro = out.to_tensor_4x4().cpu() self.assertTrue( torch.max(torch.abs(transforms_gt - transforms_repro) < consts.eps) ) def test_frames_and_literature_positions_to_atom14_pos_shape(self): batch_size = consts.batch_size n_res = consts.n_res rots = torch.rand((batch_size, n_res, 8, 3, 3)) trans = torch.rand((batch_size, n_res, 8, 3)) if consts.is_multimer: rotation = Rot3Array.from_array(rots) translation = Vec3Array.from_array(trans) ts = Rigid3Array(rotation, translation) else: ts = Rigid(Rotation(rot_mats=rots), trans) f = torch.randint(low=0, high=21, size=(batch_size, n_res)).long() xyz = feats.frames_and_literature_positions_to_atom14_pos( ts, f, torch.tensor(restype_rigid_group_default_frame), torch.tensor(restype_atom14_to_rigid_group), torch.tensor(restype_atom14_mask), torch.tensor(restype_atom14_rigid_group_positions), ) self.assertTrue(xyz.shape == (batch_size, n_res, 14, 3)) @compare_utils.skip_unless_alphafold_installed() def test_frames_and_literature_positions_to_atom14_pos_compare(self): def run_f(aatype, affines): return self.am_atom.frames_and_literature_positions_to_atom14_pos( aatype, affines ) f = hk.transform(run_f) n_res = consts.n_res aatype = np.random.randint(0, 21, size=(n_res,)) affines = random_affines_4x4((n_res, 8)) if consts.is_multimer: rigids = self.am_rigid.Rigid3Array.from_array4x4(affines) transformations = Rigid3Array.from_tensor_4x4( torch.as_tensor(affines).float() ) else: rigids = self.am_rigid.rigids_from_tensor4x4(affines) transformations = Rigid.from_tensor_4x4( torch.as_tensor(affines).float() ) out_gt = f.apply({}, None, aatype, rigids) jax.tree_map(lambda x: x.block_until_ready(), out_gt) if consts.is_multimer: out_gt = torch.as_tensor(np.array(out_gt.to_array())) else: out_gt = torch.stack( [torch.as_tensor(np.array(x)) for x in out_gt], dim=-1 ) out_repro = feats.frames_and_literature_positions_to_atom14_pos( transformations.cuda(), torch.as_tensor(aatype).cuda(), torch.tensor(restype_rigid_group_default_frame).cuda(), torch.tensor(restype_atom14_to_rigid_group).cuda(), torch.tensor(restype_atom14_mask).cuda(), torch.tensor(restype_atom14_rigid_group_positions).cuda(), ).cpu() self.assertTrue(torch.max(torch.abs(out_gt - out_repro) < consts.eps)) if __name__ == "__main__": unittest.main()