Unverified Commit 4b410596 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #199 from ebetica/upstream_updates

Minor optimizations & fixes to support ESMFold
parents 349fdbd9 023596d2
...@@ -573,11 +573,11 @@ class StructureModule(nn.Module): ...@@ -573,11 +573,11 @@ class StructureModule(nn.Module):
self.epsilon = epsilon self.epsilon = epsilon
self.inf = inf self.inf = inf
# To be lazily initialized later # Buffers to be lazily initialized later
self.default_frames = None # self.default_frames
self.group_idx = None # self.group_idx
self.atom_mask = None # self.atom_mask
self.lit_positions = None # self.lit_positions
self.layer_norm_s = LayerNorm(self.c_s) self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z) self.layer_norm_z = LayerNorm(self.c_z)
...@@ -723,6 +723,7 @@ class StructureModule(nn.Module): ...@@ -723,6 +723,7 @@ class StructureModule(nn.Module):
"unnormalized_angles": unnormalized_angles, "unnormalized_angles": unnormalized_angles,
"angles": angles, "angles": angles,
"positions": pred_xyz, "positions": pred_xyz,
"states": s,
} }
outputs.append(preds) outputs.append(preds)
...@@ -742,32 +743,48 @@ class StructureModule(nn.Module): ...@@ -742,32 +743,48 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, float_dtype, device): def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None: if not hasattr(self, "default_frames"):
self.default_frames = torch.tensor( self.register_buffer(
"default_frames",
torch.tensor(
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
),
persistent=False,
) )
if self.group_idx is None: if not hasattr(self, "group_idx"):
self.group_idx = torch.tensor( self.register_buffer(
"group_idx",
torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
device=device, device=device,
requires_grad=False, requires_grad=False,
),
persistent=False,
) )
if self.atom_mask is None: if not hasattr(self, "atom_mask"):
self.atom_mask = torch.tensor( self.register_buffer(
"atom_mask",
torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
),
persistent=False,
) )
if self.lit_positions is None: if not hasattr(self, "lit_positions"):
self.lit_positions = torch.tensor( self.register_buffer(
"lit_positions",
torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=float_dtype, dtype=float_dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
),
persistent=False,
) )
def torsion_angles_to_frames(self, r, alpha, f): def torsion_angles_to_frames(self, r, alpha, f):
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
# limitations under the License. # limitations under the License.
from __future__ import annotations from __future__ import annotations
from functools import lru_cache
from typing import Tuple, Any, Sequence, Callable, Optional from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np import numpy as np
...@@ -84,7 +85,7 @@ def rot_vec_mul( ...@@ -84,7 +85,7 @@ def rot_vec_mul(
dim=-1, dim=-1,
) )
@lru_cache(maxsize=None)
def identity_rot_mats( def identity_rot_mats(
batch_dims: Tuple[int], batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -101,6 +102,7 @@ def identity_rot_mats( ...@@ -101,6 +102,7 @@ def identity_rot_mats(
return rots return rots
@lru_cache(maxsize=None)
def identity_trans( def identity_trans(
batch_dims: Tuple[int], batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -116,6 +118,7 @@ def identity_trans( ...@@ -116,6 +118,7 @@ def identity_trans(
return trans return trans
@lru_cache(maxsize=None)
def identity_quats( def identity_quats(
batch_dims: Tuple[int], batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
...@@ -175,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor: ...@@ -175,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
quat = quat[..., None] * quat[..., None, :] quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3] # [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False) mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
# [*, 4, 4, 3, 3] # [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape) shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
...@@ -230,10 +233,20 @@ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1], ...@@ -230,10 +233,20 @@ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :] _QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC
}
@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2): def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion.""" """Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY) mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape) reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum( return torch.sum(
reshaped_mat * reshaped_mat *
...@@ -245,7 +258,7 @@ def quat_multiply(quat1, quat2): ...@@ -245,7 +258,7 @@ def quat_multiply(quat1, quat2):
def quat_multiply_by_vec(quat, vec): def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion.""" """Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC) mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape) reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum( return torch.sum(
reshaped_mat * reshaped_mat *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment