"alphafold/model/structure_module.py" did not exist on "3d9c2de3f1857072be78ee80396553a5347bda57"
Unverified Commit 4b410596 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz Committed by GitHub
Browse files

Merge pull request #199 from ebetica/upstream_updates

Minor optimizations & fixes to support ESMFold
parents 349fdbd9 023596d2
......@@ -573,11 +573,11 @@ class StructureModule(nn.Module):
self.epsilon = epsilon
self.inf = inf
# To be lazily initialized later
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
# Buffers to be lazily initialized later
# self.default_frames
# self.group_idx
# self.atom_mask
# self.lit_positions
self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
......@@ -723,6 +723,7 @@ class StructureModule(nn.Module):
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"states": s,
}
outputs.append(preds)
......@@ -742,32 +743,48 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
......
......@@ -14,6 +14,7 @@
# limitations under the License.
from __future__ import annotations
from functools import lru_cache
from typing import Tuple, Any, Sequence, Callable, Optional
import numpy as np
......@@ -84,7 +85,7 @@ def rot_vec_mul(
dim=-1,
)
@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -101,6 +102,7 @@ def identity_rot_mats(
return rots
@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -116,6 +118,7 @@ def identity_trans(
return trans
@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
......@@ -175,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
quat = quat[..., None] * quat[..., None, :]
# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
......@@ -230,10 +233,20 @@ _QUAT_MULTIPLY[:, :, 3] = [[ 0, 0, 0, 1],
_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]
_CACHED_QUATS = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC
}
@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)
def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
......@@ -245,7 +258,7 @@ def quat_multiply(quat1, quat2):
def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment