Commit b2d102cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor certain modules for TorchScript, fix recycling bug

parent 4bd4ad93
......@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
)
oh = one_hot(d, boundaries)
oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh)
def forward(self,
......@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri)
#pair_emb = pair_emb + self.relpos(ri)
pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
# [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3]
tf_m = (self.linear_tf_m(tf)
.unsqueeze(-3)
.expand((*(-1,) * len(tf.shape[:-2]), n_clust, -1, -1)))
tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb
......@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self.min_bin,
self.max_bin,
self.no_bins,
dtype=x.dtype,
requires_grad=False,
device=x.device
)
......
......@@ -109,7 +109,7 @@ class AlphaFold(nn.Module):
# Embed the templates one at a time (with a poor man's vmap)
template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ):
for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx),
......@@ -170,6 +170,154 @@ class AlphaFold(nn.Module):
"torsion_angles_mask": angle_feats["torsion_angles_mask"],
}
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary
outputs = {}
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [* N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch):
"""
Args:
......@@ -223,160 +371,19 @@ class AlphaFold(nn.Module):
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Primary output dictionary
outputs = {}
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = torch.zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
device=device,
)
# [*, N, N, C_z]
z_prev = torch.zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
device=device,
)
# [*, N, 3]
x_prev = torch.zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
device=device,
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1)
with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [* N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
outputs.update(self.aux_heads(outputs))
return outputs
......@@ -16,8 +16,9 @@
import math
import torch
import torch.nn as nn
from typing import Optional
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal"
)
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if(mask is None):
# [*, N_seq, N_res]
mask = torch.ones(
(*m.shape[:-3], n_seq, n_res),
m.shape[:-3] + (n_seq, n_res),
device=m.device,
requires_grad=False
)
......@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand(
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1)
((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
)
biases = [bias]
......@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4)
z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
biases.append(z)
......@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self.inf = inf
self.eps = eps
self.layer_norm_m = nn.LayerNorm(self.c_in)
self.linear_q = Linear(
self.c_in, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
C_hidden = self.c_hidden
self.linear_k = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(self.c_in, self.c_hidden * self.no_heads, init="gating")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_in, init="final")
self.layer_norm_m = nn.LayerNorm(c_in)
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def global_attention(self, m, mask):
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
self.global_attention = GlobalAttention(
c_in=c_in,
c_hidden=c_hidden,
no_heads=no_heads,
inf=inf,
eps=eps,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(*g.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(*o.shape[:-2], -1)
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def forward(self, m, mask=None):
def forward(self,
m: torch.Tensor,
mask: Optional[torch.Tensor] = None
) -> torch.Tensor:
n_seq, n_res, c_in = m.shape[-3:]
if(mask is None):
# [*, N_seq, N_res]
mask = m.new_ones(m.shape[:-1], requires_grad=False)
mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3)
......@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims=len(m.shape[:-2])
)
else:
m = self.global_attention(**mha_input)
m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
# [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3)
......
......@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x)
k = self.linear_k(k_x)
......@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K]
a = torch.matmul(
q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden]
k.permute(0, 2, 3, 1), # [*, H, C_hidden, K]
permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
)
norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm
a *= norm
if(biases is not None):
for b in biases:
a = a + b
a += b
a = self.softmax(a)
# [*, H, Q, C_hidden]
o = torch.matmul(
a,
v.permute(0, 2, 1, 3), # [*, H, V, C_hidden]
permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
......@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q]
o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o
def scripted_attention(*args, **kwargs):
return torch.jit.script(Attention(*args, **kwargs))
class GlobalAttention(nn.Module):
def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
c_in, c_hidden * no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
......@@ -16,7 +16,7 @@
import math
import torch
import torch.nn as nn
from typing import Optional
from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.np.residue_constants import (
......@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self.relu = nn.ReLU()
def forward(self, a):
def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a
......@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self.c_hidden = c_hidden
self.no_blocks = no_blocks
self.no_angles = no_angles
self.epsilon = epsilon
self.eps = epsilon
self.linear_in = Linear(self.c_in, self.c_hidden)
self.linear_initial = Linear(self.c_in, self.c_hidden)
......@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self.relu = nn.ReLU()
def forward(self, s, s_initial):
def forward(self,
s: torch.Tensor,
s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Args:
s:
......@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s = self.linear_out(s)
# [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2)
s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s
norm_denom = torch.sqrt(
torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True),
min=self.epsilon,
torch.sum(s ** 2, dim=-1, keepdim=True),
min=self.eps,
)
)
s = s / norm_denom
......@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor,
t: T,
mask: torch.Tensor,
):
) -> torch.Tensor:
"""
Args:
s:
......@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q = self.linear_q(s)
kv = self.linear_kv(s)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden]
kv = kv.view(*kv.shape[:-1], self.no_heads, -1)
kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1)
......@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3]
q_pts = q_pts.view(
*q_pts.shape[:-2], self.no_heads, self.no_qk_points, 3
q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
)
# [*, N_res, H * (P_q + P_v) * 3]
......@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view(
*kv_pts.shape[:-2], self.no_heads, -1, 3
kv_pts.shape[:-2] + (self.no_heads, -1, 3)
)
# [*, N_res, H, P_q/P_v, 3]
......@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res]
permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
)
a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1)
a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
......@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1)
pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att
a = a + square_mask.unsqueeze(-3)
a = self.softmax(a)
......@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
o_pt = torch.matmul(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res]
permute_final_dims(v_pts, 1, 3, 0, 2), # [*, H, 3, N_res, P_v]
permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v]
)
# [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, 2, 0, 3, 1)
o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v]
......@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return outputs
def _init_residue_constants(self, device):
def _init_residue_constants(self, dtype, device):
if(self.default_frames is None):
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
restype_rigid_group_default_frame,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.group_idx is None):
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
restype_atom14_to_rigid_group,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.atom_mask is None):
self.atom_mask = torch.tensor(
restype_atom14_mask,
restype_atom14_mask,
dtype=dtype,
device=device,
requires_grad=False,
)
if(self.lit_positions is None):
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
restype_atom14_rigid_group_positions,
dtype=dtype,
device=device,
requires_grad=False,
)
def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.device)
self._init_residue_constants(f.dtype, f.device)
# Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
......@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self._init_residue_constants(f.device)
self._init_residue_constants(f.dtype, f.device)
return _frames_and_literature_positions_to_atom14_pos(
t,
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import (
DropoutRowwise,
......@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads
self.chunk_size = chunk_size
self.mha = scripted_attention(
self.mha = Attention(
self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads,
gating=False,
......@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
template_mask = torch.ones(t.shape[:-3], device=t.device)
template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
......@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, 1, 2, 0, 3)
t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z]
mha_inputs = {
......
......@@ -18,7 +18,7 @@ import math
import torch
import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention
from openfold.model.primitives import Linear, Attention
from openfold.utils.tensor_utils import (
chunk_layer,
permute_final_dims,
......@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in,
self.c_hidden,
self.no_heads
......@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), 2, 0, 1)
triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4)
......
......@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 0, 1),
permute_final_dims(b, 2, 1, 0),
permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, (2, 1, 0)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C]
......@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j]
p = torch.matmul(
permute_final_dims(a, 2, 1, 0),
permute_final_dims(b, 2, 0, 1),
permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, (2, 0, 1)),
)
# [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0)
return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None):
"""
......
......@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def rot_matmul(a, b):
e = ...
row_1 = torch.stack([
a[e,0,0]*b[e,0,0] + a[e,0,1]*b[e,1,0] + a[e,0,2]*b[e,2,0],
a[e,0,0]*b[e,0,1] + a[e,0,1]*b[e,1,1] + a[e,0,2]*b[e,2,1],
a[e,0,0]*b[e,0,2] + a[e,0,1]*b[e,1,2] + a[e,0,2]*b[e,2,2],
a[...,0,0]*b[...,0,0] + a[...,0,1]*b[...,1,0] + a[...,0,2]*b[...,2,0],
a[...,0,0]*b[...,0,1] + a[...,0,1]*b[...,1,1] + a[...,0,2]*b[...,2,1],
a[...,0,0]*b[...,0,2] + a[...,0,1]*b[...,1,2] + a[...,0,2]*b[...,2,2],
], dim=-1)
row_2 = torch.stack([
a[e,1,0]*b[e,0,0] + a[e,1,1]*b[e,1,0] + a[e,1,2]*b[e,2,0],
a[e,1,0]*b[e,0,1] + a[e,1,1]*b[e,1,1] + a[e,1,2]*b[e,2,1],
a[e,1,0]*b[e,0,2] + a[e,1,1]*b[e,1,2] + a[e,1,2]*b[e,2,2],
a[...,1,0]*b[...,0,0] + a[...,1,1]*b[...,1,0] + a[...,1,2]*b[...,2,0],
a[...,1,0]*b[...,0,1] + a[...,1,1]*b[...,1,1] + a[...,1,2]*b[...,2,1],
a[...,1,0]*b[...,0,2] + a[...,1,1]*b[...,1,2] + a[...,1,2]*b[...,2,2],
], dim=-1)
row_3 = torch.stack([
a[e,2,0]*b[e,0,0] + a[e,2,1]*b[e,1,0] + a[e,2,2]*b[e,2,0],
a[e,2,0]*b[e,0,1] + a[e,2,1]*b[e,1,1] + a[e,2,2]*b[e,2,1],
a[e,2,0]*b[e,0,2] + a[e,2,1]*b[e,1,2] + a[e,2,2]*b[e,2,2],
a[...,2,0]*b[...,0,0] + a[...,2,1]*b[...,1,0] + a[...,2,2]*b[...,2,0],
a[...,2,0]*b[...,0,1] + a[...,2,1]*b[...,1,1] + a[...,2,2]*b[...,2,1],
a[...,2,0]*b[...,0,2] + a[...,2,1]*b[...,1,2] + a[...,2,2]*b[...,2,2],
], dim=-1)
return torch.stack([row_1, row_2, row_3], dim=-2)
......@@ -175,7 +170,7 @@ class T:
return T(rots, trans)
def to_4x4(self):
tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device)
tensor = self.rots.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self.rots
tensor[..., :3, 3] = self.trans
tensor[..., 3, 3] = 1
......@@ -311,7 +306,7 @@ def _to_mat(pairs):
return mat
_qtr_mat = torch.zeros((4, 4, 3, 3))
_qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)])
_qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)])
_qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)])
......@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4]
quat = quat[..., None] * quat[..., None, :]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3]
shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device)
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3]
return torch.sum(quat, dim=(-3, -4))
......@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats = vector[..., :4]
trans = vector[..., 4:]
four_by_four = torch.zeros(
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1
......
......@@ -14,6 +14,7 @@
import deepspeed
import torch
from torch.utils.checkpoint import checkpoint
from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any
......@@ -55,7 +56,7 @@ def checkpoint_blocks(
return a
def chunker(s, e):
def exec_sliced(a):
def exec_sliced(*a):
return exec(blocks[s:e], a)
return exec_sliced
......@@ -69,7 +70,7 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args)
args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args)
return args
......@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt)
"attention": GlobalAttentionParams(matt.global_attention)
}
MSAAttPairBiasParams = lambda matt: dict(
......
......@@ -356,7 +356,7 @@ def lddt_loss(
)
dists_to_score = (
(dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) *
permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device))
)
......
......@@ -16,12 +16,13 @@
from functools import partial
import torch
import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict
def permute_final_dims(tensor, *inds):
def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds)
first_inds = range(len(tensor.shape[:zero_index]))
return tensor.permute(*first_inds, *[zero_index + i for i in inds])
first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
......@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def one_hot(x, v_bins):
reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),)))
reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
......@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
) -> Any:
"""
Implements the "chunking" procedure described in section 1.11.8.
......@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer:
The layer to be applied chunk-wise
inputs:
A (nested) dictionary of keyworded inputs. All leaves must be
tensors and must share the same batch dimensions.
A (non-nested) dictionary of keyworded inputs. All leaves must
be tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
......@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = [max(s) for s in zip(*initial_dims)]
orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t):
# TODO: make this more memory efficient. This sucks
......@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output
if(out is None):
allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:])
allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space
......@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i += chunk_size
reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:])
reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment