Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
...@@ -13,12 +13,26 @@ ...@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, Optional from typing import Tuple, Optional
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
dgram_from_positions,
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.model.primitives import Linear, LayerNorm from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, one_hot from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module): class InputEmbedder(nn.Module):
...@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module): ...@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
tf: batch: Dict containing
"target_feat" features of shape [*, N_res, tf_dim] "target_feat":
ri: Features of shape [*, N_res, tf_dim]
"residue_index" features of shape [*, N_res] "residue_index":
msa: Features of shape [*, N_res]
"msa_feat" features of shape [*, N_clust, N_res, msa_dim] "msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns: Returns:
msa_emb: msa_emb:
[*, N_clust, N_res, C_m] MSA embedding [*, N_clust, N_res, C_m] MSA embedding
...@@ -139,6 +154,162 @@ class InputEmbedder(nn.Module): ...@@ -139,6 +154,162 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if(self.use_chain_relative):
self.no_bins = (
2 * max_relative_idx + 2 +
1 +
2 * max_relative_chain + 2
)
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if(self.use_chain_relative):
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset,
boundaries,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) *
torch.ones_like(clipped_rel_chain)
)
boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain,
boundaries,
)
rel_feats.append(rel_chain)
else:
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(
self.linear_relpos.weight.dtype
)
return self.linear_relpos(rel_feat)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module): class RecyclingEmbedder(nn.Module):
""" """
Embeds the output of an iteration of the model for recycling. Embeds the output of an iteration of the model for recycling.
...@@ -365,3 +536,345 @@ class ExtraMSAEmbedder(nn.Module): ...@@ -365,3 +536,345 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x) x = self.linear(x)
return x return x
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
use_lma=False,
inplace_safe=False
):
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if (not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
use_lma=False,
inplace_safe=False
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
This diff is collapsed.
...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module): ...@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled: if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"]) tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm( aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm tm_logits, **self.config.tm
) )
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update( aux_out.update(
compute_predicted_aligned_error( compute_predicted_aligned_error(
tm_logits, tm_logits,
......
...@@ -18,11 +18,20 @@ import weakref ...@@ -18,11 +18,20 @@ import weakref
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.data import data_transforms_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
build_extra_msa_feat,
dgram_from_positions,
atom14_to_atom37,
)
from openfold.utils.tensor_utils import masked_mean
from openfold.model.embedders import ( from openfold.model.embedders import (
InputEmbedder, InputEmbedder,
InputEmbedderMultimer,
RecyclingEmbedder, RecyclingEmbedder,
TemplateAngleEmbedder, TemplateEmbedder,
TemplatePairEmbedder, TemplateEmbedderMultimer,
ExtraMSAEmbedder, ExtraMSAEmbedder,
) )
from openfold.model.evoformer import EvoformerStack, ExtraMSAStack from openfold.model.evoformer import EvoformerStack, ExtraMSAStack
...@@ -73,28 +82,30 @@ class AlphaFold(nn.Module): ...@@ -73,28 +82,30 @@ class AlphaFold(nn.Module):
self.extra_msa_config = self.config.extra_msa self.extra_msa_config = self.config.extra_msa
# Main trunk + structure module # Main trunk + structure module
if(self.globals.is_multimer):
self.input_embedder = InputEmbedderMultimer(
**self.config["input_embedder"],
)
else:
self.input_embedder = InputEmbedder( self.input_embedder = InputEmbedder(
**self.config["input_embedder"], **self.config["input_embedder"],
) )
self.recycling_embedder = RecyclingEmbedder( self.recycling_embedder = RecyclingEmbedder(
**self.config["recycling_embedder"], **self.config["recycling_embedder"],
) )
if(self.template_config.enabled): if (self.template_config.enabled):
self.template_angle_embedder = TemplateAngleEmbedder( if(self.globals.is_multimer):
**self.template_config["template_angle_embedder"], self.template_embedder = TemplateEmbedderMultimer(
) self.template_config,
self.template_pair_embedder = TemplatePairEmbedder(
**self.template_config["template_pair_embedder"],
) )
self.template_pair_stack = TemplatePairStack( else:
**self.template_config["template_pair_stack"], self.template_embedder = TemplateEmbedder(
) self.template_config,
self.template_pointwise_att = TemplatePointwiseAttention(
**self.template_config["template_pointwise_attention"],
) )
if(self.extra_msa_config.enabled): if (self.extra_msa_config.enabled):
self.extra_msa_embedder = ExtraMSAEmbedder( self.extra_msa_embedder = ExtraMSAEmbedder(
**self.extra_msa_config["extra_msa_embedder"], **self.extra_msa_config["extra_msa_embedder"],
) )
...@@ -105,112 +116,87 @@ class AlphaFold(nn.Module): ...@@ -105,112 +116,87 @@ class AlphaFold(nn.Module):
self.evoformer = EvoformerStack( self.evoformer = EvoformerStack(
**self.config["evoformer_stack"], **self.config["evoformer_stack"],
) )
self.structure_module = StructureModule( self.structure_module = StructureModule(
is_multimer=self.globals.is_multimer,
**self.config["structure_module"], **self.config["structure_module"],
) )
self.aux_heads = AuxiliaryHeads( self.aux_heads = AuxiliaryHeads(
self.config["heads"], self.config["heads"],
) )
def embed_templates(self, batch, z, pair_mask, templ_dim, inplace_safe): def embed_templates(self, batch, feats, z, pair_mask, templ_dim, inplace_safe):
if(self.template_config.offload_templates): if (self.globals.is_multimer):
asym_id = feats["asym_id"]
multichain_mask_2d = (
asym_id[..., None] == asym_id[..., None, :]
)
template_embeds = self.template_embedder(
batch,
z,
pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
multichain_mask_2d=multichain_mask_2d,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
)
feats["template_torsion_angles_mask"] = (
template_embeds["template_mask"]
)
else:
if (self.template_config.offload_templates):
return embed_templates_offload(self, return embed_templates_offload(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
elif(self.template_config.average_templates): elif (self.template_config.average_templates):
return embed_templates_average(self, return embed_templates_average(self,
batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe, batch, z, pair_mask, templ_dim, inplace_safe=inplace_safe,
) )
# Embed the templates one at a time (with a poor man's vmap) template_embeds = self.template_embedder(
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if(inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.globals.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch, batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.template.use_unit_vector,
inf=self.config.template.inf,
eps=self.config.template.eps,
**self.config.template.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if(inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if(not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma,
inplace_safe=inplace_safe,
_mask_trans=self.config._mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z, z,
template_mask=batch["template_mask"].to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
templ_dim,
chunk_size=self.globals.chunk_size,
use_lma=self.globals.use_lma, use_lma=self.globals.use_lma,
inplace_safe=inplace_safe
) )
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0 return template_embeds
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if(inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.template.embed_angles: def tolerance_reached(self, prev_pos, next_pos, mask, no_batch_dims, eps=1e-8) -> bool:
template_angle_feat = build_template_angle_feat( """
batch Early stopping criteria based on criteria used in
) AF2Complex: https://www.nature.com/articles/s41467-022-29394-2
Args:
# [*, S_t, N, C_m] prev_pos: Previous atom positions in atom37/14 representation
a = self.template_angle_embedder(template_angle_feat) next_pos: Current atom positions in atom37/14 representation
mask: 1-D sequence mask
ret["template_angle_embedding"] = a eps: Epsilon used in square root calculation
Returns:
return ret Whether to stop recycling early based on the desired tolerance.
"""
def distances(points):
"""Compute all pairwise distances for a set of points."""
d = points[..., None, :] - points[..., None, :, :]
return torch.sqrt(torch.sum(d ** 2, dim=-1))
if self.config.recycle_early_stop_tolerance < 0:
return False
if no_batch_dims == 0:
prev_pos = prev_pos.unsqueeze(dim=0)
next_pos = next_pos.unsqueeze(dim=0)
mask = mask.unsqueeze(dim=0)
ca_idx = residue_constants.atom_order['CA']
sq_diff = (distances(prev_pos[..., ca_idx, :]) - distances(next_pos[..., ca_idx, :])) ** 2
mask = mask[..., None] * mask[..., None, :]
sq_diff = masked_mean(mask=mask, value=sq_diff, dim=list(range(len(mask.shape))))
diff = torch.sqrt(sq_diff + eps)
return diff <= self.config.recycle_early_stop_tolerance
def iteration(self, feats, prevs, _recycle=True): def iteration(self, feats, prevs, _recycle=True):
# Primary output dictionary # Primary output dictionary
...@@ -240,6 +226,12 @@ class AlphaFold(nn.Module): ...@@ -240,6 +226,12 @@ class AlphaFold(nn.Module):
## Initialize the MSA and pair representations ## Initialize the MSA and pair representations
if (self.globals.is_multimer):
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(feats)
else:
# m: [*, S_c, N, C_m] # m: [*, S_c, N, C_m]
# z: [*, N, N, C_z] # z: [*, N, N, C_z]
m, z = self.input_embedder( m, z = self.input_embedder(
...@@ -304,15 +296,17 @@ class AlphaFold(nn.Module): ...@@ -304,15 +296,17 @@ class AlphaFold(nn.Module):
# Deletions like these become significant for inference with large N, # Deletions like these become significant for inference with large N,
# where they free unused tensors and remove references to others such # where they free unused tensors and remove references to others such
# that they can be offloaded later # that they can be offloaded later
del m_1_prev, z_prev, x_prev, m_1_prev_emb, z_prev_emb del m_1_prev, z_prev, m_1_prev_emb, z_prev_emb
# Embed the templates + merge with MSA/pair embeddings # Embed the templates + merge with MSA/pair embeddings
if self.config.template.enabled: if self.config.template.enabled:
template_feats = { template_feats = {
k: v for k, v in feats.items() if k.startswith("template_") k: v for k, v in feats.items() if k.startswith("template_")
} }
template_embeds = self.embed_templates( template_embeds = self.embed_templates(
template_feats, template_feats,
feats,
z, z,
pair_mask.to(dtype=z.dtype), pair_mask.to(dtype=z.dtype),
no_batch_dims, no_batch_dims,
...@@ -325,24 +319,38 @@ class AlphaFold(nn.Module): ...@@ -325,24 +319,38 @@ class AlphaFold(nn.Module):
inplace_safe, inplace_safe,
) )
if "template_angle_embedding" in template_embeds: if(
"template_single_embedding" in template_embeds
):
# [*, S = S_c + S_t, N, C_m] # [*, S = S_c + S_t, N, C_m]
m = torch.cat( m = torch.cat(
[m, template_embeds["template_angle_embedding"]], [m, template_embeds["template_single_embedding"]],
dim=-3 dim=-3
) )
# [*, S, N] # [*, S, N]
if(not self.globals.is_multimer):
torsion_angles_mask = feats["template_torsion_angles_mask"] torsion_angles_mask = feats["template_torsion_angles_mask"]
msa_mask = torch.cat( msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], [feats["msa_mask"], torsion_angles_mask[..., 2]],
dim=-2 dim=-2
) )
else:
msa_mask = torch.cat(
[feats["msa_mask"], template_embeds["template_mask"]],
dim=-2,
)
# Embed extra MSA features + merge with pairwise embeddings # Embed extra MSA features + merge with pairwise embeddings
if self.config.extra_msa.enabled: if self.config.extra_msa.enabled:
if(self.globals.is_multimer):
extra_msa_fn = data_transforms_multimer.build_extra_msa_feat
else:
extra_msa_fn = build_extra_msa_feat
# [*, S_e, N, C_e] # [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats)) extra_msa_feat = extra_msa_fn(feats)
a = self.extra_msa_embedder(extra_msa_feat)
if(self.globals.offload_inference): if(self.globals.offload_inference):
# To allow the extra MSA stack (and later the evoformer) to # To allow the extra MSA stack (and later the evoformer) to
...@@ -431,10 +439,34 @@ class AlphaFold(nn.Module): ...@@ -431,10 +439,34 @@ class AlphaFold(nn.Module):
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = outputs["pair"] z_prev = outputs["pair"]
early_stop = False
if self.globals.is_multimer:
early_stop = self.tolerance_reached(x_prev, outputs["final_atom_positions"], seq_mask, no_batch_dims)
del x_prev
# [*, N, 3] # [*, N, 3]
x_prev = outputs["final_atom_positions"] x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev return outputs, m_1_prev, z_prev, x_prev, early_stop
def _disable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = None
self.evoformer.blocks_per_ckpt = None
for b in self.extra_msa_stack.blocks:
b.ckpt = False
def _enable_activation_checkpointing(self):
self.template_embedder.template_pair_stack.blocks_per_ckpt = (
self.config.template.template_pair_stack.blocks_per_ckpt
)
self.evoformer.blocks_per_ckpt = (
self.config.evoformer_stack.blocks_per_ckpt
)
for b in self.extra_msa_stack.blocks:
b.ckpt = self.config.extra_msa.extra_msa_stack.ckpt
def forward(self, batch): def forward(self, batch):
""" """
...@@ -495,13 +527,14 @@ class AlphaFold(nn.Module): ...@@ -495,13 +527,14 @@ class AlphaFold(nn.Module):
# Main recycling loop # Main recycling loop
num_iters = batch["aatype"].shape[-1] num_iters = batch["aatype"].shape[-1]
early_stop = False
for cycle_no in range(num_iters): for cycle_no in range(num_iters):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer # Enable grad iff we're training and it's the final recycling layer
is_final_iter = cycle_no == (num_iters - 1) is_final_iter = cycle_no == (num_iters - 1) or early_stop
with torch.set_grad_enabled(is_grad_enabled and is_final_iter): with torch.set_grad_enabled(is_grad_enabled and is_final_iter):
if is_final_iter: if is_final_iter:
# Sidestep AMP bug (PyTorch issue #65766) # Sidestep AMP bug (PyTorch issue #65766)
...@@ -509,16 +542,21 @@ class AlphaFold(nn.Module): ...@@ -509,16 +542,21 @@ class AlphaFold(nn.Module):
torch.clear_autocast_cache() torch.clear_autocast_cache()
# Run the next iteration of the model # Run the next iteration of the model
outputs, m_1_prev, z_prev, x_prev = self.iteration( outputs, m_1_prev, z_prev, x_prev, early_stop = self.iteration(
feats, feats,
prevs, prevs,
_recycle=(num_iters > 1) _recycle=(num_iters > 1)
) )
if(not is_final_iter): if not is_final_iter:
del outputs del outputs
prevs = [m_1_prev, z_prev, x_prev] prevs = [m_1_prev, z_prev, x_prev]
del m_1_prev, z_prev, x_prev del m_1_prev, z_prev, x_prev
else:
break
if "asym_id" in batch:
outputs["asym_id"] = feats["asym_id"]
# Run auxiliary heads # Run auxiliary heads
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -571,60 +571,3 @@ def run_pipeline( ...@@ -571,60 +571,3 @@ def run_pipeline(
) )
iteration += 1 iteration += 1
return ret return ret
def get_initial_energies(
pdb_strs: Sequence[str],
stiffness: float = 0.0,
restraint_set: str = "non_hydrogen",
exclude_residues: Optional[Sequence[int]] = None,
):
"""Returns initial potential energies for a sequence of PDBs.
Assumes the input PDBs are ready for minimization, and all have the same
topology.
Allows time to be saved by not pdbfixing / rebuilding the system.
Args:
pdb_strs: List of PDB strings.
stiffness: kcal/mol A**2, spring constant of heavy atom restraining
potential.
restraint_set: Which atom types to restrain.
exclude_residues: An optional list of zero-indexed residues to exclude from
restraints.
Returns:
A list of initial energies in the same order as pdb_strs.
"""
exclude_residues = exclude_residues or []
openmm_pdbs = [
openmm_app.PDBFile(PdbStructure(io.StringIO(p))) for p in pdb_strs
]
force_field = openmm_app.ForceField("amber99sb.xml")
system = force_field.createSystem(
openmm_pdbs[0].topology, constraints=openmm_app.HBonds
)
stiffness = stiffness * ENERGY / (LENGTH ** 2)
if stiffness > 0 * ENERGY / (LENGTH ** 2):
_add_restraints(
system, openmm_pdbs[0], stiffness, restraint_set, exclude_residues
)
simulation = openmm_app.Simulation(
openmm_pdbs[0].topology,
system,
openmm.LangevinIntegrator(0, 0.01, 0.0),
openmm.Platform.getPlatformByName("CPU"),
)
energies = []
for pdb in openmm_pdbs:
try:
simulation.context.setPositions(pdb.positions)
state = simulation.context.getState(getEnergy=True)
energies.append(state.getPotentialEnergy().value_in_unit(ENERGY))
except Exception as e: # pylint: disable=broad-except
logging.error(
"Error getting initial energy, returning large value %s", e
)
energies.append(unit.Quantity(1e20, ENERGY))
return energies
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utils for geometry library."""
import dataclasses
def get_field_names(cls):
fields = dataclasses.fields(cls)
field_names = [f.name for f in fields]
return field_names
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment