"lib/llm/vscode:/vscode.git/clone" did not exist on "42ce6931a788c31a9f1f99b51336ba15e189f617"
Commit 260db67f authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish multimer inference

parent 6e68d6b0
......@@ -29,6 +29,7 @@ def nonensembled_transform_fns(common_cfg, mode_cfg):
data_transforms.cast_to_64bit_ints,
data_transforms_multimer.make_msa_profile,
data_transforms_multimer.create_target_feat,
data_transforms.make_atom14_masks,
]
if(common_cfg.use_templates):
......
......@@ -868,7 +868,7 @@ class TemplateEmbedderMultimer(nn.Module):
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_tensor(raw_atom_pos)
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
......
......@@ -363,9 +363,6 @@ class InvariantPointAttention(nn.Module):
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
for c in q_pts:
print(type(c))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts[..., None, :, :] - k_pts[..., None, :, :, :]
......@@ -669,7 +666,7 @@ class StructureModule(nn.Module):
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
self._init_residue_constants(r.dtype, r.device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
......@@ -818,11 +815,11 @@ class StructureModule(nn.Module):
)
preds = {
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor7(),
"frames": rigids.scale_translation(self.trans_scale_factor).to_tensor(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"positions": pred_xyz.to_tensor(),
}
outputs.append(preds)
......
......@@ -205,7 +205,7 @@ def from_proteinnet_string(proteinnet_str: str) -> Protein:
)
def _chain_end(atom_index, end_resname, chain_name, residue_indx) -> str:
def _chain_end(atom_index, end_resname, chain_name, residue_index) -> str:
chain_end = 'TER'
return(
f'{chain_end:<6}{atom_index:>5} {end_resname:>3} '
......
......@@ -22,6 +22,7 @@ from typing import Dict
from openfold.np import protein
import openfold.np.residue_constants as rc
from openfold.utils.geometry import rigid_matrix_vector, rotation_matrix
from openfold.utils.rigid_utils import Rotation, Rigid
from openfold.utils.tensor_utils import (
batched_gather,
......@@ -213,15 +214,14 @@ def torsion_angles_to_frames(
# This follows the original code rather than the supplement, which uses
# different indices.
all_rots = alpha.new_zeros(default_r.get_rots().get_rot_mats().shape)
all_rots = alpha.new_zeros(default_r.shape + (3, 3))
all_rots[..., 0, 0] = 1
all_rots[..., 1, 1] = alpha[..., 1]
all_rots[..., 1, 2] = -alpha[..., 0]
all_rots[..., 2, 1:] = alpha
all_rots = Rigid(Rotation(rot_mats=all_rots), None)
all_frames = default_r.compose(all_rots)
all_rots = rotation_matrix.Rot3Array.from_array(all_rots)
all_frames = default_r.compose_rotation(all_rots)
chi2_frame_to_frame = all_frames[..., 5]
chi3_frame_to_frame = all_frames[..., 6]
......@@ -232,7 +232,7 @@ def torsion_angles_to_frames(
chi3_frame_to_bb = chi2_frame_to_bb.compose(chi3_frame_to_frame)
chi4_frame_to_bb = chi3_frame_to_bb.compose(chi4_frame_to_frame)
all_frames_to_bb = Rigid.cat(
all_frames_to_bb = rigid_matrix_vector.Rigid3Array.cat(
[
all_frames[..., :5],
chi2_frame_to_bb.unsqueeze(-1),
......@@ -248,7 +248,7 @@ def torsion_angles_to_frames(
def frames_and_literature_positions_to_atom14_pos(
r: Rigid,
r: rigid_matrix_vector.Rigid3Array,
aatype: torch.Tensor,
default_frames,
group_idx,
......@@ -275,8 +275,8 @@ def frames_and_literature_positions_to_atom14_pos(
lambda x: torch.sum(x, dim=-1)
)
# [*, N, 14, 1]
atom_mask = atom_mask[aatype, ...].unsqueeze(-1)
# [*, N, 14]
atom_mask = atom_mask[aatype, ...]
# [*, N, 14, 3]
lit_positions = lit_positions[aatype, ...]
......
......@@ -22,8 +22,6 @@ class QuatRigid(nn.Module):
# NOTE: During training, this needs to be run in higher precision
rigid_flat = self.linear(activations.to(torch.float32))
print(rigid_flat.shape)
rigid_flat = torch.unbind(rigid_flat, dim=-1)
if(self.full_quat):
qw, qx, qy, qz = rigid_flat[:4]
......
......@@ -67,6 +67,9 @@ class Rigid3Array:
"""Apply Rigid3Array transform to point."""
return self.rotation.apply_to_point(point) + self.translation
def apply(self, point: torch.Tensor) -> vector.Vec3Array:
return self.apply_to_point(vector.Vec3Array.from_array(point))
def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array:
"""Apply inverse Rigid3Array transform to point."""
new_point = point - self.translation
......@@ -74,7 +77,28 @@ class Rigid3Array:
def compose_rotation(self, other_rotation):
rot = self.rotation @ other_rotation
return Rigid3Array(rot, trans.clone())
return Rigid3Array(rot, self.translation.clone())
def compose(self, other_rigid):
return self @ other_rigid
def unsqueeze(self, dim: int):
return Rigid3Array(
self.rotation.unsqueeze(dim),
self.translation.unsqueeze(dim),
)
@property
def shape(self) -> torch.Size:
return self.rotation.xx.shape
@property
def dtype(self) -> torch.dtype:
return self.rotation.xx.dtype
@property
def device(self) -> torch.device:
return self.rotation.xx.device
@classmethod
def identity(cls, shape, device) -> Rigid3Array:
......@@ -87,28 +111,51 @@ class Rigid3Array:
@classmethod
def cat(cls, rigids: List[Rigid3Array], dim: int) -> Rigid3Array:
return cls(
Rot3Array.cat([r.rotation for r in rigids], dim=dim),
Vec3Array.cat([r.translation for r in rigids], dim=dim),
rotation_matrix.Rot3Array.cat(
[r.rotation for r in rigids], dim=dim
),
vector.Vec3Array.cat(
[r.translation for r in rigids], dim=dim
),
)
def scale_translation(self, factor: Float) -> Rigid3Array:
"""Scale translation in Rigid3Array by 'factor'."""
return Rigid3Array(self.rotation, self.translation * factor)
def to_array(self):
rot_array = self.rotation.to_array()
vec_array = self.translation.to_array()
return torch.cat([rot_array, vec_array[..., None]], dim=-1)
def to_tensor(self) -> torch.Tensor:
rot_array = self.rotation.to_tensor()
vec_array = self.translation.to_tensor()
array = torch.zeros(
rot_array.shape[:-2] + (4, 4),
device=rot_array.device,
dtype=rot_array.dtype
)
array[..., :3, :3] = rot_array
array[..., :3, 3] = vec_array
array[..., 3, 3] = 1.
return array
def to_tensor_4x4(self) -> torch.Tensor:
return self.to_tensor()
def reshape(self, new_shape) -> Rigid3Array:
rots = self.rotation.reshape(new_shape)
trans = self.translation.reshape(new_shape)
return Rigid3Aray(rots, trans)
def stop_rot_gradient(self) -> Rigid3Array:
return Rigid3Array(
self.rotation.stop_gradient(),
self.translation,
)
@classmethod
def from_array(cls, array):
rot = rotation_matrix.Rot3Array.from_array(array[..., :3])
vec = vector.Vec3Array.from_array(array[..., -1])
rot = rotation_matrix.Rot3Array.from_array(
array[..., :3, :3],
)
vec = vector.Vec3Array.from_array(array[..., :3, 3])
return cls(rot, vec)
@classmethod
......@@ -124,5 +171,6 @@ class Rigid3Array:
array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]
)
translation = vector.Vec3Array(
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3])
array[..., 0, 3], array[..., 1, 3], array[..., 2, 3]
)
return cls(rotation, translation)
......@@ -22,6 +22,7 @@ import numpy as np
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import utils
from openfold.utils.geometry import vector
from openfold.utils.tensor_utils import tensor_tree_map
COMPONENTS = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz']
......@@ -59,6 +60,13 @@ class Rot3Array:
}
)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def map_tensor_fn(self, fn) -> Rot3Array:
field_names = utils.get_field_names(Rot3Array)
return Rot3Array(
......@@ -88,12 +96,19 @@ class Rot3Array:
"""Applies inverse Rot3Array to point."""
return self.inverse().apply_to_point(point)
def __matmul__(self, other: Rot3Array) -> Rot3Array:
"""Composes two Rot3Arrays."""
c0 = self.apply_to_point(vector.Vec3Array(other.xx, other.yx, other.zx))
c1 = self.apply_to_point(vector.Vec3Array(other.xy, other.yy, other.zy))
c2 = self.apply_to_point(vector.Vec3Array(other.xz, other.yz, other.zz))
return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z)
def unsqueeze(self, dim: int):
return Rot3Array(
*tensor_tree_map(
lambda t: t.unsqueeze(dim),
[getattr(self, c) for c in COMPONENTS]
)
)
def stop_gradient(self) -> Rot3Array:
return Rot3Array(
*[getattr(self, c).detach() for c in COMPONENTS]
)
@classmethod
def identity(cls, shape, device) -> Rot3Array:
......@@ -130,9 +145,11 @@ class Rot3Array:
@classmethod
def from_array(cls, array: torch.Tensor) -> Rot3Array:
"""Construct Rot3Array Matrix from array of shape. [..., 3, 3]."""
return cls(torch.unbind(array, dim=-2))
rows = torch.unbind(array, dim=-2)
rc = [torch.unbind(e, dim=-1) for e in rows]
return cls(*[e for row in rc for e in row])
def to_array(self) -> torch.Tensor:
def to_tensor(self) -> torch.Tensor:
"""Convert Rot3Array to array of shape [..., 3, 3]."""
return torch.stack(
[
......@@ -140,7 +157,8 @@ class Rot3Array:
torch.stack([self.yx, self.yy, self.yz], dim=-1),
torch.stack([self.zx, self.zy, self.zz], dim=-1)
],
dim=-2)
dim=-2
)
@classmethod
def from_quaternion(cls,
......
......@@ -134,13 +134,20 @@ class Vec3Array:
return Vec3Array(x, y, z)
def sum(self, dim) -> Vec3Array:
def sum(self, dim: int) -> Vec3Array:
return Vec3Array(
torch.sum(self.x, dim=dim),
torch.sum(self.y, dim=dim),
torch.sum(self.z, dim=dim),
)
def unsqueeze(self, dim: int):
return Vec3Array(
self.x.unsqueeze(dim),
self.y.unsqueeze(dim),
self.z.unsqueeze(dim),
)
@classmethod
def zeros(cls, shape, device="cpu"):
"""Return Vec3Array corresponding to zeros of given shape."""
......@@ -150,11 +157,11 @@ class Vec3Array:
torch.zeros(shape, dtype=torch.float32, device=device)
)
def to_array(self) -> torch.Tensor:
def to_tensor(self) -> torch.Tensor:
return torch.stack([self.x, self.y, self.z], dim=-1)
@classmethod
def from_tensor(cls, tensor):
def from_array(cls, tensor):
return cls(*torch.unbind(tensor, dim=-1))
@classmethod
......
......@@ -623,8 +623,8 @@ def import_jax_weights_(model, npz_path, version="model_1"):
flat_keys = list(flat.keys())
incorrect = [k for k in flat_keys if k not in keys]
missing = [k for k in keys if k not in flat_keys]
print(f"Incorrect: {incorrect}")
print(f"Missing: {missing}")
# print(f"Incorrect: {incorrect}")
# print(f"Missing: {missing}")
assert len(incorrect) == 0
# assert(sorted(list(flat.keys())) == sorted(list(data.keys())))
......
......@@ -217,7 +217,8 @@ def main(args):
unrelaxed_protein = protein.from_prediction(
features=batch,
result=out,
b_factors=plddt_b_factors
b_factors=plddt_b_factors,
remove_leading_feature_dimension=not is_multimer,
)
# Save the unrelaxed PDB.
......@@ -227,6 +228,9 @@ def main(args):
with open(unrelaxed_output_path, 'w') as f:
f.write(protein.to_pdb(unrelaxed_protein))
print(unrelaxed_output_path)
print("asdjfh klasjdhf lkasjdhf lkjasdhflkjasdh fkl jasdhfklj hasdkljf hasldkjfh lkasjdfh lkajsdhflk asd")
amber_relaxer = relax.AmberRelaxation(
use_gpu=(args.model_device != "cpu"),
**config.relax,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment