Commit 42427a1c authored by Shenggan's avatar Shenggan
Browse files

add inference pipeline from openfold/alphafold

parent b3b3b445
......@@ -3,10 +3,10 @@ import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from fastfold.model.kernel import LayerNorm
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.model.ops import Transition, SelfAttention
from fastfold.model.kernel import bias_dropout_add
from fastfold.model.fastnn.ops import Transition, SelfAttention
from fastfold.model.fastnn.kernel import bias_dropout_add
from fastfold.distributed import scatter, row_to_col
from fastfold.distributed.comm_async import gather_async
......
......@@ -2,12 +2,12 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from fastfold.model.kernel import scale_mask_softmax, scale_mask_bias_softmax
from fastfold.model.kernel import LayerNorm
from fastfold.model.fastnn.kernel import scale_mask_softmax, scale_mask_bias_softmax
from fastfold.model.fastnn.kernel import LayerNorm
from .initializer import glorot_uniform_af
from fastfold.model.kernel import bias_sigmod_ele
from fastfold.model.fastnn.kernel import bias_sigmod_ele
from fastfold.distributed import gather, scatter
from fastfold.distributed.comm_async import gather_async, gather_async_opp
......
......@@ -2,10 +2,10 @@ from fastfold.distributed.comm_async import gather_async
import torch
import torch.nn as nn
from fastfold.model.kernel import LayerNorm
from fastfold.model.fastnn.kernel import LayerNorm
from fastfold.distributed.comm import col_to_row, row_to_col, scatter
from fastfold.model.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.ops import Linear, SelfAttention, Transition
from fastfold.model.fastnn.kernel import bias_dropout_add, bias_ele_dropout_residual
from fastfold.model.fastnn.ops import Linear, SelfAttention, Transition
from fastfold.distributed.comm_async import gather_async_opp, gather_async
......
from .alphafold import AlphaFold
__all__ = ["AlphaFold"]
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from fastfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
build_template_angle_feat,
build_template_pair_feat,
atom14_to_atom37,
)
from fastfold.model.nn.embedders import (
InputEmbedder,
RecyclingEmbedder,
TemplateAngleEmbedder,
TemplatePairEmbedder,
ExtraMSAEmbedder,
)
from fastfold.model.nn.evoformer import EvoformerStack, ExtraMSAStack
from fastfold.model.nn.heads import AuxiliaryHeads
import fastfold.common.residue_constants as residue_constants
from fastfold.model.nn.structure_module import StructureModule
from fastfold.model.nn.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from fastfold.model.loss import (
compute_plddt,
)
from fastfold.utils.tensor_utils import (
dict_multimap,
tensor_tree_map,
)
class AlphaFold(nn.Module):
"""
Alphafold 2.
Implements Algorithm 2 (but with training).
"""
def __init__(self, config):
"""
Args:
config:
A dict-like config object (like the one in config.py)
"""
super(AlphaFold, self).__init__()
self.globals = config.globals
config = config.model
template_config = config.template
extra_msa_config = config.extra_msa
# Main trunk + structure module
self.input_embedder = InputEmbedder(
**config["input_embedder"],
)
self.recycling_embedder = RecyclingEmbedder(
**config["recycling_embedder"],
)
self.template_angle_embedder = TemplateAngleEmbedder(
**template_config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**template_config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**template_config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**template_config["template_pointwise_attention"],
)
self.extra_msa_embedder = ExtraMSAEmbedder(
**extra_msa_config["extra_msa_embedder"],
)
self.extra_msa_stack = ExtraMSAStack(
**extra_msa_config["extra_msa_stack"],
)
self.evoformer = EvoformerStack(
**config["evoformer_stack"],
)
self.structure_module = StructureModule(
**config["structure_module"],
)
self.aux_heads = AuxiliaryHeads(
config["heads"],
)
self.config = config
def embed_templates(self, batch, z, pair_mask, templ_dim):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
if self.config.template.embed_angles:
template_angle_feat = build_template_angle_feat(
single_template_feats,
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
single_template_embeds["angle"] = a
# [*, S_t, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
single_template_embeds.update({"pair": t})
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["pair"],
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
)
t = t * (torch.sum(batch["template_mask"]) > 0)
ret = {}
if self.config.template.embed_angles:
ret["template_angle_embedding"] = template_embeds["angle"]
ret.update({"template_pair_embedding": t})
return ret
def iteration(self, feats, m_1_prev, z_prev, x_prev, _recycle=True):
# Primary output dictionary
outputs = {}
dtype = next(self.parameters()).dtype
for k in feats:
if(feats[k].dtype == torch.float32):
feats[k] = feats[k].to(dtype=dtype)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Initialize the recycling embeddings, if needs be
if None in [m_1_prev, z_prev, x_prev]:
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.input_embedder.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.input_embedder.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(
feats["aatype"], x_prev, None
).to(dtype=z.dtype)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# If the number of recycling iterations is 0, skip recycling
# altogether. We zero them this way instead of computing them
# conditionally to avoid leaving parameters unused, which has annoying
# implications for DDP training.
if(not _recycle):
m_1_prev_emb *= 0
z_prev_emb *= 0
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z += z_prev_emb
# Possibly prevents memory fragmentation
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled:
template_feats = {
k: v for k, v in feats.items() if k.startswith("template_")
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask.to(dtype=z.dtype),
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if self.config.template.embed_angles:
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled:
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"].to(dtype=a.dtype),
chunk_size=self.globals.chunk_size,
pair_mask=pair_mask.to(dtype=z.dtype),
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask.to(dtype=m.dtype),
pair_mask=pair_mask.to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
_mask_trans=self.config._mask_trans,
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s,
z,
feats["aatype"],
mask=feats["seq_mask"].to(dtype=s.dtype),
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
outputs["final_affine_tensor"] = outputs["sm"]["frames"][-1]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [*, N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def _disable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_positions"
([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Initialize recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Disable activation checkpointing for the first few recycling iters
is_grad_enabled = torch.is_grad_enabled()
self._disable_activation_checkpointing()
# Main recycling loop
num_iters = batch["aatype"].shape[-1]
for cycle_no in range(num_iters):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1)
with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter:
self._enable_activation_checkpointing()
# Sidestep AMP bug (PyTorch issue #65766)
if torch.is_autocast_enabled():
torch.clear_autocast_cache()
# Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats,
m_1_prev,
z_prev,
x_prev,
_recycle=(num_iters > 1)
)
# Run auxiliary heads
outputs.update(self.aux_heads(outputs))
return outputs
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import logging
import ml_collections
import numpy as np
import torch
import torch.nn as nn
from torch.distributions.bernoulli import Bernoulli
from typing import Dict, Optional, Tuple
from fastfold.common import residue_constants
from fastfold.utils import feats
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
tree_map,
tensor_tree_map,
masked_mean,
permute_final_dims,
batched_gather,
)
def softmax_cross_entropy(logits, labels):
loss = -1 * torch.sum(
labels * torch.nn.functional.log_softmax(logits, dim=-1),
dim=-1,
)
return loss
def sigmoid_cross_entropy(logits, labels):
log_p = torch.log(torch.sigmoid(logits))
log_not_p = torch.log(torch.sigmoid(-logits))
loss = -labels * log_p - (1 - labels) * log_not_p
return loss
def torsion_angle_loss(
a, # [*, N, 7, 2]
a_gt, # [*, N, 7, 2]
a_alt_gt, # [*, N, 7, 2]
):
# [*, N, 7]
norm = torch.norm(a, dim=-1)
# [*, N, 7, 2]
a = a / norm.unsqueeze(-1)
# [*, N, 7]
diff_norm_gt = torch.norm(a - a_gt, dim=-1)
diff_norm_alt_gt = torch.norm(a - a_alt_gt, dim=-1)
min_diff = torch.minimum(diff_norm_gt ** 2, diff_norm_alt_gt ** 2)
# [*]
l_torsion = torch.mean(min_diff, dim=(-1, -2))
l_angle_norm = torch.mean(torch.abs(norm - 1), dim=(-1, -2))
an_weight = 0.02
return l_torsion + an_weight * l_angle_norm
def compute_fape(
pred_frames: Rigid,
target_frames: Rigid,
frames_mask: torch.Tensor,
pred_positions: torch.Tensor,
target_positions: torch.Tensor,
positions_mask: torch.Tensor,
length_scale: float,
l1_clamp_distance: Optional[float] = None,
eps=1e-8,
) -> torch.Tensor:
"""
Computes FAPE loss.
Args:
pred_frames:
[*, N_frames] Rigid object of predicted frames
target_frames:
[*, N_frames] Rigid object of ground truth frames
frames_mask:
[*, N_frames] binary mask for the frames
pred_positions:
[*, N_pts, 3] predicted atom positions
target_positions:
[*, N_pts, 3] ground truth positions
positions_mask:
[*, N_pts] positions mask
length_scale:
Length scale by which the loss is divided
l1_clamp_distance:
Cutoff above which distance errors are disregarded
eps:
Small value used to regularize denominators
Returns:
[*] loss tensor
"""
# [*, N_frames, N_pts, 3]
local_pred_pos = pred_frames.invert()[..., None].apply(
pred_positions[..., None, :, :],
)
local_target_pos = target_frames.invert()[..., None].apply(
target_positions[..., None, :, :],
)
error_dist = torch.sqrt(
torch.sum((local_pred_pos - local_target_pos) ** 2, dim=-1) + eps
)
if l1_clamp_distance is not None:
error_dist = torch.clamp(error_dist, min=0, max=l1_clamp_distance)
normed_error = error_dist / length_scale
normed_error = normed_error * frames_mask[..., None]
normed_error = normed_error * positions_mask[..., None, :]
# FP16-friendly averaging. Roughly equivalent to:
#
# norm_factor = (
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = (
normed_error / (eps + torch.sum(frames_mask, dim=-1))[..., None]
)
normed_error = torch.sum(normed_error, dim=-1)
normed_error = normed_error / (eps + torch.sum(positions_mask, dim=-1))
return normed_error
def backbone_loss(
backbone_rigid_tensor: torch.Tensor,
backbone_rigid_mask: torch.Tensor,
traj: torch.Tensor,
use_clamped_fape: Optional[torch.Tensor] = None,
clamp_distance: float = 10.0,
loss_unit_distance: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
pred_aff = Rigid.from_tensor_7(traj)
pred_aff = Rigid(
Rotation(rot_mats=pred_aff.get_rots().get_rot_mats(), quats=None),
pred_aff.get_trans(),
)
# DISCREPANCY: DeepMind somehow gets a hold of a tensor_7 version of
# backbone tensor, normalizes it, and then turns it back to a rotation
# matrix. To avoid a potentially numerically unstable rotation matrix
# to quaternion conversion, we just use the original rotation matrix
# outright. This one hasn't been composed a bunch of times, though, so
# it might be fine.
gt_aff = Rigid.from_tensor_4x4(backbone_rigid_tensor)
fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
l1_clamp_distance=clamp_distance,
length_scale=loss_unit_distance,
eps=eps,
)
if use_clamped_fape is not None:
unclamped_fape_loss = compute_fape(
pred_aff,
gt_aff[None],
backbone_rigid_mask[None],
pred_aff.get_trans(),
gt_aff[None].get_trans(),
backbone_rigid_mask[None],
l1_clamp_distance=None,
length_scale=loss_unit_distance,
eps=eps,
)
fape_loss = fape_loss * use_clamped_fape + unclamped_fape_loss * (
1 - use_clamped_fape
)
# Average over the batch dimension
fape_loss = torch.mean(fape_loss)
return fape_loss
def sidechain_loss(
sidechain_frames: torch.Tensor,
sidechain_atom_pos: torch.Tensor,
rigidgroups_gt_frames: torch.Tensor,
rigidgroups_alt_gt_frames: torch.Tensor,
rigidgroups_gt_exists: torch.Tensor,
renamed_atom14_gt_positions: torch.Tensor,
renamed_atom14_gt_exists: torch.Tensor,
alt_naming_is_better: torch.Tensor,
clamp_distance: float = 10.0,
length_scale: float = 10.0,
eps: float = 1e-4,
**kwargs,
) -> torch.Tensor:
renamed_gt_frames = (
1.0 - alt_naming_is_better[..., None, None, None]
) * rigidgroups_gt_frames + alt_naming_is_better[
..., None, None, None
] * rigidgroups_alt_gt_frames
# Steamroll the inputs
sidechain_frames = sidechain_frames[-1]
batch_dims = sidechain_frames.shape[:-4]
sidechain_frames = sidechain_frames.view(*batch_dims, -1, 4, 4)
sidechain_frames = Rigid.from_tensor_4x4(sidechain_frames)
renamed_gt_frames = renamed_gt_frames.view(*batch_dims, -1, 4, 4)
renamed_gt_frames = Rigid.from_tensor_4x4(renamed_gt_frames)
rigidgroups_gt_exists = rigidgroups_gt_exists.reshape(*batch_dims, -1)
sidechain_atom_pos = sidechain_atom_pos[-1]
sidechain_atom_pos = sidechain_atom_pos.view(*batch_dims, -1, 3)
renamed_atom14_gt_positions = renamed_atom14_gt_positions.view(
*batch_dims, -1, 3
)
renamed_atom14_gt_exists = renamed_atom14_gt_exists.view(*batch_dims, -1)
fape = compute_fape(
sidechain_frames,
renamed_gt_frames,
rigidgroups_gt_exists,
sidechain_atom_pos,
renamed_atom14_gt_positions,
renamed_atom14_gt_exists,
l1_clamp_distance=clamp_distance,
length_scale=length_scale,
eps=eps,
)
return fape
def fape_loss(
out: Dict[str, torch.Tensor],
batch: Dict[str, torch.Tensor],
config: ml_collections.ConfigDict,
) -> torch.Tensor:
bb_loss = backbone_loss(
traj=out["sm"]["frames"],
**{**batch, **config.backbone},
)
sc_loss = sidechain_loss(
out["sm"]["sidechain_frames"],
out["sm"]["positions"],
**{**batch, **config.sidechain},
)
loss = config.backbone.weight * bb_loss + config.sidechain.weight * sc_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def supervised_chi_loss(
angles_sin_cos: torch.Tensor,
unnormalized_angles_sin_cos: torch.Tensor,
aatype: torch.Tensor,
seq_mask: torch.Tensor,
chi_mask: torch.Tensor,
chi_angles_sin_cos: torch.Tensor,
chi_weight: float,
angle_norm_weight: float,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
"""
Implements Algorithm 27 (torsionAngleLoss)
Args:
angles_sin_cos:
[*, N, 7, 2] predicted angles
unnormalized_angles_sin_cos:
The same angles, but unnormalized
aatype:
[*, N] residue indices
seq_mask:
[*, N] sequence mask
chi_mask:
[*, N, 7] angle mask
chi_angles_sin_cos:
[*, N, 7, 2] ground truth angles
chi_weight:
Weight for the angle component of the loss
angle_norm_weight:
Weight for the normalization component of the loss
Returns:
[*] loss tensor
"""
pred_angles = angles_sin_cos[..., 3:, :]
residue_type_one_hot = torch.nn.functional.one_hot(
aatype,
residue_constants.restype_num + 1,
)
chi_pi_periodic = torch.einsum(
"...ij,jk->ik",
residue_type_one_hot.type(angles_sin_cos.dtype),
angles_sin_cos.new_tensor(residue_constants.chi_pi_periodic),
)
true_chi = chi_angles_sin_cos[None]
shifted_mask = (1 - 2 * chi_pi_periodic).unsqueeze(-1)
true_chi_shifted = shifted_mask * true_chi
sq_chi_error = torch.sum((true_chi - pred_angles) ** 2, dim=-1)
sq_chi_error_shifted = torch.sum(
(true_chi_shifted - pred_angles) ** 2, dim=-1
)
sq_chi_error = torch.minimum(sq_chi_error, sq_chi_error_shifted)
# The ol' switcheroo
sq_chi_error = sq_chi_error.permute(
*range(len(sq_chi_error.shape))[1:-2], 0, -2, -1
)
sq_chi_loss = masked_mean(
chi_mask[..., None, :, :], sq_chi_error, dim=(-1, -2, -3)
)
loss = chi_weight * sq_chi_loss
angle_norm = torch.sqrt(
torch.sum(unnormalized_angles_sin_cos ** 2, dim=-1) + eps
)
norm_error = torch.abs(angle_norm - 1.0)
norm_error = norm_error.permute(
*range(len(norm_error.shape))[1:-2], 0, -2, -1
)
angle_norm_loss = masked_mean(
seq_mask[..., None, :, None], norm_error, dim=(-1, -2, -3)
)
loss = loss + angle_norm_weight * angle_norm_loss
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def compute_plddt(logits: torch.Tensor) -> torch.Tensor:
num_bins = logits.shape[-1]
bin_width = 1.0 / num_bins
bounds = torch.arange(
start=0.5 * bin_width, end=1.0, step=bin_width, device=logits.device
)
probs = torch.nn.functional.softmax(logits, dim=-1)
pred_lddt_ca = torch.sum(
probs * bounds.view(*((1,) * len(probs.shape[:-1])), *bounds.shape),
dim=-1,
)
return pred_lddt_ca * 100
def lddt(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
dmat_true = torch.sqrt(
eps
+ torch.sum(
(
all_atom_positions[..., None, :]
- all_atom_positions[..., None, :, :]
)
** 2,
dim=-1,
)
)
dmat_pred = torch.sqrt(
eps
+ torch.sum(
(
all_atom_pred_pos[..., None, :]
- all_atom_pred_pos[..., None, :, :]
)
** 2,
dim=-1,
)
)
dists_to_score = (
(dmat_true < cutoff)
* all_atom_mask
* permute_final_dims(all_atom_mask, (1, 0))
* (1.0 - torch.eye(n, device=all_atom_mask.device))
)
dist_l1 = torch.abs(dmat_true - dmat_pred)
score = (
(dist_l1 < 0.5).type(dist_l1.dtype)
+ (dist_l1 < 1.0).type(dist_l1.dtype)
+ (dist_l1 < 2.0).type(dist_l1.dtype)
+ (dist_l1 < 4.0).type(dist_l1.dtype)
)
score = score * 0.25
dims = (-1,) if per_residue else (-2, -1)
norm = 1.0 / (eps + torch.sum(dists_to_score, dim=dims))
score = norm * (eps + torch.sum(dists_to_score * score, dim=dims))
return score
def lddt_ca(
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
cutoff: float = 15.0,
eps: float = 1e-10,
per_residue: bool = True,
) -> torch.Tensor:
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
return lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps,
per_residue=per_residue,
)
def lddt_loss(
logits: torch.Tensor,
all_atom_pred_pos: torch.Tensor,
all_atom_positions: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
cutoff: float = 15.0,
no_bins: int = 50,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps: float = 1e-10,
**kwargs,
) -> torch.Tensor:
n = all_atom_mask.shape[-2]
ca_pos = residue_constants.atom_order["CA"]
all_atom_pred_pos = all_atom_pred_pos[..., ca_pos, :]
all_atom_positions = all_atom_positions[..., ca_pos, :]
all_atom_mask = all_atom_mask[..., ca_pos : (ca_pos + 1)] # keep dim
score = lddt(
all_atom_pred_pos,
all_atom_positions,
all_atom_mask,
cutoff=cutoff,
eps=eps
)
score = score.detach()
bin_index = torch.floor(score * no_bins).long()
bin_index = torch.clamp(bin_index, max=(no_bins - 1))
lddt_ca_one_hot = torch.nn.functional.one_hot(
bin_index, num_classes=no_bins
)
errors = softmax_cross_entropy(logits, lddt_ca_one_hot)
all_atom_mask = all_atom_mask.squeeze(-1)
loss = torch.sum(errors * all_atom_mask, dim=-1) / (
eps + torch.sum(all_atom_mask, dim=-1)
)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the batch dimension
loss = torch.mean(loss)
return loss
def distogram_loss(
logits,
pseudo_beta,
pseudo_beta_mask,
min_bin=2.3125,
max_bin=21.6875,
no_bins=64,
eps=1e-6,
**kwargs,
):
boundaries = torch.linspace(
min_bin,
max_bin,
no_bins - 1,
device=logits.device,
)
boundaries = boundaries ** 2
dists = torch.sum(
(pseudo_beta[..., None, :] - pseudo_beta[..., None, :, :]) ** 2,
dim=-1,
keepdims=True,
)
true_bins = torch.sum(dists > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits,
torch.nn.functional.one_hot(true_bins, no_bins),
)
square_mask = pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
# FP16-friendly sum. Equivalent to:
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
# (eps + torch.sum(square_mask, dim=(-1, -2))))
denom = eps + torch.sum(square_mask, dim=(-1, -2))
mean = errors * square_mask
mean = torch.sum(mean, dim=-1)
mean = mean / denom[..., None]
mean = torch.sum(mean, dim=-1)
# Average over the batch dimensions
mean = torch.mean(mean)
return mean
def _calculate_bin_centers(boundaries: torch.Tensor):
step = boundaries[1] - boundaries[0]
bin_centers = boundaries + step / 2
bin_centers = torch.cat(
[bin_centers, (bin_centers[-1] + step).unsqueeze(-1)], dim=0
)
return bin_centers
def _calculate_expected_aligned_error(
alignment_confidence_breaks: torch.Tensor,
aligned_distance_error_probs: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
bin_centers = _calculate_bin_centers(alignment_confidence_breaks)
return (
torch.sum(aligned_distance_error_probs * bin_centers, dim=-1),
bin_centers[-1],
)
def compute_predicted_aligned_error(
logits: torch.Tensor,
max_bin: int = 31,
no_bins: int = 64,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes aligned confidence metrics from logits.
Args:
logits: [*, num_res, num_res, num_bins] the logits output from
PredictedAlignedErrorHead.
max_bin: Maximum bin value
no_bins: Number of bins
Returns:
aligned_confidence_probs: [*, num_res, num_res, num_bins] the predicted
aligned error probabilities over bins for each residue pair.
predicted_aligned_error: [*, num_res, num_res] the expected aligned distance
error for each pair of residues.
max_predicted_aligned_error: [*] the maximum predicted error possible.
"""
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
aligned_confidence_probs = torch.nn.functional.softmax(logits, dim=-1)
(
predicted_aligned_error,
max_predicted_aligned_error,
) = _calculate_expected_aligned_error(
alignment_confidence_breaks=boundaries,
aligned_distance_error_probs=aligned_confidence_probs,
)
return {
"aligned_confidence_probs": aligned_confidence_probs,
"predicted_aligned_error": predicted_aligned_error,
"max_predicted_aligned_error": max_predicted_aligned_error,
}
def compute_tm(
logits: torch.Tensor,
residue_weights: Optional[torch.Tensor] = None,
max_bin: int = 31,
no_bins: int = 64,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
if residue_weights is None:
residue_weights = logits.new_ones(logits.shape[-2])
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
bin_centers = _calculate_bin_centers(boundaries)
torch.sum(residue_weights)
n = logits.shape[-2]
clipped_n = max(n, 19)
d0 = 1.24 * (clipped_n - 15) ** (1.0 / 3) - 1.8
probs = torch.nn.functional.softmax(logits, dim=-1)
tm_per_bin = 1.0 / (1 + (bin_centers ** 2) / (d0 ** 2))
predicted_tm_term = torch.sum(probs * tm_per_bin, dim=-1)
normed_residue_mask = residue_weights / (eps + residue_weights.sum())
per_alignment = torch.sum(predicted_tm_term * normed_residue_mask, dim=-1)
weighted = per_alignment * residue_weights
argmax = (weighted == torch.max(weighted)).nonzero()[0]
return per_alignment[tuple(argmax)]
def tm_loss(
logits,
final_affine_tensor,
backbone_rigid_tensor,
backbone_rigid_mask,
resolution,
max_bin=31,
no_bins=64,
min_resolution: float = 0.1,
max_resolution: float = 3.0,
eps=1e-8,
**kwargs,
):
pred_affine = Rigid.from_tensor_7(final_affine_tensor)
backbone_rigid = Rigid.from_tensor_4x4(backbone_rigid_tensor)
def _points(affine):
pts = affine.get_trans()[..., None, :, :]
return affine.invert()[..., None].apply(pts)
sq_diff = torch.sum(
(_points(pred_affine) - _points(backbone_rigid)) ** 2, dim=-1
)
sq_diff = sq_diff.detach()
boundaries = torch.linspace(
0, max_bin, steps=(no_bins - 1), device=logits.device
)
boundaries = boundaries ** 2
true_bins = torch.sum(sq_diff[..., None] > boundaries, dim=-1)
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_bins, no_bins)
)
square_mask = (
backbone_rigid_mask[..., None] * backbone_rigid_mask[..., None, :]
)
loss = torch.sum(errors * square_mask, dim=-1)
scale = 0.5 # hack to help FP16 training along
denom = eps + torch.sum(scale * square_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
# Average over the loss dimension
loss = torch.mean(loss)
return loss
def between_residue_bond_loss(
pred_atom_positions: torch.Tensor, # (*, N, 37/14, 3)
pred_atom_mask: torch.Tensor, # (*, N, 37/14)
residue_index: torch.Tensor, # (*, N)
aatype: torch.Tensor, # (*, N)
tolerance_factor_soft=12.0,
tolerance_factor_hard=12.0,
eps=1e-6,
) -> Dict[str, torch.Tensor]:
"""Flat-bottom loss to penalize structural violations between residues.
This is a loss penalizing any violation of the geometry around the peptide
bond between consecutive amino acids. This loss corresponds to
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 44, 45.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
aatype: Amino acid type of given residue
tolerance_factor_soft: soft tolerance factor measured in standard deviations
of pdb distributions
tolerance_factor_hard: hard tolerance factor measured in standard deviations
of pdb distributions
Returns:
Dict containing:
* 'c_n_loss_mean': Loss for peptide bond length violations
* 'ca_c_n_loss_mean': Loss for violations of bond angle around C spanned
by CA, C, N
* 'c_n_ca_loss_mean': Loss for violations of bond angle around N spanned
by C, N, CA
* 'per_residue_loss_sum': sum of all losses for each residue
* 'per_residue_violation_mask': mask denoting all residues with violation
present.
"""
# Get the positions of the relevant backbone atoms.
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
this_c_pos = pred_atom_positions[..., :-1, 2, :]
this_c_mask = pred_atom_mask[..., :-1, 2]
next_n_pos = pred_atom_positions[..., 1:, 0, :]
next_n_mask = pred_atom_mask[..., 1:, 0]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
# Compute loss for the C--N bond.
c_n_bond_length = torch.sqrt(
eps + torch.sum((this_c_pos - next_n_pos) ** 2, dim=-1)
)
# The C-N bond to proline has slightly different length because of the ring.
next_is_proline = aatype[..., 1:] == residue_constants.resname_to_idx["PRO"]
gt_length = (
~next_is_proline
) * residue_constants.between_res_bond_length_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_c_n[
1
]
gt_stddev = (
~next_is_proline
) * residue_constants.between_res_bond_length_stddev_c_n[
0
] + next_is_proline * residue_constants.between_res_bond_length_stddev_c_n[
1
]
c_n_bond_length_error = torch.sqrt(eps + (c_n_bond_length - gt_length) ** 2)
c_n_loss_per_residue = torch.nn.functional.relu(
c_n_bond_length_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * has_no_gap_mask
c_n_loss = torch.sum(mask * c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_violation_mask = mask * (
c_n_bond_length_error > (tolerance_factor_hard * gt_stddev)
)
# Compute loss for the angles.
ca_c_bond_length = torch.sqrt(
eps + torch.sum((this_ca_pos - this_c_pos) ** 2, dim=-1)
)
n_ca_bond_length = torch.sqrt(
eps + torch.sum((next_n_pos - next_ca_pos) ** 2, dim=-1)
)
c_ca_unit_vec = (this_ca_pos - this_c_pos) / ca_c_bond_length[..., None]
c_n_unit_vec = (next_n_pos - this_c_pos) / c_n_bond_length[..., None]
n_ca_unit_vec = (next_ca_pos - next_n_pos) / n_ca_bond_length[..., None]
ca_c_n_cos_angle = torch.sum(c_ca_unit_vec * c_n_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_ca_c_n[0]
gt_stddev = residue_constants.between_res_bond_length_stddev_c_n[0]
ca_c_n_cos_angle_error = torch.sqrt(
eps + (ca_c_n_cos_angle - gt_angle) ** 2
)
ca_c_n_loss_per_residue = torch.nn.functional.relu(
ca_c_n_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_ca_mask * this_c_mask * next_n_mask * has_no_gap_mask
ca_c_n_loss = torch.sum(mask * ca_c_n_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
ca_c_n_violation_mask = mask * (
ca_c_n_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
c_n_ca_cos_angle = torch.sum((-c_n_unit_vec) * n_ca_unit_vec, dim=-1)
gt_angle = residue_constants.between_res_cos_angles_c_n_ca[0]
gt_stddev = residue_constants.between_res_cos_angles_c_n_ca[1]
c_n_ca_cos_angle_error = torch.sqrt(
eps + torch.square(c_n_ca_cos_angle - gt_angle)
)
c_n_ca_loss_per_residue = torch.nn.functional.relu(
c_n_ca_cos_angle_error - tolerance_factor_soft * gt_stddev
)
mask = this_c_mask * next_n_mask * next_ca_mask * has_no_gap_mask
c_n_ca_loss = torch.sum(mask * c_n_ca_loss_per_residue, dim=-1) / (
torch.sum(mask, dim=-1) + eps
)
c_n_ca_violation_mask = mask * (
c_n_ca_cos_angle_error > (tolerance_factor_hard * gt_stddev)
)
# Compute a per residue loss (equally distribute the loss to both
# neighbouring residues).
per_residue_loss_sum = (
c_n_loss_per_residue + ca_c_n_loss_per_residue + c_n_ca_loss_per_residue
)
per_residue_loss_sum = 0.5 * (
torch.nn.functional.pad(per_residue_loss_sum, (0, 1))
+ torch.nn.functional.pad(per_residue_loss_sum, (1, 0))
)
# Compute hard violations.
violation_mask = torch.max(
torch.stack(
[c_n_violation_mask, ca_c_n_violation_mask, c_n_ca_violation_mask],
dim=-2,
),
dim=-2,
)[0]
violation_mask = torch.maximum(
torch.nn.functional.pad(violation_mask, (0, 1)),
torch.nn.functional.pad(violation_mask, (1, 0)),
)
return {
"c_n_loss_mean": c_n_loss,
"ca_c_n_loss_mean": ca_c_n_loss,
"c_n_ca_loss_mean": c_n_ca_loss,
"per_residue_loss_sum": per_residue_loss_sum,
"per_residue_violation_mask": violation_mask,
}
def between_residue_clash_loss(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_atom_radius: torch.Tensor,
residue_index: torch.Tensor,
overlap_tolerance_soft=1.5,
overlap_tolerance_hard=1.5,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes between residues.
This is a loss penalizing any steric clashes due to non bonded atoms in
different peptides coming too close. This loss corresponds to the part with
different residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions: Predicted positions of atoms in
global prediction frame
atom14_atom_exists: Mask denoting whether atom at positions exists for given
amino acid type
atom14_atom_radius: Van der Waals radius for each atom.
residue_index: Residue index for given amino acid.
overlap_tolerance_soft: Soft tolerance factor.
overlap_tolerance_hard: Hard tolerance factor.
Returns:
Dict containing:
* 'mean_loss': average clash loss
* 'per_atom_loss_sum': sum of all clash losses per atom, shape (N, 14)
* 'per_atom_clash_mask': mask whether atom clashes with any other atom
shape (N, 14)
"""
fp_type = atom14_pred_positions.dtype
# Create the distance matrix.
# (N, N, 14, 14)
dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
# Create the mask for valid distances.
# shape (N, N, 14, 14)
dists_mask = (
atom14_atom_exists[..., :, None, :, None]
* atom14_atom_exists[..., None, :, None, :]
).type(fp_type)
# Mask out all the duplicate entries in the lower triangular matrix.
# Also mask out the diagonal (atom-pairs from the same residue) -- these atoms
# are handled separately.
dists_mask = dists_mask * (
residue_index[..., :, None, None, None]
< residue_index[..., None, :, None, None]
)
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(2), num_classes=14
)
c_one_hot = c_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *c_one_hot.shape
)
c_one_hot = c_one_hot.type(fp_type)
n_one_hot = torch.nn.functional.one_hot(
residue_index.new_tensor(0), num_classes=14
)
n_one_hot = n_one_hot.reshape(
*((1,) * len(residue_index.shape[:-1])), *n_one_hot.shape
)
n_one_hot = n_one_hot.type(fp_type)
neighbour_mask = (
residue_index[..., :, None, None, None] + 1
) == residue_index[..., None, :, None, None]
c_n_bonds = (
neighbour_mask
* c_one_hot[..., None, None, :, None]
* n_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - c_n_bonds)
# Disulfide bridge between two cysteines is no clash.
cys = residue_constants.restype_name_to_atom14_names["CYS"]
cys_sg_idx = cys.index("SG")
cys_sg_idx = residue_index.new_tensor(cys_sg_idx)
cys_sg_idx = cys_sg_idx.reshape(
*((1,) * len(residue_index.shape[:-1])), 1
).squeeze(-1)
cys_sg_one_hot = torch.nn.functional.one_hot(cys_sg_idx, num_classes=14)
disulfide_bonds = (
cys_sg_one_hot[..., None, None, :, None]
* cys_sg_one_hot[..., None, None, None, :]
)
dists_mask = dists_mask * (1.0 - disulfide_bonds)
# Compute the lower bound for the allowed distances.
# shape (N, N, 14, 14)
dists_lower_bound = dists_mask * (
atom14_atom_radius[..., :, None, :, None]
+ atom14_atom_radius[..., None, :, None, :]
)
# Compute the error.
# shape (N, N, 14, 14)
dists_to_low_error = dists_mask * torch.nn.functional.relu(
dists_lower_bound - overlap_tolerance_soft - dists
)
# Compute the mean loss.
# shape ()
mean_loss = torch.sum(dists_to_low_error) / (1e-6 + torch.sum(dists_mask))
# Compute the per atom loss sum.
# shape (N, 14)
per_atom_loss_sum = torch.sum(dists_to_low_error, dim=(-4, -2)) + torch.sum(
dists_to_low_error, axis=(-3, -1)
)
# Compute the hard clash mask.
# shape (N, N, 14, 14)
clash_mask = dists_mask * (
dists < (dists_lower_bound - overlap_tolerance_hard)
)
# Compute the per atom clash.
# shape (N, 14)
per_atom_clash_mask = torch.maximum(
torch.amax(clash_mask, axis=(-4, -2)),
torch.amax(clash_mask, axis=(-3, -1)),
)
return {
"mean_loss": mean_loss, # shape ()
"per_atom_loss_sum": per_atom_loss_sum, # shape (N, 14)
"per_atom_clash_mask": per_atom_clash_mask, # shape (N, 14)
}
def within_residue_violations(
atom14_pred_positions: torch.Tensor,
atom14_atom_exists: torch.Tensor,
atom14_dists_lower_bound: torch.Tensor,
atom14_dists_upper_bound: torch.Tensor,
tighten_bounds_for_loss=0.0,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""Loss to penalize steric clashes within residues.
This is a loss penalizing any steric violations or clashes of non-bonded atoms
in a given peptide. This loss corresponds to the part with
the same residues of
Jumper et al. (2021) Suppl. Sec. 1.9.11, eq 46.
Args:
atom14_pred_positions ([*, N, 14, 3]):
Predicted positions of atoms in global prediction frame.
atom14_atom_exists ([*, N, 14]):
Mask denoting whether atom at positions exists for given
amino acid type
atom14_dists_lower_bound ([*, N, 14]):
Lower bound on allowed distances.
atom14_dists_upper_bound ([*, N, 14]):
Upper bound on allowed distances
tighten_bounds_for_loss ([*, N]):
Extra factor to tighten loss
Returns:
Dict containing:
* 'per_atom_loss_sum' ([*, N, 14]):
sum of all clash losses per atom, shape
* 'per_atom_clash_mask' ([*, N, 14]):
mask whether atom clashes with any other atom shape
"""
# Compute the mask for each residue.
dists_masks = 1.0 - torch.eye(14, device=atom14_atom_exists.device)[None]
dists_masks = dists_masks.reshape(
*((1,) * len(atom14_atom_exists.shape[:-2])), *dists_masks.shape
)
dists_masks = (
atom14_atom_exists[..., :, :, None]
* atom14_atom_exists[..., :, None, :]
* dists_masks
)
# Distance matrix
dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., :, :, None, :]
- atom14_pred_positions[..., :, None, :, :]
)
** 2,
dim=-1,
)
)
# Compute the loss.
dists_to_low_error = torch.nn.functional.relu(
atom14_dists_lower_bound + tighten_bounds_for_loss - dists
)
dists_to_high_error = torch.nn.functional.relu(
dists - (atom14_dists_upper_bound - tighten_bounds_for_loss)
)
loss = dists_masks * (dists_to_low_error + dists_to_high_error)
# Compute the per atom loss sum.
per_atom_loss_sum = torch.sum(loss, dim=-2) + torch.sum(loss, dim=-1)
# Compute the violations mask.
violations = dists_masks * (
(dists < atom14_dists_lower_bound) | (dists > atom14_dists_upper_bound)
)
# Compute the per atom violations.
per_atom_violations = torch.maximum(
torch.max(violations, dim=-2)[0], torch.max(violations, axis=-1)[0]
)
return {
"per_atom_loss_sum": per_atom_loss_sum,
"per_atom_violations": per_atom_violations,
}
def find_structural_violations(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
violation_tolerance_factor: float,
clash_overlap_tolerance: float,
**kwargs,
) -> Dict[str, torch.Tensor]:
"""Computes several checks for structural violations."""
# Compute between residue backbone violations of bonds and angles.
connection_violations = between_residue_bond_loss(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
aatype=batch["aatype"],
tolerance_factor_soft=violation_tolerance_factor,
tolerance_factor_hard=violation_tolerance_factor,
)
# Compute the Van der Waals radius for every atom
# (the first letter of the atom name is the element type).
# Shape: (N, 14).
atomtype_radius = [
residue_constants.van_der_waals_radius[name[0]]
for name in residue_constants.atom_types
]
atomtype_radius = atom14_pred_positions.new_tensor(atomtype_radius)
atom14_atom_radius = (
batch["atom14_atom_exists"]
* atomtype_radius[batch["residx_atom14_to_atom37"]]
)
# Compute the between residue clash loss.
between_residue_clashes = between_residue_clash_loss(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_atom_radius=atom14_atom_radius,
residue_index=batch["residue_index"],
overlap_tolerance_soft=clash_overlap_tolerance,
overlap_tolerance_hard=clash_overlap_tolerance,
)
# Compute all within-residue violations (clashes,
# bond length and angle violations).
restype_atom14_bounds = residue_constants.make_atom14_dists_bounds(
overlap_tolerance=clash_overlap_tolerance,
bond_length_tolerance_factor=violation_tolerance_factor,
)
atom14_atom_exists = batch["atom14_atom_exists"]
atom14_dists_lower_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["lower_bound"]
)[batch["aatype"]]
atom14_dists_upper_bound = atom14_pred_positions.new_tensor(
restype_atom14_bounds["upper_bound"]
)[batch["aatype"]]
residue_violations = within_residue_violations(
atom14_pred_positions=atom14_pred_positions,
atom14_atom_exists=batch["atom14_atom_exists"],
atom14_dists_lower_bound=atom14_dists_lower_bound,
atom14_dists_upper_bound=atom14_dists_upper_bound,
tighten_bounds_for_loss=0.0,
)
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask = torch.max(
torch.stack(
[
connection_violations["per_residue_violation_mask"],
torch.max(
between_residue_clashes["per_atom_clash_mask"], dim=-1
)[0],
torch.max(residue_violations["per_atom_violations"], dim=-1)[0],
],
dim=-1,
),
dim=-1,
)[0]
return {
"between_residues": {
"bonds_c_n_loss_mean": connection_violations["c_n_loss_mean"], # ()
"angles_ca_c_n_loss_mean": connection_violations[
"ca_c_n_loss_mean"
], # ()
"angles_c_n_ca_loss_mean": connection_violations[
"c_n_ca_loss_mean"
], # ()
"connections_per_residue_loss_sum": connection_violations[
"per_residue_loss_sum"
], # (N)
"connections_per_residue_violation_mask": connection_violations[
"per_residue_violation_mask"
], # (N)
"clashes_mean_loss": between_residue_clashes["mean_loss"], # ()
"clashes_per_atom_loss_sum": between_residue_clashes[
"per_atom_loss_sum"
], # (N, 14)
"clashes_per_atom_clash_mask": between_residue_clashes[
"per_atom_clash_mask"
], # (N, 14)
},
"within_residues": {
"per_atom_loss_sum": residue_violations[
"per_atom_loss_sum"
], # (N, 14)
"per_atom_violations": residue_violations[
"per_atom_violations"
], # (N, 14),
},
"total_per_residue_violations_mask": per_residue_violations_mask, # (N)
}
def find_structural_violations_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
config: ml_collections.ConfigDict,
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
out = find_structural_violations(batch, atom14_pred_positions, **config)
to_np = lambda x: np.array(x)
np_out = tensor_tree_map(to_np, out)
return np_out
def extreme_ca_ca_distance_violations(
pred_atom_positions: torch.Tensor, # (N, 37(14), 3)
pred_atom_mask: torch.Tensor, # (N, 37(14))
residue_index: torch.Tensor, # (N)
max_angstrom_tolerance=1.5,
eps=1e-6,
) -> torch.Tensor:
"""Counts residues whose Ca is a large distance from its neighbour.
Measures the fraction of CA-CA pairs between consecutive amino acids that are
more than 'max_angstrom_tolerance' apart.
Args:
pred_atom_positions: Atom positions in atom37/14 representation
pred_atom_mask: Atom mask in atom37/14 representation
residue_index: Residue index for given amino acid, this is assumed to be
monotonically increasing.
max_angstrom_tolerance: Maximum distance allowed to not count as violation.
Returns:
Fraction of consecutive CA-CA pairs with violation.
"""
this_ca_pos = pred_atom_positions[..., :-1, 1, :]
this_ca_mask = pred_atom_mask[..., :-1, 1]
next_ca_pos = pred_atom_positions[..., 1:, 1, :]
next_ca_mask = pred_atom_mask[..., 1:, 1]
has_no_gap_mask = (residue_index[..., 1:] - residue_index[..., :-1]) == 1.0
ca_ca_distance = torch.sqrt(
eps + torch.sum((this_ca_pos - next_ca_pos) ** 2, dim=-1)
)
violations = (
ca_ca_distance - residue_constants.ca_ca
) > max_angstrom_tolerance
mask = this_ca_mask * next_ca_mask * has_no_gap_mask
mean = masked_mean(mask, violations, -1)
return mean
def compute_violation_metrics(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor, # (N, 14, 3)
violations: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Compute several metrics to assess the structural violations."""
ret = {}
extreme_ca_ca_violations = extreme_ca_ca_distance_violations(
pred_atom_positions=atom14_pred_positions,
pred_atom_mask=batch["atom14_atom_exists"],
residue_index=batch["residue_index"],
)
ret["violations_extreme_ca_ca_distance"] = extreme_ca_ca_violations
ret["violations_between_residue_bond"] = masked_mean(
batch["seq_mask"],
violations["between_residues"][
"connections_per_residue_violation_mask"
],
dim=-1,
)
ret["violations_between_residue_clash"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations["between_residues"]["clashes_per_atom_clash_mask"],
dim=-1,
)[0],
dim=-1,
)
ret["violations_within_residue"] = masked_mean(
mask=batch["seq_mask"],
value=torch.max(
violations["within_residues"]["per_atom_violations"], dim=-1
)[0],
dim=-1,
)
ret["violations_per_residue"] = masked_mean(
mask=batch["seq_mask"],
value=violations["total_per_residue_violations_mask"],
dim=-1,
)
return ret
def compute_violation_metrics_np(
batch: Dict[str, np.ndarray],
atom14_pred_positions: np.ndarray,
violations: Dict[str, np.ndarray],
) -> Dict[str, np.ndarray]:
to_tensor = lambda x: torch.tensor(x)
batch = tree_map(to_tensor, batch, np.ndarray)
atom14_pred_positions = to_tensor(atom14_pred_positions)
violations = tree_map(to_tensor, violations, np.ndarray)
out = compute_violation_metrics(batch, atom14_pred_positions, violations)
to_np = lambda x: np.array(x)
return tree_map(to_np, out, torch.Tensor)
def violation_loss(
violations: Dict[str, torch.Tensor],
atom14_atom_exists: torch.Tensor,
eps=1e-6,
**kwargs,
) -> torch.Tensor:
num_atoms = torch.sum(atom14_atom_exists)
l_clash = torch.sum(
violations["between_residues"]["clashes_per_atom_loss_sum"]
+ violations["within_residues"]["per_atom_loss_sum"]
)
l_clash = l_clash / (eps + num_atoms)
loss = (
violations["between_residues"]["bonds_c_n_loss_mean"]
+ violations["between_residues"]["angles_ca_c_n_loss_mean"]
+ violations["between_residues"]["angles_c_n_ca_loss_mean"]
+ l_clash
)
return loss
def compute_renamed_ground_truth(
batch: Dict[str, torch.Tensor],
atom14_pred_positions: torch.Tensor,
eps=1e-10,
) -> Dict[str, torch.Tensor]:
"""
Find optimal renaming of ground truth based on the predicted positions.
Alg. 26 "renameSymmetricGroundTruthAtoms"
This renamed ground truth is then used for all losses,
such that each loss moves the atoms in the same direction.
Args:
batch: Dictionary containing:
* atom14_gt_positions: Ground truth positions.
* atom14_alt_gt_positions: Ground truth positions with renaming swaps.
* atom14_atom_is_ambiguous: 1.0 for atoms that are affected by
renaming swaps.
* atom14_gt_exists: Mask for which atoms exist in ground truth.
* atom14_alt_gt_exists: Mask for which atoms exist in ground truth
after renaming.
* atom14_atom_exists: Mask for whether each atom is part of the given
amino acid type.
atom14_pred_positions: Array of atom positions in global frame with shape
Returns:
Dictionary containing:
alt_naming_is_better: Array with 1.0 where alternative swap is better.
renamed_atom14_gt_positions: Array of optimal ground truth positions
after renaming swaps are performed.
renamed_atom14_gt_exists: Mask after renaming swap is performed.
"""
pred_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_pred_positions[..., None, :, None, :]
- atom14_pred_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_gt_positions = batch["atom14_gt_positions"]
gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_gt_positions[..., None, :, None, :]
- atom14_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
atom14_alt_gt_positions = batch["atom14_alt_gt_positions"]
alt_gt_dists = torch.sqrt(
eps
+ torch.sum(
(
atom14_alt_gt_positions[..., None, :, None, :]
- atom14_alt_gt_positions[..., None, :, None, :, :]
)
** 2,
dim=-1,
)
)
lddt = torch.sqrt(eps + (pred_dists - gt_dists) ** 2)
alt_lddt = torch.sqrt(eps + (pred_dists - alt_gt_dists) ** 2)
atom14_gt_exists = batch["atom14_gt_exists"]
atom14_atom_is_ambiguous = batch["atom14_atom_is_ambiguous"]
mask = (
atom14_gt_exists[..., None, :, None]
* atom14_atom_is_ambiguous[..., None, :, None]
* atom14_gt_exists[..., None, :, None, :]
* (1.0 - atom14_atom_is_ambiguous[..., None, :, None, :])
)
per_res_lddt = torch.sum(mask * lddt, dim=(-1, -2, -3))
alt_per_res_lddt = torch.sum(mask * alt_lddt, dim=(-1, -2, -3))
fp_type = atom14_pred_positions.dtype
alt_naming_is_better = (alt_per_res_lddt < per_res_lddt).type(fp_type)
renamed_atom14_gt_positions = (
1.0 - alt_naming_is_better[..., None, None]
) * atom14_gt_positions + alt_naming_is_better[
..., None, None
] * atom14_alt_gt_positions
renamed_atom14_gt_mask = (
1.0 - alt_naming_is_better[..., None]
) * atom14_gt_exists + alt_naming_is_better[..., None] * batch[
"atom14_alt_gt_exists"
]
return {
"alt_naming_is_better": alt_naming_is_better,
"renamed_atom14_gt_positions": renamed_atom14_gt_positions,
"renamed_atom14_gt_exists": renamed_atom14_gt_mask,
}
def experimentally_resolved_loss(
logits: torch.Tensor,
atom37_atom_exists: torch.Tensor,
all_atom_mask: torch.Tensor,
resolution: torch.Tensor,
min_resolution: float,
max_resolution: float,
eps: float = 1e-8,
**kwargs,
) -> torch.Tensor:
errors = sigmoid_cross_entropy(logits, all_atom_mask)
loss = torch.sum(errors * atom37_atom_exists, dim=-1)
loss = loss / (eps + torch.sum(atom37_atom_exists, dim=(-1, -2)))
loss = torch.sum(loss, dim=-1)
loss = loss * (
(resolution >= min_resolution) & (resolution <= max_resolution)
)
loss = torch.mean(loss)
return loss
def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
"""
Computes BERT-style masked MSA loss. Implements subsection 1.9.9.
Args:
logits: [*, N_seq, N_res, 23] predicted residue distribution
true_msa: [*, N_seq, N_res] true MSA
bert_mask: [*, N_seq, N_res] MSA mask
Returns:
Masked MSA loss
"""
errors = softmax_cross_entropy(
logits, torch.nn.functional.one_hot(true_msa, num_classes=23)
)
# FP16-friendly averaging. Equivalent to:
# loss = (
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
loss = errors * bert_mask
loss = torch.sum(loss, dim=-1)
scale = 0.5
denom = eps + torch.sum(scale * bert_mask, dim=(-1, -2))
loss = loss / denom[..., None]
loss = torch.sum(loss, dim=-1)
loss = loss * scale
loss = torch.mean(loss)
return loss
def compute_drmsd(structure_1, structure_2, mask=None):
if(mask is not None):
structure_1 = structure_1 * mask[..., None]
structure_2 = structure_2 * mask[..., None]
d1 = structure_1[..., :, None, :] - structure_1[..., None, :, :]
d2 = structure_2[..., :, None, :] - structure_2[..., None, :, :]
d1 = d1 ** 2
d2 = d2 ** 2
d1 = torch.sqrt(torch.sum(d1, dim=-1))
d2 = torch.sqrt(torch.sum(d2, dim=-1))
drmsd = d1 - d2
drmsd = drmsd ** 2
drmsd = torch.sum(drmsd, dim=(-1, -2))
n = d1.shape[-1] if mask is None else torch.sum(mask, dim=-1)
drmsd = drmsd * (1 / (n * (n - 1))) if n > 1 else (drmsd * 0.)
drmsd = torch.sqrt(drmsd)
return drmsd
def compute_drmsd_np(structure_1, structure_2, mask=None):
structure_1 = torch.tensor(structure_1)
structure_2 = torch.tensor(structure_2)
if(mask is not None):
mask = torch.tensor(mask)
return compute_drmsd(structure_1, structure_2, mask)
class AlphaFoldLoss(nn.Module):
"""Aggregation of the various losses described in the supplement"""
def __init__(self, config):
super(AlphaFoldLoss, self).__init__()
self.config = config
def forward(self, out, batch, _return_breakdown=False):
if "violation" not in out.keys():
out["violation"] = find_structural_violations(
batch,
out["sm"]["positions"][-1],
**self.config.violation,
)
if "renamed_atom14_gt_positions" not in out.keys():
batch.update(
compute_renamed_ground_truth(
batch,
out["sm"]["positions"][-1],
)
)
loss_fns = {
"distogram": lambda: distogram_loss(
logits=out["distogram_logits"],
**{**batch, **self.config.distogram},
),
"experimentally_resolved": lambda: experimentally_resolved_loss(
logits=out["experimentally_resolved_logits"],
**{**batch, **self.config.experimentally_resolved},
),
"fape": lambda: fape_loss(
out,
batch,
self.config.fape,
),
"lddt": lambda: lddt_loss(
logits=out["lddt_logits"],
all_atom_pred_pos=out["final_atom_positions"],
**{**batch, **self.config.lddt},
),
"masked_msa": lambda: masked_msa_loss(
logits=out["masked_msa_logits"],
**{**batch, **self.config.masked_msa},
),
"supervised_chi": lambda: supervised_chi_loss(
out["sm"]["angles"],
out["sm"]["unnormalized_angles"],
**{**batch, **self.config.supervised_chi},
),
"violation": lambda: violation_loss(
out["violation"],
**batch,
),
}
if(self.config.tm.enabled):
loss_fns["tm"] = lambda: tm_loss(
logits=out["tm_logits"],
**{**batch, **out, **self.config.tm},
)
cum_loss = 0.
losses = {}
for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight
loss = loss_fn()
if(torch.isnan(loss) or torch.isinf(loss)):
logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cum_loss.detach().clone()
# Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9.
seq_len = torch.mean(batch["seq_length"].float())
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown):
return cum_loss
return cum_loss, losses
\ No newline at end of file
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from functools import partialmethod
from typing import Union, List
class Dropout(nn.Module):
"""
Implementation of dropout with the ability to share the dropout mask
along a particular dimension.
If not in training mode, this module computes the identity function.
"""
def __init__(self, r: float, batch_dim: Union[int, List[int]]):
"""
Args:
r:
Dropout rate
batch_dim:
Dimension(s) along which the dropout mask is shared
"""
super(Dropout, self).__init__()
self.r = r
if type(batch_dim) == int:
batch_dim = [batch_dim]
self.batch_dim = batch_dim
self.dropout = nn.Dropout(self.r)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Tensor to which dropout is applied. Can have any shape
compatible with self.batch_dim
"""
shape = list(x.shape)
if self.batch_dim is not None:
for bd in self.batch_dim:
shape[bd] = 1
mask = x.new_ones(shape)
mask = self.dropout(mask)
x *= mask
return x
class DropoutRowwise(Dropout):
"""
Convenience class for rowwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-3)
class DropoutColumnwise(Dropout):
"""
Convenience class for columnwise dropout as described in subsection
1.11.6.
"""
__init__ = partialmethod(Dropout.__init__, batch_dim=-2)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from typing import Tuple
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import one_hot
class InputEmbedder(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
relpos_k: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedder, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.relpos_k = relpos_k
self.no_bins = 2 * relpos_k + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, ri: torch.Tensor):
"""
Computes relative positional encodings
Implements Algorithm 4.
Args:
ri:
"residue_index" features of shape [*, N]
"""
d = ri[..., None] - ri[..., None, :]
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(
self,
tf: torch.Tensor,
ri: torch.Tensor,
msa: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
pair_emb:
[*, N_res, N_res, C_z] pair embedding
"""
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
Implements Algorithm 32.
"""
def __init__(
self,
c_m: int,
c_z: int,
min_bin: float,
max_bin: float,
no_bins: int,
inf: float = 1e8,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair embedding channel dimension
min_bin:
Smallest distogram bin (Angstroms)
max_bin:
Largest distogram bin (Angstroms)
no_bins:
Number of distogram bins
"""
super(RecyclingEmbedder, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.min_bin = min_bin
self.max_bin = max_bin
self.no_bins = no_bins
self.inf = inf
self.bins = None
self.linear = Linear(self.no_bins, self.c_z)
self.layer_norm_m = LayerNorm(self.c_m)
self.layer_norm_z = LayerNorm(self.c_z)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
m:
First row of the MSA embedding. [*, N_res, C_m]
z:
[*, N_res, N_res, C_z] pair embedding
x:
[*, N_res, 3] predicted C_beta coordinates
Returns:
m:
[*, N_res, C_m] MSA embedding update
z:
[*, N_res, N_res, C_z] pair embedding update
"""
if self.bins is None:
self.bins = torch.linspace(
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
device=x.device,
requires_grad=False,
)
# [*, N, C_m]
m_update = self.layer_norm_m(m)
# This squared method might become problematic in FP16 mode.
# I'm using it because my homegrown method had a stubborn discrepancy I
# couldn't find in time.
squared_bins = self.bins ** 2
upper = torch.cat(
[squared_bins[1:], squared_bins.new_tensor([self.inf])], dim=-1
)
d = torch.sum(
(x[..., None, :] - x[..., None, :, :]) ** 2, dim=-1, keepdims=True
)
# [*, N, N, no_bins]
d = ((d > squared_bins) * (d < upper)).type(x.dtype)
# [*, N, N, C_z]
d = self.linear(d)
z_update = d + self.layer_norm_z(z)
return m_update, z_update
class TemplateAngleEmbedder(nn.Module):
"""
Embeds the "template_angle_feat" feature.
Implements Algorithm 2, line 7.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Final dimension of "template_angle_feat"
c_out:
Output channel dimension
"""
super(TemplateAngleEmbedder, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.linear_1 = Linear(self.c_in, self.c_out, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.c_out, self.c_out, init="relu")
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: [*, N_templ, N_res, c_in] "template_angle_feat" features
Returns:
x: [*, N_templ, N_res, C_out] embedding
"""
x = self.linear_1(x)
x = self.relu(x)
x = self.linear_2(x)
return x
class TemplatePairEmbedder(nn.Module):
"""
Embeds "template_pair_feat" features.
Implements Algorithm 2, line 9.
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
c_out:
Output channel dimension
"""
super(TemplatePairEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
# Despite there being no relu nearby, the source uses that initializer
self.linear = Linear(self.c_in, self.c_out, init="relu")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Args:
x:
[*, C_in] input tensor
Returns:
[*, C_out] output tensor
"""
x = self.linear(x)
return x
class ExtraMSAEmbedder(nn.Module):
"""
Embeds unclustered MSA sequences.
Implements Algorithm 2, line 15
"""
def __init__(
self,
c_in: int,
c_out: int,
**kwargs,
):
"""
Args:
c_in:
Input channel dimension
c_out:
Output channel dimension
"""
super(ExtraMSAEmbedder, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.linear = Linear(self.c_in, self.c_out)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
[*, N_extra_seq, N_res, C_in] "extra_msa_feat" features
Returns:
[*, N_extra_seq, N_res, C_out] embedding
"""
x = self.linear(x)
return x
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from typing import Tuple, Optional
from functools import partial
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.model.nn.dropout import DropoutRowwise, DropoutColumnwise
from fastfold.model.nn.msa import (
MSARowAttentionWithPairBias,
MSAColumnAttention,
MSAColumnGlobalAttention,
)
from fastfold.model.nn.outer_product_mean import OuterProductMean
from fastfold.model.nn.pair_transition import PairTransition
from fastfold.model.nn.triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from fastfold.model.nn.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from fastfold.utils.checkpointing import checkpoint_blocks, get_checkpoint_fn
from fastfold.utils.tensor_utils import chunk_layer
class MSATransition(nn.Module):
"""
Feed-forward network applied to MSA activations after attention.
Implements Algorithm 9
"""
def __init__(self, c_m, n):
"""
Args:
c_m:
MSA channel dimension
n:
Factor multiplied to c_m to obtain the hidden channel
dimension
"""
super(MSATransition, self).__init__()
self.c_m = c_m
self.n = n
self.layer_norm = LayerNorm(self.c_m)
self.linear_1 = Linear(self.c_m, self.n * self.c_m, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_m, self.c_m, init="final")
def _transition(self, m, mask):
m = self.linear_1(m)
m = self.relu(m)
m = self.linear_2(m) * mask
return m
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"m": m, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA activation
mask:
[*, N_seq, N_res, C_m] MSA mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA activation update
"""
# DISCREPANCY: DeepMind forgets to apply the MSA mask here.
if mask is None:
mask = m.new_ones(m.shape[:-1])
mask = mask.unsqueeze(-1)
m = self.layer_norm(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
m = self._transition(m, mask)
return m
class EvoformerBlockCore(nn.Module):
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
pair_dropout: float,
inf: float,
eps: float,
_is_extra_msa_stack: bool = False,
):
super(EvoformerBlockCore, self).__init__()
self.msa_transition = MSATransition(
c_m=c_m,
n=transition_n,
)
self.outer_product_mean = OuterProductMean(
c_m,
c_z,
c_hidden_opm,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
c_z,
c_hidden_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
c_z,
c_hidden_mul,
)
self.tri_att_start = TriangleAttentionStartingNode(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
c_z,
c_hidden_pair_att,
no_heads_pair,
inf=inf,
)
self.pair_transition = PairTransition(
c_z,
transition_n,
)
self.ps_dropout_row_layer = DropoutRowwise(pair_dropout)
self.ps_dropout_col_layer = DropoutColumnwise(pair_dropout)
def forward(
self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
# DeepMind doesn't mask these transitions in the source, so _mask_trans
# should be disabled to better approximate the exact activations of
# the original.
msa_trans_mask = msa_mask if _mask_trans else None
pair_trans_mask = pair_mask if _mask_trans else None
m = m + self.msa_transition(
m, mask=msa_trans_mask, chunk_size=chunk_size
)
z = z + self.outer_product_mean(
m, mask=msa_mask, chunk_size=chunk_size
)
z = z + self.ps_dropout_row_layer(self.tri_mul_out(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(self.tri_mul_in(z, mask=pair_mask))
z = z + self.ps_dropout_row_layer(
self.tri_att_start(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.ps_dropout_col_layer(
self.tri_att_end(z, mask=pair_mask, chunk_size=chunk_size)
)
z = z + self.pair_transition(
z, mask=pair_trans_mask, chunk_size=chunk_size
)
return m, z
class EvoformerBlock(nn.Module):
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
):
super(EvoformerBlock, self).__init__()
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnAttention(
c_m,
c_hidden_msa_att,
no_heads_msa,
inf=inf,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(m, z=z, mask=msa_mask, chunk_size=chunk_size)
)
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
return m, z
class ExtraMSABlock(nn.Module):
"""
Almost identical to the standard EvoformerBlock, except in that the
ExtraMSABlock uses GlobalAttention for MSA column attention and
requires more fine-grained control over checkpointing. Separated from
its twin to preserve the TorchScript-ability of the latter.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
):
super(ExtraMSABlock, self).__init__()
self.ckpt = ckpt
self.msa_att_row = MSARowAttentionWithPairBias(
c_m=c_m,
c_z=c_z,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
)
self.msa_att_col = MSAColumnGlobalAttention(
c_in=c_m,
c_hidden=c_hidden_msa_att,
no_heads=no_heads_msa,
inf=inf,
eps=eps,
)
self.msa_dropout_layer = DropoutRowwise(msa_dropout)
self.core = EvoformerBlockCore(
c_m=c_m,
c_z=c_z,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = 1024,
) -> Tuple[torch.Tensor, torch.Tensor]:
m = m + self.msa_dropout_layer(
self.msa_att_row(
m.clone(),
z=z.clone(),
mask=msa_mask,
chunk_size=chunk_size,
_chunk_logits=_chunk_logits if torch.is_grad_enabled() else None,
_checkpoint_chunks=
self.ckpt if torch.is_grad_enabled() else False,
)
)
def fn(m, z):
m = m + self.msa_att_col(m, mask=msa_mask, chunk_size=chunk_size)
m, z = self.core(
m, z, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size
)
return m, z
if(torch.is_grad_enabled() and self.ckpt):
checkpoint_fn = get_checkpoint_fn()
m, z = checkpoint_fn(fn, m, z)
else:
m, z = fn(m, z)
return m, z
class EvoformerStack(nn.Module):
"""
Main Evoformer trunk.
Implements Algorithm 6.
"""
def __init__(
self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
c_s: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
blocks_per_ckpt: int,
inf: float,
eps: float,
clear_cache_between_blocks: bool = False,
**kwargs,
):
"""
Args:
c_m:
MSA channel dimension
c_z:
Pair channel dimension
c_hidden_msa_att:
Hidden dimension in MSA attention
c_hidden_opm:
Hidden dimension in outer product mean module
c_hidden_mul:
Hidden dimension in multiplicative updates
c_hidden_pair_att:
Hidden dimension in triangular attention
c_s:
Channel dimension of the output "single" embedding
no_heads_msa:
Number of heads used for MSA attention
no_heads_pair:
Number of heads used for pair attention
no_blocks:
Number of Evoformer blocks in the stack
transition_n:
Factor by which to multiply c_m to obtain the MSATransition
hidden dimension
msa_dropout:
Dropout rate for MSA activations
pair_dropout:
Dropout used for pair activations
blocks_per_ckpt:
Number of Evoformer blocks in each activation checkpoint
clear_cache_between_blocks:
Whether to clear CUDA's GPU memory cache between blocks of the
stack. Slows down each block but can reduce fragmentation
"""
super(EvoformerStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = EvoformerBlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
)
self.blocks.append(block)
self.linear = Linear(c_m, c_s)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
msa_mask: torch.Tensor,
pair_mask: torch.Tensor,
chunk_size: int,
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
[*, N_seq, N_res] MSA mask
pair_mask:
[*, N_res, N_res] pair mask
Returns:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
s:
[*, N_res, C_s] single embedding (or None if extra MSA stack)
"""
blocks = [
partial(
b,
msa_mask=msa_mask,
pair_mask=pair_mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
]
if(self.clear_cache_between_blocks):
def block_with_cache_clear(block, *args):
torch.cuda.empty_cache()
return block(*args)
blocks = [partial(block_with_cache_clear, b) for b in blocks]
m, z = checkpoint_blocks(
blocks,
args=(m, z),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
s = self.linear(m[..., 0, :, :])
return m, z, s
class ExtraMSAStack(nn.Module):
"""
Implements Algorithm 18.
"""
def __init__(self,
c_m: int,
c_z: int,
c_hidden_msa_att: int,
c_hidden_opm: int,
c_hidden_mul: int,
c_hidden_pair_att: int,
no_heads_msa: int,
no_heads_pair: int,
no_blocks: int,
transition_n: int,
msa_dropout: float,
pair_dropout: float,
inf: float,
eps: float,
ckpt: bool,
clear_cache_between_blocks: bool = False,
**kwargs,
):
super(ExtraMSAStack, self).__init__()
self.clear_cache_between_blocks = clear_cache_between_blocks
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = ExtraMSABlock(
c_m=c_m,
c_z=c_z,
c_hidden_msa_att=c_hidden_msa_att,
c_hidden_opm=c_hidden_opm,
c_hidden_mul=c_hidden_mul,
c_hidden_pair_att=c_hidden_pair_att,
no_heads_msa=no_heads_msa,
no_heads_pair=no_heads_pair,
transition_n=transition_n,
msa_dropout=msa_dropout,
pair_dropout=pair_dropout,
inf=inf,
eps=eps,
ckpt=ckpt,
)
self.blocks.append(block)
def forward(self,
m: torch.Tensor,
z: torch.Tensor,
chunk_size: int,
msa_mask: Optional[torch.Tensor] = None,
pair_mask: Optional[torch.Tensor] = None,
_mask_trans: bool = True,
) -> torch.Tensor:
"""
Args:
m:
[*, N_extra, N_res, C_m] extra MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding
msa_mask:
Optional [*, N_extra, N_res] MSA mask
pair_mask:
Optional [*, N_res, N_res] pair mask
Returns:
[*, N_res, N_res, C_z] pair update
"""
#checkpoint_fn = get_checkpoint_fn()
#blocks = [
# partial(b, msa_mask=msa_mask, pair_mask=pair_mask, chunk_size=chunk_size, _chunk_logits=None) for b in self.blocks
#]
#def dodo(b, *args):
# torch.cuda.empty_cache()
# return b(*args)
#blocks = [partial(dodo, b) for b in blocks]
#for b in blocks:
# if(torch.is_grad_enabled()):
# m, z = checkpoint_fn(b, *(m, z))
# else:
# m, z = b(m, z)
for b in self.blocks:
m, z = b(m, z, msa_mask, pair_mask, chunk_size=chunk_size)
if(self.clear_cache_between_blocks):
torch.cuda.empty_cache()
return z
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.model.loss import (
compute_plddt,
compute_tm,
compute_predicted_aligned_error,
)
class AuxiliaryHeads(nn.Module):
def __init__(self, config):
super(AuxiliaryHeads, self).__init__()
self.plddt = PerResidueLDDTCaPredictor(
**config["lddt"],
)
self.distogram = DistogramHead(
**config["distogram"],
)
self.masked_msa = MaskedMSAHead(
**config["masked_msa"],
)
self.experimentally_resolved = ExperimentallyResolvedHead(
**config["experimentally_resolved"],
)
if config.tm.enabled:
self.tm = TMScoreHead(
**config.tm,
)
self.config = config
def forward(self, outputs):
aux_out = {}
lddt_logits = self.plddt(outputs["sm"]["single"])
aux_out["lddt_logits"] = lddt_logits
# Required for relaxation later on
aux_out["plddt"] = compute_plddt(lddt_logits)
distogram_logits = self.distogram(outputs["pair"])
aux_out["distogram_logits"] = distogram_logits
masked_msa_logits = self.masked_msa(outputs["msa"])
aux_out["masked_msa_logits"] = masked_msa_logits
experimentally_resolved_logits = self.experimentally_resolved(
outputs["single"]
)
aux_out[
"experimentally_resolved_logits"
] = experimentally_resolved_logits
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
tm_logits, **self.config.tm
)
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
**self.config.tm,
)
)
return aux_out
class PerResidueLDDTCaPredictor(nn.Module):
def __init__(self, no_bins, c_in, c_hidden):
super(PerResidueLDDTCaPredictor, self).__init__()
self.no_bins = no_bins
self.c_in = c_in
self.c_hidden = c_hidden
self.layer_norm = LayerNorm(self.c_in)
self.linear_1 = Linear(self.c_in, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="relu")
self.linear_3 = Linear(self.c_hidden, self.no_bins, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s = self.layer_norm(s)
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
return s
class DistogramHead(nn.Module):
"""
Computes a distogram probability distribution.
For use in computation of distogram loss, subsection 1.9.8
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of distogram bins
"""
super(DistogramHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z): # [*, N, N, C_z]
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N, N, no_bins] distogram probability distribution
"""
# [*, N, N, no_bins]
logits = self.linear(z)
logits = logits + logits.transpose(-2, -3)
return logits
class TMScoreHead(nn.Module):
"""
For use in computation of TM-score, subsection 1.9.7
"""
def __init__(self, c_z, no_bins, **kwargs):
"""
Args:
c_z:
Input channel dimension
no_bins:
Number of bins
"""
super(TMScoreHead, self).__init__()
self.c_z = c_z
self.no_bins = no_bins
self.linear = Linear(self.c_z, self.no_bins, init="final")
def forward(self, z):
"""
Args:
z:
[*, N_res, N_res, C_z] pairwise embedding
Returns:
[*, N_res, N_res, no_bins] prediction
"""
# [*, N, N, no_bins]
logits = self.linear(z)
return logits
class MaskedMSAHead(nn.Module):
"""
For use in computation of masked MSA loss, subsection 1.9.9
"""
def __init__(self, c_m, c_out, **kwargs):
"""
Args:
c_m:
MSA channel dimension
c_out:
Output channel dimension
"""
super(MaskedMSAHead, self).__init__()
self.c_m = c_m
self.c_out = c_out
self.linear = Linear(self.c_m, self.c_out, init="final")
def forward(self, m):
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
Returns:
[*, N_seq, N_res, C_out] reconstruction
"""
# [*, N_seq, N_res, C_out]
logits = self.linear(m)
return logits
class ExperimentallyResolvedHead(nn.Module):
"""
For use in computation of "experimentally resolved" loss, subsection
1.9.10
"""
def __init__(self, c_s, c_out, **kwargs):
"""
Args:
c_s:
Input channel dimension
c_out:
Number of distogram bins
"""
super(ExperimentallyResolvedHead, self).__init__()
self.c_s = c_s
self.c_out = c_out
self.linear = Linear(self.c_s, self.c_out, init="final")
def forward(self, s):
"""
Args:
s:
[*, N_res, C_s] single embedding
Returns:
[*, N, C_out] logits
"""
# [*, N, C_out]
logits = self.linear(s)
return logits
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from typing import Optional, List, Tuple
from fastfold.model.nn.primitives import (
Linear,
LayerNorm,
Attention,
GlobalAttention,
_attention_chunked_trainable,
)
from fastfold.utils.checkpointing import get_checkpoint_fn
from fastfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class MSAAttention(nn.Module):
def __init__(
self,
c_in,
c_hidden,
no_heads,
pair_bias=False,
c_z=None,
inf=1e9,
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
pair_bias:
Whether to use pair embedding bias
c_z:
Pair embedding channel dimension. Ignored unless pair_bias
is true
inf:
A large number to be used in computing the attention mask
"""
super(MSAAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.pair_bias = pair_bias
self.c_z = c_z
self.inf = inf
self.layer_norm_m = LayerNorm(self.c_in)
self.layer_norm_z = None
self.linear_z = None
if self.pair_bias:
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_z = Linear(
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self.mha,
{"q_x": m, "kv_x": m, "biases": biases},
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def _prep_inputs(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, N_seq, N_res, C_m]
m = self.layer_norm_m(m)
n_seq, n_res = m.shape[-3:-1]
if mask is None:
# [*, N_seq, N_res]
mask = m.new_ones(
m.shape[:-3] + (n_seq, n_res),
)
# [*, N_seq, 1, 1, N_res]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# This step simply returns a larger view of the bias, and does not
# consume additional memory.
# [*, N_seq, no_heads, N_res, N_res]
#bias = bias.expand(
# ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
#)
if (self.pair_bias and
z is not None and # For the
self.layer_norm_z is not None and # benefit of
self.linear_z is not None # TorchScript
):
# [*, N_res, N_res, C_z]
z = self.layer_norm_z(z)
# [*, N_res, N_res, no_heads]
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
return m, mask_bias, z
@torch.jit.ignore
def _chunked_msa_attn(self,
m: torch.Tensor,
z: Optional[torch.Tensor],
mask: Optional[torch.Tensor],
chunk_logits: int,
checkpoint: bool,
) -> torch.Tensor:
MSA_DIM = -4
def _get_qkv(m, z):
m, mask_bias, z = self._prep_inputs(m, z, mask)
q, k, v = self.mha._prep_qkv(m, m)
return m, q, k, v, mask_bias, z
checkpoint_fn = get_checkpoint_fn()
if(torch.is_grad_enabled() and checkpoint):
m, q, k, v, mask_bias, z = checkpoint_fn(_get_qkv, m, z)
else:
m, q, k, v, mask_bias, z = _get_qkv(m, z)
o = _attention_chunked_trainable(
query=q,
key=k,
value=v,
biases=[mask_bias, z],
chunk_size=chunk_logits,
chunk_dim=MSA_DIM,
checkpoint=checkpoint,
)
if(torch.is_grad_enabled() and checkpoint):
# Storing an additional m here is far from ideal
m = checkpoint_fn(self.mha._wrap_up, o, m)
else:
m = self.mha._wrap_up(o, m)
return m
def forward(self,
m: torch.Tensor,
z: Optional[torch.Tensor] = None,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
_chunk_logits: Optional[int] = None,
_checkpoint_chunks: Optional[bool] = None,
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
z:
[*, N_res, N_res, C_z] pair embedding. Required only if
pair_bias is True
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
if(_chunk_logits is not None):
return self._chunked_msa_attn(
m=m, z=z, mask=mask,
chunk_logits=_chunk_logits, checkpoint=_checkpoint_chunks
)
m, mask_bias, z = self._prep_inputs(m, z, mask)
biases = [mask_bias]
if(z is not None):
biases.append(z)
if chunk_size is not None:
m = self._chunk(m, biases, chunk_size)
else:
m = self.mha(
q_x=m,
kv_x=m,
biases=biases
)
return m
class MSARowAttentionWithPairBias(MSAAttention):
"""
Implements Algorithm 7.
"""
def __init__(self, c_m, c_z, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
Input channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSARowAttentionWithPairBias, self).__init__(
c_m,
c_hidden,
no_heads,
pair_bias=True,
c_z=c_z,
inf=inf,
)
class MSAColumnAttention(nn.Module):
"""
Implements Algorithm 8.
By rights, this should also be a subclass of MSAAttention. Alas,
most inheritance isn't supported by TorchScript.
"""
def __init__(self, c_m, c_hidden, no_heads, inf=1e9):
"""
Args:
c_m:
MSA channel dimension
c_hidden:
Per-head hidden channel dimension
no_heads:
Number of attention heads
inf:
Large number used to construct attention masks
"""
super(MSAColumnAttention, self).__init__()
self.c_m = c_m
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self._msa_att = MSAAttention(
c_in=c_m,
c_hidden=c_hidden,
no_heads=no_heads,
pair_bias=False,
c_z=None,
inf=inf,
)
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
chunk_size:
Size of chunks into which the inputs are split along their
batch dimensions. A low value decreases memory overhead at the
cost of slower execution. Chunking is not performed by default.
"""
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
m = self._msa_att(m, mask=mask, chunk_size=chunk_size)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
if mask is not None:
mask = mask.transpose(-1, -2)
return m
class MSAColumnGlobalAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, inf=1e9, eps=1e-10,
):
super(MSAColumnGlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.layer_norm_m = nn.LayerNorm(c_in)
self.global_attention = GlobalAttention(
c_in=c_in,
c_hidden=c_hidden,
no_heads=no_heads,
inf=inf,
eps=eps,
)
@torch.jit.ignore
def _chunk(self,
m: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
mha_input = {
"m": m,
"mask": mask,
}
return chunk_layer(
self.global_attention,
mha_input,
chunk_size=chunk_size,
no_batch_dims=len(m.shape[:-2]),
)
def forward(
self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if mask is None:
# [*, N_seq, N_res]
mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, N_res, N_seq, C_in]
m = self.layer_norm_m(m)
if chunk_size is not None:
m = self._chunk(m, mask, chunk_size)
else:
m = self.global_attention(m=m, mask=mask)
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
return m
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Optional
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear
from fastfold.utils.tensor_utils import chunk_layer
class OuterProductMean(nn.Module):
"""
Implements Algorithm 10.
"""
def __init__(self, c_m, c_z, c_hidden, eps=1e-3):
"""
Args:
c_m:
MSA embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(OuterProductMean, self).__init__()
self.c_m = c_m
self.c_z = c_z
self.c_hidden = c_hidden
self.eps = eps
self.layer_norm = nn.LayerNorm(c_m)
self.linear_1 = Linear(c_m, c_hidden)
self.linear_2 = Linear(c_m, c_hidden)
self.linear_out = Linear(c_hidden ** 2, c_z, init="final")
def _opm(self, a, b):
# [*, N_res, N_res, C, C]
outer = torch.einsum("...bac,...dae->...bdce", a, b)
# [*, N_res, N_res, C * C]
outer = outer.reshape(outer.shape[:-2] + (-1,))
# [*, N_res, N_res, C_z]
outer = self.linear_out(outer)
return outer
@torch.jit.ignore
def _chunk(self,
a: torch.Tensor,
b: torch.Tensor,
chunk_size: int
) -> torch.Tensor:
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
a_reshape = a.reshape((-1,) + a.shape[-3:])
b_reshape = b.reshape((-1,) + b.shape[-3:])
out = []
for a_prime, b_prime in zip(a_reshape, b_reshape):
outer = chunk_layer(
partial(self._opm, b=b_prime),
{"a": a_prime},
chunk_size=chunk_size,
no_batch_dims=1,
)
out.append(outer)
outer = torch.stack(out, dim=0)
outer = outer.reshape(a.shape[:-3] + outer.shape[1:])
return outer
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
m:
[*, N_seq, N_res, C_m] MSA embedding
mask:
[*, N_seq, N_res] MSA mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if mask is None:
mask = m.new_ones(m.shape[:-1])
# [*, N_seq, N_res, C_m]
m = self.layer_norm(m)
# [*, N_seq, N_res, C]
mask = mask.unsqueeze(-1)
a = self.linear_1(m) * mask
b = self.linear_2(m) * mask
a = a.transpose(-2, -3)
b = b.transpose(-2, -3)
if chunk_size is not None:
outer = self._chunk(a, b, chunk_size)
else:
outer = self._opm(a, b)
# [*, N_res, N_res, 1]
norm = torch.einsum("...abc,...adc->...bdc", mask, mask)
# [*, N_res, N_res, C_z]
outer = outer / (self.eps + norm)
return outer
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import chunk_layer
class PairTransition(nn.Module):
"""
Implements Algorithm 15.
"""
def __init__(self, c_z, n):
"""
Args:
c_z:
Pair transition channel dimension
n:
Factor by which c_z is multiplied to obtain hidden channel
dimension
"""
super(PairTransition, self).__init__()
self.c_z = c_z
self.n = n
self.layer_norm = LayerNorm(self.c_z)
self.linear_1 = Linear(self.c_z, self.n * self.c_z, init="relu")
self.relu = nn.ReLU()
self.linear_2 = Linear(self.n * self.c_z, c_z, init="final")
def _transition(self, z, mask):
# [*, N_res, N_res, C_hidden]
z = self.linear_1(z)
z = self.relu(z)
# [*, N_res, N_res, C_z]
z = self.linear_2(z) * mask
return z
@torch.jit.ignore
def _chunk(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: int,
) -> torch.Tensor:
return chunk_layer(
self._transition,
{"z": z, "mask": mask},
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
z:
[*, N_res, N_res, C_z] pair embedding
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
# DISCREPANCY: DeepMind forgets to apply the mask in this module.
if mask is None:
mask = z.new_ones(z.shape[:-1])
# [*, N_res, N_res, 1]
mask = mask.unsqueeze(-1)
# [*, N_res, N_res, C_z]
z = self.layer_norm(z)
if chunk_size is not None:
z = self._chunk(z, mask, chunk_size)
else:
z = self._transition(z=z, mask=mask)
return z
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Optional, Callable, List, Tuple, Sequence
import numpy as np
import torch
import torch.nn as nn
from scipy.stats import truncnorm
from fastfold.utils.checkpointing import get_checkpoint_fn
from fastfold.utils.tensor_utils import (
permute_final_dims,
flatten_final_dims,
_chunk_slice,
)
def _prod(nums):
out = 1
for n in nums:
out = out * n
return out
def _calculate_fan(linear_weight_shape, fan="fan_in"):
fan_out, fan_in = linear_weight_shape
if fan == "fan_in":
f = fan_in
elif fan == "fan_out":
f = fan_out
elif fan == "fan_avg":
f = (fan_in + fan_out) / 2
else:
raise ValueError("Invalid fan option")
return f
def trunc_normal_init_(weights, scale=1.0, fan="fan_in"):
shape = weights.shape
f = _calculate_fan(shape, fan)
scale = scale / max(1, f)
a = -2
b = 2
std = math.sqrt(scale) / truncnorm.std(a=a, b=b, loc=0, scale=1)
size = _prod(shape)
samples = truncnorm.rvs(a=a, b=b, loc=0, scale=std, size=size)
samples = np.reshape(samples, shape)
with torch.no_grad():
weights.copy_(torch.tensor(samples, device=weights.device))
def lecun_normal_init_(weights):
trunc_normal_init_(weights, scale=1.0)
def he_normal_init_(weights):
trunc_normal_init_(weights, scale=2.0)
def glorot_uniform_init_(weights):
nn.init.xavier_uniform_(weights, gain=1)
def final_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def gating_init_(weights):
with torch.no_grad():
weights.fill_(0.0)
def normal_init_(weights):
torch.nn.init.kaiming_normal_(weights, nonlinearity="linear")
def ipa_point_weights_init_(weights):
with torch.no_grad():
softplus_inverse_1 = 0.541324854612918
weights.fill_(softplus_inverse_1)
class Linear(nn.Linear):
"""
A Linear layer with built-in nonstandard initializations. Called just
like torch.nn.Linear.
Implements the initializers in 1.11.4, plus some additional ones found
in the code.
"""
def __init__(
self,
in_dim: int,
out_dim: int,
bias: bool = True,
init: str = "default",
init_fn: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None,
):
"""
Args:
in_dim:
The final dimension of inputs to the layer
out_dim:
The final dimension of layer outputs
bias:
Whether to learn an additive bias. True by default
init:
The initializer to use. Choose from:
"default": LeCun fan-in truncated normal initialization
"relu": He initialization w/ truncated normal distribution
"glorot": Fan-average Glorot uniform initialization
"gating": Weights=0, Bias=1
"normal": Normal initialization with std=1/sqrt(fan_in)
"final": Weights=0, Bias=0
Overridden by init_fn if the latter is not None.
init_fn:
A custom initializer taking weight and bias as inputs.
Overrides init if not None.
"""
super(Linear, self).__init__(in_dim, out_dim, bias=bias)
if bias:
with torch.no_grad():
self.bias.fill_(0)
if init_fn is not None:
init_fn(self.weight, self.bias)
else:
if init == "default":
lecun_normal_init_(self.weight)
elif init == "relu":
he_normal_init_(self.weight)
elif init == "glorot":
glorot_uniform_init_(self.weight)
elif init == "gating":
gating_init_(self.weight)
if bias:
with torch.no_grad():
self.bias.fill_(1.0)
elif init == "normal":
normal_init_(self.weight)
elif init == "final":
final_init_(self.weight)
else:
raise ValueError("Invalid init string.")
class LayerNorm(nn.Module):
def __init__(self, c_in, eps=1e-5):
super(LayerNorm, self).__init__()
self.c_in = (c_in,)
self.eps = eps
self.weight = nn.Parameter(torch.ones(c_in))
self.bias = nn.Parameter(torch.zeros(c_in))
def forward(self, x):
out = nn.functional.layer_norm(
x,
self.c_in,
self.weight,
self.bias,
self.eps,
)
return out
@torch.jit.ignore
def softmax(t: torch.Tensor, dim: int = -1) -> torch.Tensor:
"""
Softmax, but without automatic casting to fp32 when the input is of
type bfloat16
"""
s = torch.nn.functional.softmax(t, dim=dim)
return s
#@torch.jit.script
def _attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
biases: List[torch.Tensor]) -> torch.Tensor:
# [*, H, Q, C_hidden]
query = permute_final_dims(query, (1, 0, 2))
# [*, H, C_hidden, K]
key = permute_final_dims(key, (1, 2, 0))
# [*, H, V, C_hidden]
value = permute_final_dims(value, (1, 0, 2))
# [*, H, Q, K]
a = torch.matmul(query, key)
for b in biases:
a += b
a = softmax(a, -1)
# [*, H, Q, C_hidden]
a = torch.matmul(a, value)
# [*, Q, H, C_hidden]
a = a.transpose(-2, -3)
return a
@torch.jit.ignore
def _attention_chunked_trainable(
query,
key,
value,
biases,
chunk_size,
chunk_dim,
checkpoint,
):
if (checkpoint and len(biases) > 2):
raise ValueError("Checkpointed version permits only permits two bias terms")
def _checkpointable_attention(q, k, v, b1, b2):
bs = [b for b in [b1, b2] if b is not None]
return _attention(q, k, v, bs)
o_chunks = []
checkpoint_fn = get_checkpoint_fn()
count = query.shape[chunk_dim]
for start in range(0, count, chunk_size):
end = start + chunk_size
idx = [slice(None)] * len(query.shape)
idx[chunk_dim] = slice(start, end)
idx_tup = tuple(idx)
q_chunk = query[idx_tup]
k_chunk = key[idx_tup]
v_chunk = value[idx_tup]
def _slice_bias(b):
idx[chunk_dim] = (slice(start, end) if b.shape[chunk_dim] != 1 else slice(None))
return b[tuple(idx)]
if (checkpoint):
bias_1_chunk, bias_2_chunk = [
_slice_bias(b) if b is not None else None for b in (biases + [None, None])[:2]
]
o_chunk = checkpoint_fn(_checkpointable_attention, q_chunk, k_chunk, v_chunk,
bias_1_chunk, bias_2_chunk)
else:
bias_chunks = [_slice_bias(b) for b in biases]
o_chunk = _attention(q_chunk, k_chunk, v_chunk, bias_chunks)
o_chunks.append(o_chunk)
o = torch.cat(o_chunks, dim=chunk_dim)
return o
class Attention(nn.Module):
"""
Standard multi-head attention using AlphaFold's default layer
initialization. Allows multiple bias vectors.
"""
def __init__(
self,
c_q: int,
c_k: int,
c_v: int,
c_hidden: int,
no_heads: int,
gating: bool = True,
):
"""
Args:
c_q:
Input dimension of query data
c_k:
Input dimension of key data
c_v:
Input dimension of value data
c_hidden:
Per-head hidden dimension
no_heads:
Number of attention heads
gating:
Whether the output should be gated using query data
"""
super(Attention, self).__init__()
self.c_q = c_q
self.c_k = c_k
self.c_v = c_v
self.c_hidden = c_hidden
self.no_heads = no_heads
self.gating = gating
# DISCREPANCY: c_hidden is not the per-head channel dimension, as
# stated in the supplement, but the overall channel dimension.
self.linear_q = Linear(self.c_q, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_k = Linear(self.c_k, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_v = Linear(self.c_v, self.c_hidden * self.no_heads, bias=False, init="glorot")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_q, init="final")
self.linear_g = None
if self.gating:
self.linear_g = Linear(self.c_q, self.c_hidden * self.no_heads, init="gating")
self.sigmoid = nn.Sigmoid()
def _prep_qkv(self, q_x: torch.Tensor,
kv_x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(kv_x)
v = self.linear_v(kv_x)
# [*, Q/K, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
k = k.view(k.shape[:-1] + (self.no_heads, -1))
v = v.view(v.shape[:-1] + (self.no_heads, -1))
q /= math.sqrt(self.c_hidden)
return q, k, v
def _wrap_up(self, o: torch.Tensor, q_x: torch.Tensor) -> torch.Tensor:
if (self.linear_g is not None):
g = self.sigmoid(self.linear_g(q_x))
# [*, Q, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
o = o * g
# [*, Q, H * C_hidden]
o = flatten_final_dims(o, 2)
# [*, Q, C_q]
o = self.linear_o(o)
return o
def forward(
self,
q_x: torch.Tensor,
kv_x: torch.Tensor,
biases: Optional[List[torch.Tensor]] = None,
use_lma: bool = False,
q_chunk_size: Optional[int] = None,
kv_chunk_size: Optional[int] = None,
) -> torch.Tensor:
"""
Args:
q_x:
[*, Q, C_q] query data
kv_x:
[*, K, C_k] key data
biases:
List of biases that broadcast to [*, H, Q, K]
use_lma:
Whether to use low-memory attention
q_chunk_size:
Query chunk size (for LMA)
kv_chunk_size:
Key/Value chunk size (for LMA)
Returns
[*, Q, C_q] attention update
"""
if (biases is None):
biases = []
if (use_lma and (q_chunk_size is None or kv_chunk_size is None)):
raise ValueError("If use_lma is specified, q_chunk_size and kv_chunk_size must "
"be provided")
q, k, v = self._prep_qkv(q_x, kv_x)
if (use_lma):
biases = [b.expand(b.shape[:-2] + (q_x.shape[-2],) + (kv_x.shape[-2],)) for b in biases]
o = _lma(q, k, v, biases, q_chunk_size, kv_chunk_size)
else:
o = _attention(q, k, v, biases)
o = self._wrap_up(o, q_x)
return o
class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(c_in, c_hidden * no_heads, bias=False, init="glorot")
self.linear_k = Linear(
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_v = Linear(
c_in,
c_hidden,
bias=False,
init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = torch.sum(m * mask.unsqueeze(-1),
dim=-2) / (torch.sum(mask, dim=-1)[..., None] + self.eps)
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q *= (self.c_hidden**(-0.5))
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def _lma(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
biases: List[torch.Tensor],
q_chunk_size: int,
kv_chunk_size: int,
):
no_q, no_kv = q.shape[-3], k.shape[-3]
# [*, Q, H, C_hidden]
o = q.new_zeros(q.shape)
for q_s in range(0, no_q, q_chunk_size):
q_chunk = q[..., q_s:q_s + q_chunk_size, :, :]
large_bias_chunks = [b[..., q_s:q_s + q_chunk_size, :] for b in biases]
maxes = []
weights = []
values = []
for kv_s in range(0, no_kv, kv_chunk_size):
k_chunk = k[..., kv_s:kv_s + kv_chunk_size, :, :]
v_chunk = v[..., kv_s:kv_s + kv_chunk_size, :, :]
small_bias_chunks = [b[..., kv_s:kv_s + kv_chunk_size] for b in large_bias_chunks]
a = torch.einsum(
"...qhd,...khd->...hqk",
q_chunk,
k_chunk,
)
for b in small_bias_chunks:
a += b
a = a.transpose(-2, -3)
max_a = torch.max(a, dim=-1, keepdim=True)[0]
exp_a = torch.exp(a - max_a)
exp_v = torch.einsum("...vhf,...qhv->...qhf", v_chunk, exp_a)
maxes.append(max_a.detach().squeeze(-1))
weights.append(torch.sum(exp_a, dim=-1))
values.append(exp_v)
chunk_max = torch.stack(maxes, dim=-3)
chunk_weights = torch.stack(weights, dim=-3)
chunk_values = torch.stack(values, dim=-4)
global_max = torch.max(chunk_max, dim=-3, keepdim=True)[0]
max_diffs = torch.exp(chunk_max - global_max)
chunk_values *= max_diffs.unsqueeze(-1)
chunk_weights *= max_diffs
all_values = torch.sum(chunk_values, dim=-4)
all_weights = torch.sum(chunk_weights.unsqueeze(-1), dim=-4)
q_chunk_out = all_values / all_weights
o[..., q_s:q_s + q_chunk_size, :, :] = q_chunk_out
return o
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import torch
import torch.nn as nn
from typing import Optional, Tuple
from fastfold.model.nn.primitives import Linear, LayerNorm, ipa_point_weights_init_
from fastfold.common.residue_constants import (
restype_rigid_group_default_frame,
restype_atom14_to_rigid_group,
restype_atom14_mask,
restype_atom14_rigid_group_positions,
)
from fastfold.utils.feats import (
frames_and_literature_positions_to_atom14_pos,
torsion_angles_to_frames,
)
from fastfold.utils.rigid_utils import Rotation, Rigid
from fastfold.utils.tensor_utils import (
dict_multimap,
permute_final_dims,
flatten_final_dims,
)
class AngleResnetBlock(nn.Module):
def __init__(self, c_hidden):
"""
Args:
c_hidden:
Hidden channel dimension
"""
super(AngleResnetBlock, self).__init__()
self.c_hidden = c_hidden
self.linear_1 = Linear(self.c_hidden, self.c_hidden, init="relu")
self.linear_2 = Linear(self.c_hidden, self.c_hidden, init="final")
self.relu = nn.ReLU()
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
a = self.relu(a)
a = self.linear_1(a)
a = self.relu(a)
a = self.linear_2(a)
return a + s_initial
class AngleResnet(nn.Module):
"""
Implements Algorithm 20, lines 11-14
"""
def __init__(self, c_in, c_hidden, no_blocks, no_angles, epsilon):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Hidden channel dimension
no_blocks:
Number of resnet blocks
no_angles:
Number of torsion angles to generate
epsilon:
Small constant for normalization
"""
super(AngleResnet, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_blocks = no_blocks
self.no_angles = no_angles
self.eps = epsilon
self.linear_in = Linear(self.c_in, self.c_hidden)
self.linear_initial = Linear(self.c_in, self.c_hidden)
self.layers = nn.ModuleList()
for _ in range(self.no_blocks):
layer = AngleResnetBlock(c_hidden=self.c_hidden)
self.layers.append(layer)
self.linear_out = Linear(self.c_hidden, self.no_angles * 2)
self.relu = nn.ReLU()
def forward(
self, s: torch.Tensor, s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
[*, C_hidden] single embedding
s_initial:
[*, C_hidden] single embedding as of the start of the
StructureModule
Returns:
[*, no_angles, 2] predicted angles
"""
# NOTE: The ReLU's applied to the inputs are absent from the supplement
# pseudocode but present in the source. For maximal compatibility with
# the pretrained weights, I'm going with the source.
# [*, C_hidden]
s_initial = self.relu(s_initial)
s_initial = self.linear_initial(s_initial)
s = self.relu(s)
s = self.linear_in(s)
s = s + s_initial
for l in self.layers:
s = l(s)
s = self.relu(s)
# [*, no_angles * 2]
s = self.linear_out(s)
# [*, no_angles, 2]
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s ** 2, dim=-1, keepdim=True),
min=self.eps,
)
)
s = s / norm_denom
return unnormalized_s, s
class InvariantPointAttention(nn.Module):
"""
Implements Algorithm 22.
"""
def __init__(
self,
c_s: int,
c_z: int,
c_hidden: int,
no_heads: int,
no_qk_points: int,
no_v_points: int,
inf: float = 1e5,
eps: float = 1e-8,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_hidden:
Hidden channel dimension
no_heads:
Number of attention heads
no_qk_points:
Number of query/key points to generate
no_v_points:
Number of value points to generate
"""
super(InvariantPointAttention, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.no_qk_points = no_qk_points
self.no_v_points = no_v_points
self.inf = inf
self.eps = eps
# These linear layers differ from their specifications in the
# supplement. There, they lack bias and use Glorot initialization.
# Here as in the official source, they have bias and use the default
# Lecun initialization.
hc = self.c_hidden * self.no_heads
self.linear_q = Linear(self.c_s, hc)
self.linear_kv = Linear(self.c_s, 2 * hc)
hpq = self.no_heads * self.no_qk_points * 3
self.linear_q_points = Linear(self.c_s, hpq)
hpkv = self.no_heads * (self.no_qk_points + self.no_v_points) * 3
self.linear_kv_points = Linear(self.c_s, hpkv)
hpv = self.no_heads * self.no_v_points * 3
self.linear_b = Linear(self.c_z, self.no_heads)
self.head_weights = nn.Parameter(torch.zeros((no_heads)))
ipa_point_weights_init_(self.head_weights)
concat_out_dim = self.no_heads * (
self.c_z + self.c_hidden + self.no_v_points * 4
)
self.linear_out = Linear(concat_out_dim, self.c_s, init="final")
self.softmax = nn.Softmax(dim=-1)
self.softplus = nn.Softplus()
def forward(
self,
s: torch.Tensor,
z: torch.Tensor,
r: Rigid,
mask: torch.Tensor,
) -> torch.Tensor:
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
r:
[*, N_res] transformation object
mask:
[*, N_res] mask
Returns:
[*, N_res, C_s] single representation update
"""
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
# [*, N_res, H * P_q * 3]
q_pts = self.linear_q_points(s)
# This is kind of clunky, but it's how the original does it
# [*, N_res, H * P_q, 3]
q_pts = torch.split(q_pts, q_pts.shape[-1] // 3, dim=-1)
q_pts = torch.stack(q_pts, dim=-1)
q_pts = r[..., None].apply(q_pts)
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
kv_pts = self.linear_kv_points(s)
# [*, N_res, H * (P_q + P_v), 3]
kv_pts = torch.split(kv_pts, kv_pts.shape[-1] // 3, dim=-1)
kv_pts = torch.stack(kv_pts, dim=-1)
kv_pts = r[..., None].apply(kv_pts)
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(kv_pts.shape[:-2] + (self.no_heads, -1, 3))
# [*, N_res, H, P_q/P_v, 3]
k_pts, v_pts = torch.split(
kv_pts, [self.no_qk_points, self.no_v_points], dim=-2
)
##########################
# Compute attention scores
##########################
# [*, N_res, N_res, H]
b = self.linear_b(z)
# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a *= math.sqrt(1.0 / (3 * self.c_hidden))
a += (math.sqrt(1.0 / 3) * permute_final_dims(b, (2, 0, 1)))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
pt_att = pt_att ** 2
# [*, N_res, N_res, H, P_q]
pt_att = sum(torch.unbind(pt_att, dim=-1))
head_weights = self.softplus(self.head_weights).view(
*((1,) * len(pt_att.shape[:-2]) + (-1, 1))
)
head_weights = head_weights * math.sqrt(
1.0 / (3 * (self.no_qk_points * 9.0 / 2))
)
pt_att = pt_att * head_weights
# [*, N_res, N_res, H]
pt_att = torch.sum(pt_att, dim=-1) * (-0.5)
# [*, N_res, N_res]
square_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
################
# Compute output
################
# [*, N_res, H, C_hidden]
o = torch.matmul(
a, v.transpose(-2, -3).to(dtype=a.dtype)
).transpose(-2, -3)
# [*, N_res, H * C_hidden]
o = flatten_final_dims(o, 2)
# As DeepMind explains, this manual matmul ensures that the operation
# happens in float32.
# [*, H, 3, N_res, P_v]
o_pt = torch.sum(
(
a[..., None, :, :, None]
* permute_final_dims(v_pts, (1, 3, 0, 2))[..., None, :, :]
),
dim=-2,
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = r[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
o_pt_norm = flatten_final_dims(
torch.sqrt(torch.sum(o_pt ** 2, dim=-1) + self.eps), 2
)
# [*, N_res, H * P_v, 3]
o_pt = o_pt.reshape(*o_pt.shape[:-3], -1, 3)
# [*, N_res, H, C_z]
o_pair = torch.matmul(a.transpose(-2, -3), z.to(dtype=a.dtype))
# [*, N_res, H * C_z]
o_pair = flatten_final_dims(o_pair, 2)
# [*, N_res, C_s]
s = self.linear_out(
torch.cat(
(o, *torch.unbind(o_pt, dim=-1), o_pt_norm, o_pair), dim=-1
).to(dtype=z.dtype)
)
return s
class BackboneUpdate(nn.Module):
"""
Implements part of Algorithm 23.
"""
def __init__(self, c_s):
"""
Args:
c_s:
Single representation channel dimension
"""
super(BackboneUpdate, self).__init__()
self.c_s = c_s
self.linear = Linear(self.c_s, 6, init="final")
def forward(self, s: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
[*, N_res, C_s] single representation
Returns:
[*, N_res, 6] update vector
"""
# [*, 6]
update = self.linear(s)
return update
class StructureModuleTransitionLayer(nn.Module):
def __init__(self, c):
super(StructureModuleTransitionLayer, self).__init__()
self.c = c
self.linear_1 = Linear(self.c, self.c, init="relu")
self.linear_2 = Linear(self.c, self.c, init="relu")
self.linear_3 = Linear(self.c, self.c, init="final")
self.relu = nn.ReLU()
def forward(self, s):
s_initial = s
s = self.linear_1(s)
s = self.relu(s)
s = self.linear_2(s)
s = self.relu(s)
s = self.linear_3(s)
s = s + s_initial
return s
class StructureModuleTransition(nn.Module):
def __init__(self, c, num_layers, dropout_rate):
super(StructureModuleTransition, self).__init__()
self.c = c
self.num_layers = num_layers
self.dropout_rate = dropout_rate
self.layers = nn.ModuleList()
for _ in range(self.num_layers):
l = StructureModuleTransitionLayer(self.c)
self.layers.append(l)
self.dropout = nn.Dropout(self.dropout_rate)
self.layer_norm = LayerNorm(self.c)
def forward(self, s):
for l in self.layers:
s = l(s)
s = self.dropout(s)
s = self.layer_norm(s)
return s
class StructureModule(nn.Module):
def __init__(
self,
c_s,
c_z,
c_ipa,
c_resnet,
no_heads_ipa,
no_qk_points,
no_v_points,
dropout_rate,
no_blocks,
no_transition_layers,
no_resnet_blocks,
no_angles,
trans_scale_factor,
epsilon,
inf,
**kwargs,
):
"""
Args:
c_s:
Single representation channel dimension
c_z:
Pair representation channel dimension
c_ipa:
IPA hidden channel dimension
c_resnet:
Angle resnet (Alg. 23 lines 11-14) hidden channel dimension
no_heads_ipa:
Number of IPA heads
no_qk_points:
Number of query/key points to generate during IPA
no_v_points:
Number of value points to generate during IPA
dropout_rate:
Dropout rate used throughout the layer
no_blocks:
Number of structure module blocks
no_transition_layers:
Number of layers in the single representation transition
(Alg. 23 lines 8-9)
no_resnet_blocks:
Number of blocks in the angle resnet
no_angles:
Number of angles to generate in the angle resnet
trans_scale_factor:
Scale of single representation transition hidden dimension
epsilon:
Small number used in angle resnet normalization
inf:
Large number used for attention masking
"""
super(StructureModule, self).__init__()
self.c_s = c_s
self.c_z = c_z
self.c_ipa = c_ipa
self.c_resnet = c_resnet
self.no_heads_ipa = no_heads_ipa
self.no_qk_points = no_qk_points
self.no_v_points = no_v_points
self.dropout_rate = dropout_rate
self.no_blocks = no_blocks
self.no_transition_layers = no_transition_layers
self.no_resnet_blocks = no_resnet_blocks
self.no_angles = no_angles
self.trans_scale_factor = trans_scale_factor
self.epsilon = epsilon
self.inf = inf
# To be lazily initialized later
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
self.linear_in = Linear(self.c_s, self.c_s)
self.ipa = InvariantPointAttention(
self.c_s,
self.c_z,
self.c_ipa,
self.no_heads_ipa,
self.no_qk_points,
self.no_v_points,
inf=self.inf,
eps=self.epsilon,
)
self.ipa_dropout = nn.Dropout(self.dropout_rate)
self.layer_norm_ipa = LayerNorm(self.c_s)
self.transition = StructureModuleTransition(
self.c_s,
self.no_transition_layers,
self.dropout_rate,
)
self.bb_update = BackboneUpdate(self.c_s)
self.angle_resnet = AngleResnet(
self.c_s,
self.c_resnet,
self.no_resnet_blocks,
self.no_angles,
self.epsilon,
)
def forward(
self,
s,
z,
aatype,
mask=None,
):
"""
Args:
s:
[*, N_res, C_s] single representation
z:
[*, N_res, N_res, C_z] pair representation
aatype:
[*, N_res] amino acid indices
mask:
Optional [*, N_res] sequence mask
Returns:
A dictionary of outputs
"""
if mask is None:
# [*, N]
mask = s.new_ones(s.shape[:-1])
# [*, N, C_s]
s = self.layer_norm_s(s)
# [*, N, N, C_z]
z = self.layer_norm_z(z)
# [*, N, C_s]
s_initial = s
s = self.linear_in(s)
# [*, N]
rigids = Rigid.identity(
s.shape[:-1],
s.dtype,
s.device,
self.training,
fmt="quat",
)
outputs = []
for i in range(self.no_blocks):
# [*, N, C_s]
s = s + self.ipa(s, z, rigids, mask)
s = self.ipa_dropout(s)
s = self.layer_norm_ipa(s)
s = self.transition(s)
# [*, N]
rigids = rigids.compose_q_update_vec(self.bb_update(s))
# To hew as closely as possible to AlphaFold, we convert our
# quaternion-based transformations to rotation-matrix ones
# here
backb_to_global = Rigid(
Rotation(
rot_mats=rigids.get_rots().get_rot_mats(),
quats=None
),
rigids.get_trans(),
)
backb_to_global = backb_to_global.scale_translation(
self.trans_scale_factor
)
# [*, N, 7, 2]
unnormalized_angles, angles = self.angle_resnet(s, s_initial)
all_frames_to_global = self.torsion_angles_to_frames(
backb_to_global,
angles,
aatype,
)
pred_xyz = self.frames_and_literature_positions_to_atom14_pos(
all_frames_to_global,
aatype,
)
scaled_rigids = rigids.scale_translation(self.trans_scale_factor)
preds = {
"frames": scaled_rigids.to_tensor_7(),
"sidechain_frames": all_frames_to_global.to_tensor_4x4(),
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
}
outputs.append(preds)
if i < (self.no_blocks - 1):
rigids = rigids.stop_rot_gradient()
outputs = dict_multimap(torch.stack, outputs)
outputs["single"] = s
return outputs
def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, r, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(alpha.dtype, alpha.device)
# Separated purely to make testing less annoying
return torsion_angles_to_frames(r, alpha, f, self.default_frames)
def frames_and_literature_positions_to_atom14_pos(
self, r, f # [*, N, 8] # [*, N]
):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(r.get_rots().dtype, r.get_rots().device)
return frames_and_literature_positions_to_atom14_pos(
r,
f,
self.default_frames,
self.group_idx,
self.atom_mask,
self.lit_positions,
)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import math
from typing import Optional, List
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm, Attention
from fastfold.model.nn.dropout import (
DropoutRowwise,
DropoutColumnwise,
)
from fastfold.model.nn.pair_transition import PairTransition
from fastfold.model.nn.triangular_attention import (
TriangleAttentionStartingNode,
TriangleAttentionEndingNode,
)
from fastfold.model.nn.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
)
from fastfold.utils.checkpointing import checkpoint_blocks
from fastfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TemplatePointwiseAttention(nn.Module):
"""
Implements Algorithm 17.
"""
def __init__(self, c_t, c_z, c_hidden, no_heads, inf, **kwargs):
"""
Args:
c_t:
Template embedding channel dimension
c_z:
Pair embedding channel dimension
c_hidden:
Hidden channel dimension
"""
super(TemplatePointwiseAttention, self).__init__()
self.c_t = c_t
self.c_z = c_z
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.mha = Attention(
self.c_z,
self.c_t,
self.c_t,
self.c_hidden,
self.no_heads,
gating=False,
)
def _chunk(self,
z: torch.Tensor,
t: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": z,
"kv_x": t,
"biases": biases,
}
return chunk_layer(
self.mha,
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(z.shape[:-2]),
)
def forward(self,
t: torch.Tensor,
z: torch.Tensor,
template_mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
z:
[*, N_res, N_res, C_t] pair embedding
template_mask:
[*, N_templ] template mask
Returns:
[*, N_res, N_res, C_z] pair embedding update
"""
if template_mask is None:
template_mask = t.new_ones(t.shape[:-3])
bias = self.inf * (template_mask[..., None, None, None, None, :] - 1)
# [*, N_res, N_res, 1, C_z]
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
biases = [bias]
if chunk_size is not None:
z = self._chunk(z, t, biases, chunk_size)
else:
z = self.mha(q_x=z, kv_x=t, biases=biases)
# [*, N_res, N_res, C_z]
z = z.squeeze(-2)
return z
class TemplatePairStackBlock(nn.Module):
def __init__(
self,
c_t: int,
c_hidden_tri_att: int,
c_hidden_tri_mul: int,
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
inf: float,
**kwargs,
):
super(TemplatePairStackBlock, self).__init__()
self.c_t = c_t
self.c_hidden_tri_att = c_hidden_tri_att
self.c_hidden_tri_mul = c_hidden_tri_mul
self.no_heads = no_heads
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.inf = inf
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
self.tri_att_start = TriangleAttentionStartingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
inf=inf,
)
self.tri_att_end = TriangleAttentionEndingNode(
self.c_t,
self.c_hidden_tri_att,
self.no_heads,
inf=inf,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition(
self.c_t,
self.pair_transition_n,
)
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
chunk_size: Optional[int] = None,
_mask_trans: bool = True
):
single_templates = [
t.unsqueeze(-4) for t in torch.unbind(z, dim=-4)
]
single_templates_masks = [
m.unsqueeze(-3) for m in torch.unbind(mask, dim=-3)
]
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = single + self.dropout_row(
self.tri_att_start(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_col(
self.tri_att_end(
single,
chunk_size=chunk_size,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_out(
single,
mask=single_mask
)
)
single = single + self.dropout_row(
self.tri_mul_in(
single,
mask=single_mask
)
)
single = single + self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
)
single_templates[i] = single
z = torch.cat(single_templates, dim=-4)
return z
class TemplatePairStack(nn.Module):
"""
Implements Algorithm 16.
"""
def __init__(
self,
c_t,
c_hidden_tri_att,
c_hidden_tri_mul,
no_blocks,
no_heads,
pair_transition_n,
dropout_rate,
blocks_per_ckpt,
inf=1e9,
**kwargs,
):
"""
Args:
c_t:
Template embedding channel dimension
c_hidden_tri_att:
Per-head hidden dimension for triangular attention
c_hidden_tri_att:
Hidden dimension for triangular multiplication
no_blocks:
Number of blocks in the stack
pair_transition_n:
Scale of pair transition (Alg. 15) hidden dimension
dropout_rate:
Dropout rate used throughout the stack
blocks_per_ckpt:
Number of blocks per activation checkpoint. None disables
activation checkpointing
"""
super(TemplatePairStack, self).__init__()
self.blocks_per_ckpt = blocks_per_ckpt
self.blocks = nn.ModuleList()
for _ in range(no_blocks):
block = TemplatePairStackBlock(
c_t=c_t,
c_hidden_tri_att=c_hidden_tri_att,
c_hidden_tri_mul=c_hidden_tri_mul,
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
inf=inf,
)
self.blocks.append(block)
self.layer_norm = LayerNorm(c_t)
def forward(
self,
t: torch.tensor,
mask: torch.tensor,
chunk_size: int,
_mask_trans: bool = True,
):
"""
Args:
t:
[*, N_templ, N_res, N_res, C_t] template embedding
mask:
[*, N_templ, N_res, N_res] mask
Returns:
[*, N_templ, N_res, N_res, C_t] template embedding update
"""
if(mask.shape[-3] == 1):
expand_idx = list(mask.shape)
expand_idx[-3] = t.shape[-4]
mask = mask.expand(*expand_idx)
t, = checkpoint_blocks(
blocks=[
partial(
b,
mask=mask,
chunk_size=chunk_size,
_mask_trans=_mask_trans,
)
for b in self.blocks
],
args=(t,),
blocks_per_ckpt=self.blocks_per_ckpt if self.training else None,
)
t = self.layer_norm(t)
return t
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod, partial
import math
from typing import Optional, List
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm, Attention
from fastfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
flatten_final_dims,
)
class TriangleAttention(nn.Module):
def __init__(
self, c_in, c_hidden, no_heads, starting, inf=1e9
):
"""
Args:
c_in:
Input channel dimension
c_hidden:
Overall hidden channel dimension (not per-head)
no_heads:
Number of attention heads
"""
super(TriangleAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.starting = starting
self.inf = inf
self.layer_norm = LayerNorm(self.c_in)
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_hidden, self.no_heads
)
@torch.jit.ignore
def _chunk(self,
x: torch.Tensor,
biases: List[torch.Tensor],
chunk_size: int,
) -> torch.Tensor:
mha_inputs = {
"q_x": x,
"kv_x": x,
"biases": biases,
}
return chunk_layer(
partial(self.mha),
mha_inputs,
chunk_size=chunk_size,
no_batch_dims=len(x.shape[:-2]),
)
def forward(self,
x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
chunk_size: Optional[int] = None
) -> torch.Tensor:
"""
Args:
x:
[*, I, J, C_in] input tensor (e.g. the pair representation)
Returns:
[*, I, J, C_in] output tensor
"""
if mask is None:
# [*, I, J]
mask = x.new_ones(
x.shape[:-1],
)
# Shape annotations assume self.starting. Else, I and J are flipped
if not self.starting:
x = x.transpose(-2, -3)
mask = mask.transpose(-1, -2)
# [*, I, J, C_in]
x = self.layer_norm(x)
# [*, I, 1, 1, J]
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
biases = [mask_bias, triangle_bias]
if chunk_size is not None:
x = self._chunk(x, biases, chunk_size)
else:
x = self.mha(q_x=x, kv_x=x, biases=biases)
if not self.starting:
x = x.transpose(-2, -3)
return x
class TriangleAttentionStartingNode(TriangleAttention):
"""
Implements Algorithm 13.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=True)
class TriangleAttentionEndingNode(TriangleAttention):
"""
Implements Algorithm 14.
"""
__init__ = partialmethod(TriangleAttention.__init__, starting=False)
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partialmethod
from typing import Optional
import torch
import torch.nn as nn
from fastfold.model.nn.primitives import Linear, LayerNorm
from fastfold.utils.tensor_utils import permute_final_dims
class TriangleMultiplicativeUpdate(nn.Module):
"""
Implements Algorithms 11 and 12.
"""
def __init__(self, c_z, c_hidden, _outgoing=True):
"""
Args:
c_z:
Input channel dimension
c:
Hidden channel dimension
"""
super(TriangleMultiplicativeUpdate, self).__init__()
self.c_z = c_z
self.c_hidden = c_hidden
self._outgoing = _outgoing
self.linear_a_p = Linear(self.c_z, self.c_hidden)
self.linear_a_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_b_p = Linear(self.c_z, self.c_hidden)
self.linear_b_g = Linear(self.c_z, self.c_hidden, init="gating")
self.linear_g = Linear(self.c_z, self.c_z, init="gating")
self.linear_z = Linear(self.c_hidden, self.c_z, init="final")
self.layer_norm_in = LayerNorm(self.c_z)
self.layer_norm_out = LayerNorm(self.c_hidden)
self.sigmoid = nn.Sigmoid()
def _combine_projections(self,
a: torch.Tensor,
b: torch.Tensor,
) -> torch.Tensor:
raise NotImplementedError("This method needs to be overridden")
def forward(self,
z: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Args:
x:
[*, N_res, N_res, C_z] input tensor
mask:
[*, N_res, N_res] input mask
Returns:
[*, N_res, N_res, C_z] output tensor
"""
if mask is None:
mask = z.new_ones(z.shape[:-1])
mask = mask.unsqueeze(-1)
z = self.layer_norm_in(z)
a = self.linear_a_p(z) * self.sigmoid(self.linear_a_g(z))
a = a * mask
b = self.linear_b_p(z) * self.sigmoid(self.linear_b_g(z))
b = b * mask
x = self._combine_projections(a, b)
x = self.layer_norm_out(x)
x = self.linear_z(x)
g = self.sigmoid(self.linear_g(z))
z = x * g
return z
class TriangleMultiplicationOutgoing(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 11.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_i, N_k, C]
b: torch.Tensor, # [*, N_j, N_k, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
class TriangleMultiplicationIncoming(TriangleMultiplicativeUpdate):
"""
Implements Algorithm 12.
"""
def _combine_projections(self,
a: torch.Tensor, # [*, N_k, N_i, C]
b: torch.Tensor, # [*, N_k, N_j, C]
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, (1, 2, 0))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment