Commit 260db67f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish multimer inference

parent 6e68d6b0
...@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg): ...@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.cast_to_64bit_ints, data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile, data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat, data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
] ]
if(common_cfg.use_templates): if(common_cfg.use_templates):
......
...@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module): ...@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos = single_template_feats["template_all_atom_positions"] raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_tensor(raw_atom_pos) atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine( rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos, atom_pos,
single_template_feats["template_all_atom_mask"], single_template_feats["template_all_atom_mask"],
......
...@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module): ...@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a *= math.sqrt(1.0 / (3 * self.c_hidden)) a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1))) a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
for c in q_pts:
print(type(c))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :] pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
...@@ -669,7 +666,7 @@ class StructureModule(nn.Module): ...@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self, r, f # [*, N, 8] # [*, N] self, r, f # [*, N, 8] # [*, N]
): ):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device) self._init_residue_constants(r.dtype, r.device)
return frames_and_literature_positions_to_atom14_pos( return frames_and_literature_positions_to_atom14_pos(
r, r,
f, f,
...@@ -818,11 +815,11 @@ class StructureModule(nn.Module): ...@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
) )
preds = { preds = {
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor7(), "frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(), "sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles, "unnormalized_angles": unnormalized_angles,
"angles": angles, "angles": angles,
"positions": pred_xyz, "positions": pred_xyz.to_tensor(),
} }
outputs.append(preds) outputs.append(preds)
......
...@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein: ...@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
) )
def _chain_end(atom_index, end_resname, chain_name, residue_indx) -> str: def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER' chain_end = 'TER'
return( return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} ' f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
......
...@@ -22,6 +22,7 @@ from typing import Dict ...@@ -22,6 +22,7 @@ from typing import Dict
from openfold.np import protein from openfold.np import protein
import openfold.np.residue_constants as rc import openfold.np.residue_constants as rc
from openfold.utils.geometry import rigid_matrix_vector, rotation_matrix
from openfold.utils.rigid_utils import Rotation, Rigid from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
batched_gather, batched_gather,
...@@ -213,15 +214,14 @@ def torsion_angles_to_frames( ...@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses # This follows the original code rather than the supplement, which uses
# different indices. # different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape) all_rots = alpha.new_zeros(default_r.shape + (3, 3))
all_rots[..., 0, 0] = 1 all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1] all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0] all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None) all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
all_frames = default_r.compose(all_rots)
chi2_frame_to_frame = all_frames[..., 5] chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6] chi3_frame_to_frame = all_frames[..., 6]
...@@ -232,7 +232,7 @@ def torsion_angles_to_frames( ...@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame) chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame) chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat( all_frames_to_bb = rigid_matrix_vector.Rigid3Array.cat(
[ [
all_frames[..., :5], all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1), chi2_frame_to_bb.unsqueeze(-1),
...@@ -248,7 +248,7 @@ def torsion_angles_to_frames( ...@@ -248,7 +248,7 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos( def frames_and_literature_positions_to_atom14_pos(
r: Rigid, r: rigid_matrix_vector.Rigid3Array,
aatype: torch.Tensor, aatype: torch.Tensor,
default_frames, default_frames,
group_idx, group_idx,
...@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos( ...@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda x: torch.sum(x, dim=-1) lambda x: torch.sum(x, dim=-1)
) )
# [*, N, 14, 1] # [*, N, 14]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1) atom_mask = atom_mask[aatype, ...]
# [*, N, 14, 3] # [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...] lit_positions = lit_positions[aatype, ...]
......
...@@ -22,8 +22,6 @@ class QuatRigid(nn.Module): ...@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision # NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32)) rigid_flat = self.linear(activations.to(torch.float32))
print(rigid_flat.shape)
rigid_flat = torch.unbind(rigid_flat, dim=-1) rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat): if(self.full_quat):
qw, qx, qy, qz = rigid_flat[:4] qw, qx, qy, qz = rigid_flat[:4]
......
...@@ -67,6 +67,9 @@ class Rigid3Array: ...@@ -67,6 +67,9 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point.""" """Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array:
return self.apply_to_point(vector.Vec3Array.from_array(point))
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point.""" """Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation new_point = point - self.translation
...@@ -74,7 +77,28 @@ class Rigid3Array: ...@@ -74,7 +77,28 @@ class Rigid3Array:
def compose_rotation(self, other_rotation): def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation rot = self.rotation @ other_rotation
return Rigid3Array(rot, trans.clone()) return Rigid3Array(rot, self.translation.clone())
def compose(self, other_rigid):
return self @ other_rigid
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
@classmethod @classmethod
def identity(cls, shape, device) -> Rigid3Array: def identity(cls, shape, device) -> Rigid3Array:
...@@ -87,28 +111,51 @@ class Rigid3Array: ...@@ -87,28 +111,51 @@ class Rigid3Array:
@classmethod @classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array: def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls( return cls(
Rot3Array.cat([r.rotation for r in rigids], dim=dim), rotation_matrix.Rot3Array.cat(
Vec3Array.cat([r.translation for r in rigids], dim=dim), [r.rotation for r in rigids], dim=dim
),
vector.Vec3Array.cat(
[r.translation for r in rigids], dim=dim
),
) )
def scale_translation(self, factor: Float) -> Rigid3Array: def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'.""" """Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor) return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self): def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_array() rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_array() vec_array = self.translation.to_tensor()
return torch.cat([rot_array, vec_array[..., None]], dim=-1) array = torch.zeros(
rot_array.shape[:-2] + (4, 4),
device=rot_array.device,
dtype=rot_array.dtype
)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.
return array
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
def reshape(self, new_shape) -> Rigid3Array: def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape) rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape) trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans) return Rigid3Aray(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
@classmethod @classmethod
def from_array(cls, array): def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) rot = rotation_matrix.Rot3Array.from_array(
vec = vector.Vec3Array.from_array(array[..., -1]) array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec) return cls(rot, vec)
@classmethod @classmethod
...@@ -124,5 +171,6 @@ class Rigid3Array: ...@@ -124,5 +171,6 @@ class Rigid3Array:
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2] array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
) )
translation = vector.Vec3Array( translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]) array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation) return cls(rotation, translation)
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from openfold.utils.geometry import struct_of_array from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import utils from openfold.utils.geometry import utils
from openfold.utils.geometry import vector from openfold.utils.geometry import vector
from openfold.utils.tensor_utils import tensor_tree_map
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
...@@ -59,6 +60,13 @@ class Rot3Array: ...@@ -59,6 +60,13 @@ class Rot3Array:
} }
) )
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def map_tensor_fn(self, fn) -> Rot3Array: def map_tensor_fn(self, fn) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array) field_names = utils.get_field_names(Rot3Array)
return Rot3Array( return Rot3Array(
...@@ -88,12 +96,19 @@ class Rot3Array: ...@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point.""" """Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point) return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays.""" def unsqueeze(self, dim: int):
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx)) return Rot3Array(
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy)) *tensor_tree_map(
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz)) lambda t: t.unsqueeze(dim),
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) [getattr(self, c) for c in COMPONENTS]
)
)
def stop_gradient(self) -> Rot3Array:
return Rot3Array(
*[getattr(self, c).detach() for c in COMPONENTS]
)
@classmethod @classmethod
def identity(cls, shape, device) -> Rot3Array: def identity(cls, shape, device) -> Rot3Array:
...@@ -130,9 +145,11 @@ class Rot3Array: ...@@ -130,9 +145,11 @@ class Rot3Array:
@classmethod @classmethod
def from_array(cls, array: torch.Tensor) -> Rot3Array: def from_array(cls, array: torch.Tensor) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3].""" """Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return cls(torch.unbind(array, dim=-2)) rows = torch.unbind(array, dim=-2)
rc = [torch.unbind(e, dim=-1) for e in rows]
return cls(*[e for row in rc for e in row])
def to_array(self) -> torch.Tensor: def to_tensor(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3].""" """Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack( return torch.stack(
[ [
...@@ -140,7 +157,8 @@ class Rot3Array: ...@@ -140,7 +157,8 @@ class Rot3Array:
torch.stack([self.yx, self.yy, self.yz], dim=-1), torch.stack([self.yx, self.yy, self.yz], dim=-1),
torch.stack([self.zx, self.zy, self.zz], dim=-1) torch.stack([self.zx, self.zy, self.zz], dim=-1)
], ],
dim=-2) dim=-2
)
@classmethod @classmethod
def from_quaternion(cls, def from_quaternion(cls,
......
...@@ -134,13 +134,20 @@ class Vec3Array: ...@@ -134,13 +134,20 @@ class Vec3Array:
return Vec3Array(x, y, z) return Vec3Array(x, y, z)
def sum(self, dim) -> Vec3Array: def sum(self, dim: int) -> Vec3Array:
return Vec3Array( return Vec3Array(
torch.sum(self.x, dim=dim), torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim), torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim), torch.sum(self.z, dim=dim),
) )
def unsqueeze(self, dim: int):
return Vec3Array(
self.x.unsqueeze(dim),
self.y.unsqueeze(dim),
self.z.unsqueeze(dim),
)
@classmethod @classmethod
def zeros(cls, shape, device="cpu"): def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape.""" """Return Vec3Array corresponding to zeros of given shape."""
...@@ -150,11 +157,11 @@ class Vec3Array: ...@@ -150,11 +157,11 @@ class Vec3Array:
torch.zeros(shape, dtype=torch.float32, device=device) torch.zeros(shape, dtype=torch.float32, device=device)
) )
def to_array(self) -> torch.Tensor: def to_tensor(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1) return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod @classmethod
def from_tensor(cls, tensor): def from_array(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1)) return cls(*torch.unbind(tensor, dim=-1))
@classmethod @classmethod
......
...@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys()) flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys] incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys] missing = [k for k in keys if k not in flat_keys]
print(f"Incorrect: {incorrect}") # print(f"Incorrect: {incorrect}")
print(f"Missing: {missing}") # print(f"Missing: {missing}")
assert len(incorrect) == 0 assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys()))) # assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
......
...@@ -217,7 +217,8 @@ def main(args): ...@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein = protein.from_prediction( unrelaxed_protein = protein.from_prediction(
features=batch, features=batch,
result=out, result=out,
b_factors=plddt_b_factors b_factors=plddt_b_factors,
remove_leading_feature_dimension=not is_multimer,
) )
# Save the unrelaxed PDB. # Save the unrelaxed PDB.
...@@ -227,6 +228,9 @@ def main(args): ...@@ -227,6 +228,9 @@ def main(args):
with open(unrelaxed_output_path, 'w') as f: with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein)) f.write(protein.to_pdb(unrelaxed_protein))
print(unrelaxed_output_path)
print("asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd")
amber_relaxer = relax.AmberRelaxation( amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"), use_gpu=(args.model_device != "cpu"),
**config.relax, **config.relax,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment