"container/vscode:/vscode.git/clone" did not exist on "20b3684387645d0f27895fcbf80e9ead88ba86b5"
Commit b2d102cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor certain modules for TorchScript, fix recycling bug

parent 4bd4ad93
......@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(self,
......@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri)
#pair_emb = pair_emb + self.relpos(ri)
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (self.linear_tf_m(tf)
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand((*(-1,) * len(tf.shape[:-2]), n_clust, -1, -1)))
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
......@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
requires_grad=False,
device=x.device
)
......
......@@ -170,68 +170,10 @@ class AlphaFold(nn.Module):
"torsion_angles_mask": angle_feats["torsion_angles_mask"],
}
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary
outputs = {}
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
......@@ -259,24 +201,21 @@ class AlphaFold(nn.Module):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = torch.zeros(
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
device=device,
)
# [*, N, N, C_z]
z_prev = torch.zeros(
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
device=device,
)
# [*, N, 3]
x_prev = torch.zeros(
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
device=device,
)
x_prev = pseudo_beta_fn(
......@@ -377,6 +316,74 @@ class AlphaFold(nn.Module):
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1)
with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
outputs.update(self.aux_heads(outputs))
return outputs
......@@ -16,8 +16,9 @@
import math
import torch
import torch.nn as nn
from typing import Optional
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if(mask is None):
# [*, N_seq, N_res]
mask = torch.ones(
(*m.shape[:-3], n_seq, n_res),
m.shape[:-3] + (n_seq, n_res),
device=m.device,
requires_grad=False
)
......@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand(
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1)
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
biases = [bias]
......@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4)
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
biases.append(z)
......@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self.inf = inf
self.eps = eps
self.layer_norm_m = nn.LayerNorm(self.c_in)
self.linear_q = Linear(
self.c_in, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
C_hidden = self.c_hidden
self.linear_k = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(self.c_in, self.c_hidden * self.no_heads, init="gating")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def global_attention(self, m, mask):
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
self.layer_norm_m = nn.LayerNorm(c_in)
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
self.global_attention = GlobalAttention(
c_in=c_in,
c_hidden=c_hidden,
no_heads=no_heads,
inf=inf,
eps=eps,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(*o.shape[:-2], -1)
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def forward(self, m, mask=None):
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if(mask is None):
# [*, N_seq, N_res]
mask = m.new_ones(m.shape[:-1], requires_grad=False)
mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
......@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims=len(m.shape[:-2])
)
else:
m = self.global_attention(**mha_input)
m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
......@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K]
a = torch.matmul(
q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden]
k.permute(0, 2, 3, 1), # [*, H, C_hidden, K]
permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm
a *= norm
if(biases is not None):
for b in biases:
a = a + b
a += b
a = self.softmax(a)
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
v.permute(0, 2, 1, 3), # [*, H, V, C_hidden]
permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
......@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q]
o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o
def scripted_attention(*args, **kwargs):
return torch.jit.script(Attention(*args, **kwargs))
class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
c_in, c_hidden * no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Optional
from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self.relu = nn.ReLU()
def forward(self, a):
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
......@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self.c_hidden = c_hidden
self.no_blocks = no_blocks
self.no_angles = no_angles
self.epsilon = epsilon
self.eps = epsilon
self.linear_in = Linear(self.c_in, self.c_hidden)
self.linear_initial = Linear(self.c_in, self.c_hidden)
......@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self.relu = nn.ReLU()
def forward(self, s, s_initial):
def forward(self,
s: torch.Tensor,
s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
......@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s = self.linear_out(s)
# [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2)
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True),
min=self.epsilon,
torch.sum(s ** 2, dim=-1, keepdim=True),
min=self.eps,
)
)
s = s / norm_denom
......@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor,
t: T,
mask: torch.Tensor,
):
) -> torch.Tensor:
"""
Args:
s:
......@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(*kv.shape[:-1], self.no_heads, -1)
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
......@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
*q_pts.shape[:-2], self.no_heads, self.no_qk_points, 3
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
......@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(
*kv_pts.shape[:-2], self.no_heads, -1, 3
kv_pts.shape[:-2] + (self.no_heads, -1, 3)
)
# [*, N_res, H, P_q/P_v, 3]
......@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res]
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1)
a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1)
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
......@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
o_pt = torch.matmul(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res]
permute_final_dims(v_pts, 1, 3, 0, 2), # [*, H, 3, N_res, P_v]
permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v]
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, 2, 0, 3, 1)
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
......@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, device):
def _init_residue_constants(self, dtype, device):
if(self.default_frames is None):
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.group_idx is None):
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.atom_mask is None):
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.lit_positions is None):
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.device)
self._init_residue_constants(f.dtype, f.device)
# Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
......@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self._init_residue_constants(f.device)
self._init_residue_constants(f.dtype, f.device)
return _frames_and_literature_positions_to_atom14_pos(
t,
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import (
DropoutRowwise,
......@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads
self.chunk_size = chunk_size
self.mha = scripted_attention(
self.mha = Attention(
self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads,
gating=False,
......@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
template_mask = torch.ones(t.shape[:-3], device=t.device)
template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
......@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, 1, 2, 0, 3)
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
mha_inputs = {
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), 2, 0, 1)
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
......
......@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 0, 1),
permute_final_dims(b, 2, 1, 0),
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C]
......@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 1, 0),
permute_final_dims(b, 2, 0, 1),
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None):
"""
......
......@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def rot_matmul(a, b):
e = ...
row_1 = torch.stack([
a[e,0,0]*b[e,0,0] + a[e,0,1]*b[e,1,0] + a[e,0,2]*b[e,2,0],
a[e,0,0]*b[e,0,1] + a[e,0,1]*b[e,1,1] + a[e,0,2]*b[e,2,1],
a[e,0,0]*b[e,0,2] + a[e,0,1]*b[e,1,2] + a[e,0,2]*b[e,2,2],
a[...,0,0]*b[...,0,0] + a[...,0,1]*b[...,1,0] + a[...,0,2]*b[...,2,0],
a[...,0,0]*b[...,0,1] + a[...,0,1]*b[...,1,1] + a[...,0,2]*b[...,2,1],
a[...,0,0]*b[...,0,2] + a[...,0,1]*b[...,1,2] + a[...,0,2]*b[...,2,2],
], dim=-1)
row_2 = torch.stack([
a[e,1,0]*b[e,0,0] + a[e,1,1]*b[e,1,0] + a[e,1,2]*b[e,2,0],
a[e,1,0]*b[e,0,1] + a[e,1,1]*b[e,1,1] + a[e,1,2]*b[e,2,1],
a[e,1,0]*b[e,0,2] + a[e,1,1]*b[e,1,2] + a[e,1,2]*b[e,2,2],
a[...,1,0]*b[...,0,0] + a[...,1,1]*b[...,1,0] + a[...,1,2]*b[...,2,0],
a[...,1,0]*b[...,0,1] + a[...,1,1]*b[...,1,1] + a[...,1,2]*b[...,2,1],
a[...,1,0]*b[...,0,2] + a[...,1,1]*b[...,1,2] + a[...,1,2]*b[...,2,2],
], dim=-1)
row_3 = torch.stack([
a[e,2,0]*b[e,0,0] + a[e,2,1]*b[e,1,0] + a[e,2,2]*b[e,2,0],
a[e,2,0]*b[e,0,1] + a[e,2,1]*b[e,1,1] + a[e,2,2]*b[e,2,1],
a[e,2,0]*b[e,0,2] + a[e,2,1]*b[e,1,2] + a[e,2,2]*b[e,2,2],
a[...,2,0]*b[...,0,0] + a[...,2,1]*b[...,1,0] + a[...,2,2]*b[...,2,0],
a[...,2,0]*b[...,0,1] + a[...,2,1]*b[...,1,1] + a[...,2,2]*b[...,2,1],
a[...,2,0]*b[...,0,2] + a[...,2,1]*b[...,1,2] + a[...,2,2]*b[...,2,2],
], dim=-1)
return torch.stack([row_1, row_2, row_3], dim=-2)
......@@ -175,7 +170,7 @@ class T:
return T(rots, trans)
def to_4x4(self):
tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device)
tensor = self.rots.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self.rots
tensor[..., :3, 3] = self.trans
tensor[..., 3, 3] = 1
......@@ -311,7 +306,7 @@ def _to_mat(pairs):
return mat
_qtr_mat = torch.zeros((4, 4, 3, 3))
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)])
_qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)])
_qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)])
......@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device)
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
......@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = torch.zeros(
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
......
......@@ -14,6 +14,7 @@
import deepspeed
import torch
from torch.utils.checkpoint import checkpoint
from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any
......@@ -55,7 +56,7 @@ def checkpoint_blocks(
return a
def chunker(s, e):
def exec_sliced(a):
def exec_sliced(*a):
return exec(blocks[s:e], a)
return exec_sliced
......@@ -69,7 +70,7 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args)
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
......@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt)
"attention": GlobalAttentionParams(matt.global_attention)
}
MSAAttPairBiasParams = lambda matt: dict(
......
......@@ -356,7 +356,7 @@ def lddt_loss(
)
dists_to_score = (
(dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) *
permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device))
)
......
......@@ -16,12 +16,13 @@
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict
def permute_final_dims(tensor, *inds):
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = range(len(tensor.shape[:zero_index]))
return tensor.permute(*first_inds, *[zero_index + i for i in inds])
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
......@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),)))
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
......@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
......@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer:
The layer to be applied chunk-wise
inputs:
A (nested) dictionary of keyworded inputs. All leaves must be
tensors and must share the same batch dimensions.
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
......@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = [max(s) for s in zip(*initial_dims)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
......@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output
if(out is None):
allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:])
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
......@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i += chunk_size
reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:])
reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment