Commit 56d5e39c authored by Geoffrey Yu's avatar Geoffrey Yu
Browse files

Merge remote-tracking branch 'upstream/multimer' into multimer

parents 56b86074 51556d52
......@@ -13,12 +13,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, Optional
from openfold.utils import all_atom_multimer
from openfold.utils.feats import (
pseudo_beta_fn,
dgram_from_positions,
build_template_angle_feat,
build_template_pair_feat,
)
from openfold.model.primitives import Linear, LayerNorm
from openfold.utils.tensor_utils import add, one_hot
from openfold.model.template import (
TemplatePairStack,
TemplatePointwiseAttention,
)
from openfold.utils import geometry
from openfold.utils.tensor_utils import add, one_hot, tensor_tree_map, dict_multimap
class InputEmbedder(nn.Module):
......@@ -99,12 +113,13 @@ class InputEmbedder(nn.Module):
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
tf:
"target_feat" features of shape [*, N_res, tf_dim]
ri:
"residue_index" features of shape [*, N_res]
msa:
"msa_feat" features of shape [*, N_clust, N_res, msa_dim]
batch: Dict containing
"target_feat":
Features of shape [*, N_res, tf_dim]
"residue_index":
Features of shape [*, N_res]
"msa_feat":
Features of shape [*, N_clust, N_res, msa_dim]
Returns:
msa_emb:
[*, N_clust, N_res, C_m] MSA embedding
......@@ -139,6 +154,162 @@ class InputEmbedder(nn.Module):
return msa_emb, pair_emb
class InputEmbedderMultimer(nn.Module):
"""
Embeds a subset of the input features.
Implements Algorithms 3 (InputEmbedder) and 4 (relpos).
"""
def __init__(
self,
tf_dim: int,
msa_dim: int,
c_z: int,
c_m: int,
max_relative_idx: int,
use_chain_relative: bool,
max_relative_chain: int,
**kwargs,
):
"""
Args:
tf_dim:
Final dimension of the target features
msa_dim:
Final dimension of the MSA features
c_z:
Pair embedding dimension
c_m:
MSA embedding dimension
relpos_k:
Window size used in relative positional encoding
"""
super(InputEmbedderMultimer, self).__init__()
self.tf_dim = tf_dim
self.msa_dim = msa_dim
self.c_z = c_z
self.c_m = c_m
self.linear_tf_z_i = Linear(tf_dim, c_z)
self.linear_tf_z_j = Linear(tf_dim, c_z)
self.linear_tf_m = Linear(tf_dim, c_m)
self.linear_msa_m = Linear(msa_dim, c_m)
# RPE stuff
self.max_relative_idx = max_relative_idx
self.use_chain_relative = use_chain_relative
self.max_relative_chain = max_relative_chain
if(self.use_chain_relative):
self.no_bins = (
2 * max_relative_idx + 2 +
1 +
2 * max_relative_chain + 2
)
else:
self.no_bins = 2 * max_relative_idx + 1
self.linear_relpos = Linear(self.no_bins, c_z)
def relpos(self, batch):
pos = batch["residue_index"]
asym_id = batch["asym_id"]
asym_id_same = (asym_id[..., None] == asym_id[..., None, :])
offset = pos[..., None] - pos[..., None, :]
clipped_offset = torch.clamp(
offset + self.max_relative_idx, 0, 2 * self.max_relative_idx
)
rel_feats = []
if(self.use_chain_relative):
final_offset = torch.where(
asym_id_same,
clipped_offset,
(2 * self.max_relative_idx + 1) *
torch.ones_like(clipped_offset)
)
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 2, device=final_offset.device
)
rel_pos = one_hot(
final_offset,
boundaries,
)
rel_feats.append(rel_pos)
entity_id = batch["entity_id"]
entity_id_same = (entity_id[..., None] == entity_id[..., None, :])
rel_feats.append(entity_id_same[..., None])
sym_id = batch["sym_id"]
rel_sym_id = sym_id[..., None] - sym_id[..., None, :]
max_rel_chain = self.max_relative_chain
clipped_rel_chain = torch.clamp(
rel_sym_id + max_rel_chain,
0,
2 * max_rel_chain,
)
final_rel_chain = torch.where(
entity_id_same,
clipped_rel_chain,
(2 * max_rel_chain + 1) *
torch.ones_like(clipped_rel_chain)
)
boundaries = torch.arange(
start=0, end=2 * max_rel_chain + 2, device=final_rel_chain.device
)
rel_chain = one_hot(
final_rel_chain,
boundaries,
)
rel_feats.append(rel_chain)
else:
boundaries = torch.arange(
start=0, end=2 * self.max_relative_idx + 1, device=clipped_offset.device
)
rel_pos = one_hot(
clipped_offset, boundaries,
)
rel_feats.append(rel_pos)
rel_feat = torch.cat(rel_feats, dim=-1).to(
self.linear_relpos.weight.dtype
)
return self.linear_relpos(rel_feat)
def forward(self, batch) -> Tuple[torch.Tensor, torch.Tensor]:
tf = batch["target_feat"]
msa = batch["msa_feat"]
# [*, N_res, c_z]
tf_emb_i = self.linear_tf_z_i(tf)
tf_emb_j = self.linear_tf_z_j(tf)
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb = pair_emb + self.relpos(batch)
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
class RecyclingEmbedder(nn.Module):
"""
Embeds the output of an iteration of the model for recycling.
......@@ -365,3 +536,345 @@ class ExtraMSAEmbedder(nn.Module):
x = self.linear(x)
return x
class TemplateEmbedder(nn.Module):
def __init__(self, config):
super(TemplateEmbedder, self).__init__()
self.config = config
self.template_angle_embedder = TemplateAngleEmbedder(
**config["template_angle_embedder"],
)
self.template_pair_embedder = TemplatePairEmbedder(
**config["template_pair_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.template_pointwise_att = TemplatePointwiseAttention(
**config["template_pointwise_attention"],
)
def forward(
self,
batch,
z,
pair_mask,
templ_dim,
chunk_size,
_mask_trans=True,
use_lma=False,
inplace_safe=False
):
# Embed the templates one at a time (with a poor man's vmap)
pair_embeds = []
n = z.shape[-2]
n_templ = batch["template_aatype"].shape[templ_dim]
if (inplace_safe):
# We'll preallocate the full pair tensor now to avoid manifesting
# a second copy during the stack later on
t_pair = z.new_zeros(
z.shape[:-3] +
(n_templ, n, n, self.config.template_pair_embedder.c_t)
)
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx).squeeze(templ_dim),
batch,
)
# [*, N, N, C_t]
t = build_template_pair_feat(
single_template_feats,
use_unit_vector=self.config.use_unit_vector,
inf=self.config.inf,
eps=self.config.eps,
**self.config.distogram,
).to(z.dtype)
t = self.template_pair_embedder(t)
if (inplace_safe):
t_pair[..., i, :, :, :] = t
else:
pair_embeds.append(t)
del t
if (not inplace_safe):
t_pair = torch.stack(pair_embeds, dim=templ_dim)
del pair_embeds
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
t_pair,
pair_mask.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
use_lma=use_lma,
inplace_safe=inplace_safe,
_mask_trans=_mask_trans,
)
del t_pair
# [*, N, N, C_z]
t = self.template_pointwise_att(
t,
z,
template_mask=batch["template_mask"].to(dtype=z.dtype),
use_lma=use_lma,
)
t_mask = torch.sum(batch["template_mask"], dim=-1) > 0
# Append singletons
t_mask = t_mask.reshape(
*t_mask.shape, *([1] * (len(t.shape) - len(t_mask.shape)))
)
if (inplace_safe):
t *= t_mask
else:
t = t * t_mask
ret = {}
ret.update({"template_pair_embedding": t})
del t
if self.config.embed_angles:
template_angle_feat = build_template_angle_feat(
batch
)
# [*, S_t, N, C_m]
a = self.template_angle_embedder(template_angle_feat)
ret["template_single_embedding"] = a
return ret
class TemplatePairEmbedderMultimer(nn.Module):
def __init__(self,
c_z: int,
c_out: int,
c_dgram: int,
c_aatype: int,
):
super(TemplatePairEmbedderMultimer, self).__init__()
self.dgram_linear = Linear(c_dgram, c_out)
self.aatype_linear_1 = Linear(c_aatype, c_out)
self.aatype_linear_2 = Linear(c_aatype, c_out)
self.query_embedding_layer_norm = LayerNorm(c_z)
self.query_embedding_linear = Linear(c_z, c_out)
self.pseudo_beta_mask_linear = Linear(1, c_out)
self.x_linear = Linear(1, c_out)
self.y_linear = Linear(1, c_out)
self.z_linear = Linear(1, c_out)
self.backbone_mask_linear = Linear(1, c_out)
def forward(self,
template_dgram: torch.Tensor,
aatype_one_hot: torch.Tensor,
query_embedding: torch.Tensor,
pseudo_beta_mask: torch.Tensor,
backbone_mask: torch.Tensor,
multichain_mask_2d: torch.Tensor,
unit_vector: geometry.Vec3Array,
) -> torch.Tensor:
act = 0.
pseudo_beta_mask_2d = (
pseudo_beta_mask[..., None] * pseudo_beta_mask[..., None, :]
)
pseudo_beta_mask_2d *= multichain_mask_2d
template_dgram *= pseudo_beta_mask_2d[..., None]
act += self.dgram_linear(template_dgram)
act += self.pseudo_beta_mask_linear(pseudo_beta_mask_2d[..., None])
aatype_one_hot = aatype_one_hot.to(template_dgram.dtype)
act += self.aatype_linear_1(aatype_one_hot[..., None, :, :])
act += self.aatype_linear_2(aatype_one_hot[..., None, :])
backbone_mask_2d = (
backbone_mask[..., None] * backbone_mask[..., None, :]
)
backbone_mask_2d *= multichain_mask_2d
x, y, z = [coord * backbone_mask_2d for coord in unit_vector]
act += self.x_linear(x[..., None])
act += self.y_linear(y[..., None])
act += self.z_linear(z[..., None])
act += self.backbone_mask_linear(backbone_mask_2d[..., None])
query_embedding = self.query_embedding_layer_norm(query_embedding)
act += self.query_embedding_linear(query_embedding)
return act
class TemplateSingleEmbedderMultimer(nn.Module):
def __init__(self,
c_in: int,
c_m: int,
):
super(TemplateSingleEmbedderMultimer, self).__init__()
self.template_single_embedder = Linear(c_in, c_m)
self.template_projector = Linear(c_m, c_m)
def forward(self,
batch,
atom_pos,
aatype_one_hot,
):
out = {}
template_chi_angles, template_chi_mask = (
all_atom_multimer.compute_chi_angles(
atom_pos,
batch["template_all_atom_mask"],
batch["template_aatype"],
)
)
template_features = torch.cat(
[
aatype_one_hot,
torch.sin(template_chi_angles) * template_chi_mask,
torch.cos(template_chi_angles) * template_chi_mask,
template_chi_mask,
],
dim=-1,
)
template_mask = template_chi_mask[..., 0]
template_activations = self.template_single_embedder(
template_features
)
template_activations = torch.nn.functional.relu(
template_activations
)
template_activations = self.template_projector(
template_activations,
)
out["template_single_embedding"] = (
template_activations
)
out["template_mask"] = template_mask
return out
class TemplateEmbedderMultimer(nn.Module):
def __init__(self, config):
super(TemplateEmbedderMultimer, self).__init__()
self.config = config
self.template_pair_embedder = TemplatePairEmbedderMultimer(
**config["template_pair_embedder"],
)
self.template_single_embedder = TemplateSingleEmbedderMultimer(
**config["template_single_embedder"],
)
self.template_pair_stack = TemplatePairStack(
**config["template_pair_stack"],
)
self.linear_t = Linear(config.c_t, config.c_z)
def forward(self,
batch,
z,
padding_mask_2d,
templ_dim,
chunk_size,
multichain_mask_2d,
use_lma=False,
inplace_safe=False
):
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
batch,
)
single_template_embeds = {}
act = 0.
template_positions, pseudo_beta_mask = (
single_template_feats["template_pseudo_beta"],
single_template_feats["template_pseudo_beta_mask"],
)
template_dgram = dgram_from_positions(
template_positions,
inf=self.config.inf,
**self.config.distogram,
)
aatype_one_hot = torch.nn.functional.one_hot(
single_template_feats["template_aatype"], 22,
)
raw_atom_pos = single_template_feats["template_all_atom_positions"]
atom_pos = geometry.Vec3Array.from_array(raw_atom_pos)
rigid, backbone_mask = all_atom_multimer.make_backbone_affine(
atom_pos,
single_template_feats["template_all_atom_mask"],
single_template_feats["template_aatype"],
)
points = rigid.translation
rigid_vec = rigid[..., None].inverse().apply_to_point(points)
unit_vector = rigid_vec.normalized()
pair_act = self.template_pair_embedder(
template_dgram,
aatype_one_hot,
z,
pseudo_beta_mask,
backbone_mask,
multichain_mask_2d,
unit_vector,
)
single_template_embeds["template_pair_embedding"] = pair_act
single_template_embeds.update(
self.template_single_embedder(
single_template_feats,
atom_pos,
aatype_one_hot,
)
)
template_embeds.append(single_template_embeds)
template_embeds = dict_multimap(
partial(torch.cat, dim=templ_dim),
template_embeds,
)
# [*, S_t, N, N, C_z]
t = self.template_pair_stack(
template_embeds["template_pair_embedding"],
padding_mask_2d.unsqueeze(-3).to(dtype=z.dtype),
chunk_size=chunk_size,
_mask_trans=False,
)
# [*, N, N, C_z]
t = torch.sum(t, dim=-4) / n_templ
t = torch.nn.functional.relu(t)
t = self.linear_t(t)
template_embeds["template_pair_embedding"] = t
return template_embeds
This diff is collapsed.
......@@ -76,9 +76,17 @@ class AuxiliaryHeads(nn.Module):
if self.config.tm.enabled:
tm_logits = self.tm(outputs["pair"])
aux_out["tm_logits"] = tm_logits
aux_out["predicted_tm_score"] = compute_tm(
aux_out["ptm_score"] = compute_tm(
tm_logits, **self.config.tm
)
asym_id = outputs.get("asym_id")
if asym_id is not None:
aux_out["iptm_score"] = compute_tm(
tm_logits, asym_id=asym_id, interface=True, **self.config.tm
)
aux_out["weighted_ptm_score"] = (self.config.tm["iptm_weight"] * aux_out["iptm_score"]
+ self.config.tm["ptm_weight"] * aux_out["ptm_score"])
aux_out.update(
compute_predicted_aligned_error(
tm_logits,
......
This diff is collapsed.
This diff is collapsed.
......@@ -33,6 +33,8 @@ from openfold.model.triangular_attention import (
from openfold.model.triangular_multiplicative_update import (
TriangleMultiplicationOutgoing,
TriangleMultiplicationIncoming,
FusedTriangleMultiplicationOutgoing,
FusedTriangleMultiplicationIncoming
)
from openfold.utils.checkpointing import checkpoint_blocks
from openfold.utils.chunk_utils import (
......@@ -154,6 +156,8 @@ class TemplatePairStackBlock(nn.Module):
no_heads: int,
pair_transition_n: int,
dropout_rate: float,
tri_mul_first: bool,
fuse_projection_weights: bool,
inf: float,
**kwargs,
):
......@@ -166,6 +170,7 @@ class TemplatePairStackBlock(nn.Module):
self.pair_transition_n = pair_transition_n
self.dropout_rate = dropout_rate
self.inf = inf
self.tri_mul_first = tri_mul_first
self.dropout_row = DropoutRowwise(self.dropout_rate)
self.dropout_col = DropoutColumnwise(self.dropout_rate)
......@@ -183,20 +188,88 @@ class TemplatePairStackBlock(nn.Module):
inf=inf,
)
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
if fuse_projection_weights:
self.tri_mul_out = FusedTriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = FusedTriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
else:
self.tri_mul_out = TriangleMultiplicationOutgoing(
self.c_t,
self.c_hidden_tri_mul,
)
self.tri_mul_in = TriangleMultiplicationIncoming(
self.c_t,
self.c_hidden_tri_mul,
)
self.pair_transition = PairTransition(
self.c_t,
self.pair_transition_n,
)
def tri_att_start_end(self, single, _attn_chunk_size, single_mask, use_lma, inplace_safe):
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
return single
def tri_mul_out_in(self, single, single_mask, inplace_safe):
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if (not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
return single
def forward(self,
z: torch.Tensor,
mask: torch.Tensor,
......@@ -219,72 +292,37 @@ class TemplatePairStackBlock(nn.Module):
for i in range(len(single_templates)):
single = single_templates[i]
single_mask = single_templates_masks[i]
single = add(single,
self.dropout_row(
self.tri_att_start(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
single = add(single,
self.dropout_col(
self.tri_att_end(
single,
chunk_size=_attn_chunk_size,
mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe,
)
),
inplace_safe,
)
tmu_update = self.tri_mul_out(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
if self.tri_mul_first:
single = self.tri_att_start_end(single=self.tri_mul_out_in(single=single,
single_mask=single_mask,
inplace_safe=inplace_safe),
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe)
else:
single = tmu_update
del tmu_update
single = self.tri_mul_out_in(single=self.tri_att_start_end(single=single,
_attn_chunk_size=_attn_chunk_size,
single_mask=single_mask,
use_lma=use_lma,
inplace_safe=inplace_safe),
single_mask=single_mask,
inplace_safe=inplace_safe)
tmu_update = self.tri_mul_in(
single,
mask=single_mask,
inplace_safe=inplace_safe,
_add_with_inplace=True,
)
if(not inplace_safe):
single = single + self.dropout_row(tmu_update)
else:
single = tmu_update
del tmu_update
single = add(single,
self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
inplace_safe,
)
if(not inplace_safe):
self.pair_transition(
single,
mask=single_mask if _mask_trans else None,
chunk_size=chunk_size,
),
inplace_safe,
)
if (not inplace_safe):
single_templates[i] = single
if(not inplace_safe):
if (not inplace_safe):
z = torch.cat(single_templates, dim=-4)
return z
......@@ -303,6 +341,8 @@ class TemplatePairStack(nn.Module):
no_heads,
pair_transition_n,
dropout_rate,
tri_mul_first,
fuse_projection_weights,
blocks_per_ckpt,
tune_chunk_size: bool = False,
inf=1e9,
......@@ -339,6 +379,8 @@ class TemplatePairStack(nn.Module):
no_heads=no_heads,
pair_transition_n=pair_transition_n,
dropout_rate=dropout_rate,
tri_mul_first=tri_mul_first,
fuse_projection_weights=fuse_projection_weights,
inf=inf,
)
self.blocks.append(block)
......@@ -512,7 +554,7 @@ def embed_templates_offload(
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": t})
......@@ -623,7 +665,7 @@ def embed_templates_average(
# [*, N, C_m]
a = model.template_angle_embedder(template_angle_feat)
ret["template_angle_embedding"] = a
ret["template_single_embedding"] = a
ret.update({"template_pair_embedding": out_tensor})
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
# Copyright 2021 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Geometry Module."""
from openfold.utils.geometry import rigid_matrix_vector
from openfold.utils.geometry import rotation_matrix
from openfold.utils.geometry import struct_of_array
from openfold.utils.geometry import vector
Rot3Array = rotation_matrix.Rot3Array
Rigid3Array = rigid_matrix_vector.Rigid3Array
StructOfArray = struct_of_array.StructOfArray
Vec3Array = vector.Vec3Array
square_euclidean_distance = vector.square_euclidean_distance
euclidean_distance = vector.euclidean_distance
dihedral_angle = vector.dihedral_angle
dot = vector.dot
cross = vector.cross
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment