Commit b2d102cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor certain modules for TorchScript, fix recycling bug

parent 4bd4ad93
...@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module): ...@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries) oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh) return self.linear_relpos(oh)
def forward(self, def forward(self,
...@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module): ...@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri) pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
tf_m = (self.linear_tf_m(tf) tf_m = (
.unsqueeze(-3) self.linear_tf_m(tf)
.expand((*(-1,) * len(tf.shape[:-2]), n_clust, -1, -1))) .unsqueeze(-3)
.expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb return msa_emb, pair_emb
...@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
self.no_bins, self.no_bins,
dtype=x.dtype,
requires_grad=False, requires_grad=False,
device=x.device device=x.device
) )
......
...@@ -109,7 +109,7 @@ class AlphaFold(nn.Module): ...@@ -109,7 +109,7 @@ class AlphaFold(nn.Module):
# Embed the templates one at a time (with a poor man's vmap) # Embed the templates one at a time (with a poor man's vmap)
template_embeds = [] template_embeds = []
n_templ = batch["template_aatype"].shape[templ_dim] n_templ = batch["template_aatype"].shape[templ_dim]
for i in range(n_templ): for i in range(n_templ):
idx = batch["template_aatype"].new_tensor(i) idx = batch["template_aatype"].new_tensor(i)
single_template_feats = tensor_tree_map( single_template_feats = tensor_tree_map(
lambda t: torch.index_select(t, templ_dim, idx), lambda t: torch.index_select(t, templ_dim, idx),
...@@ -170,6 +170,154 @@ class AlphaFold(nn.Module): ...@@ -170,6 +170,154 @@ class AlphaFold(nn.Module):
"torsion_angles_mask": angle_feats["torsion_angles_mask"], "torsion_angles_mask": angle_feats["torsion_angles_mask"],
} }
def iteration(self, feats, m_1_prev, z_prev, x_prev):
# Primary output dictionary
outputs = {}
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
)
# [*, N, N, C_z]
z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
)
# [*, N, 3]
x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z]
z = self.extra_msa_stack(
a,
z,
msa_mask=feats["extra_msa_mask"],
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
)
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [* N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch): def forward(self, batch):
""" """
Args: Args:
...@@ -223,160 +371,19 @@ class AlphaFold(nn.Module): ...@@ -223,160 +371,19 @@ class AlphaFold(nn.Module):
# Recycling embeddings # Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None m_1_prev, z_prev, x_prev = None, None, None
# Primary output dictionary
outputs = {}
# Main recycling loop # Main recycling loop
for cycle_no in range(self.config.no_cycles): for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle # Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no] fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch) feats = tensor_tree_map(fetch_cur_batch, batch)
# Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims)
n = feats["target_feat"].shape[-2]
n_seq = feats["msa_feat"].shape[-3]
device = feats["target_feat"].device
# Prep some features
seq_mask = feats["seq_mask"]
pair_mask = seq_mask[..., None] * seq_mask[..., None, :]
msa_mask = feats["msa_mask"]
# Initialize the MSA and pair representations
# m: [*, S_c, N, C_m]
# z: [*, N, N, C_z]
m, z = self.input_embedder(
feats["target_feat"],
feats["residue_index"],
feats["msa_feat"],
)
# Inject information from previous recycling iterations
if(self.config.no_cycles > 1):
# Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m]
m_1_prev = torch.zeros(
(*batch_dims, n, self.config.c_m),
requires_grad=False,
device=device,
)
# [*, N, N, C_z]
z_prev = torch.zeros(
(*batch_dims, n, n, self.config.c_z),
requires_grad=False,
device=device,
)
# [*, N, 3]
x_prev = torch.zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False,
device=device,
)
x_prev = pseudo_beta_fn(
feats["aatype"],
x_prev,
None
)
# m_1_prev_emb: [*, N, C_m]
# z_prev_emb: [*, N, N, C_z]
m_1_prev_emb, z_prev_emb = self.recycling_embedder(
m_1_prev,
z_prev,
x_prev,
)
# [*, S_c, N, C_m]
m[..., 0, :, :] += m_1_prev_emb
# [*, N, N, C_z]
z = z + z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
if(self.config.template.enabled):
template_feats = {
k:v for k,v in feats.items() if "template_" in k
}
template_embeds = self.embed_templates(
template_feats,
z,
pair_mask,
no_batch_dims,
)
# [*, N, N, C_z]
z = z + template_embeds["template_pair_embedding"]
if(self.config.template.embed_angles):
# [*, S = S_c + S_t, N, C_m]
m = torch.cat(
[m, template_embeds["template_angle_embedding"]],
dim=-3
)
# [*, S, N]
torsion_angles_mask = template_embeds["torsion_angles_mask"]
msa_mask = torch.cat(
[feats["msa_mask"], torsion_angles_mask[..., 2]], axis=-2
)
# Embed extra MSA features + merge with pairwise embeddings
if(self.config.extra_msa.enabled):
# [*, S_e, N, C_e]
a = self.extra_msa_embedder(build_extra_msa_feat(feats))
# [*, N, N, C_z] # Enable grad iff we're training and it's the final recycling layer
z = self.extra_msa_stack( is_final_iter = (cycle_no == self.config.no_cycles - 1)
a, with torch.set_grad_enabled(self.training and is_final_iter):
z, outputs, m_1_prev, z_prev, x_prev = self.iteration(
msa_mask=feats["extra_msa_mask"], feats, m_1_prev, z_prev, x_prev,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans,
) )
# Run MSA + pair embeddings through the trunk of the network
# m: [*, S, N, C_m]
# z: [*, N, N, C_z]
# s: [*, N, C_s]
m, z, s = self.evoformer(
m,
z,
msa_mask=msa_mask,
pair_mask=pair_mask,
_mask_trans=self.config._mask_trans
)
outputs["msa"] = m[..., :n_seq, :, :]
outputs["pair"] = z
outputs["single"] = s
# Predict 3D structure
outputs["sm"] = self.structure_module(
s, z, feats["aatype"], mask=feats["seq_mask"],
)
outputs["final_atom_positions"] = atom14_to_atom37(
outputs["sm"]["positions"][-1], feats
)
outputs["final_atom_mask"] = feats["atom37_atom_exists"]
# Save embeddings for use during the next recycling iteration
# [*, N, C_m]
m_1_prev = m[..., 0, :, :]
# [* N, N, C_z]
z_prev = z
# [*, N, 3]
x_prev = outputs["final_atom_positions"]
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
return outputs return outputs
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -69,7 +70,8 @@ class MSAAttention(nn.Module): ...@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
...@@ -93,7 +95,7 @@ class MSAAttention(nn.Module): ...@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if(mask is None): if(mask is None):
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = torch.ones( mask = torch.ones(
(*m.shape[:-3], n_seq, n_res), m.shape[:-3] + (n_seq, n_res),
device=m.device, device=m.device,
requires_grad=False requires_grad=False
) )
...@@ -103,7 +105,7 @@ class MSAAttention(nn.Module): ...@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res] # [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand( bias = bias.expand(
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1) ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) )
biases = [bias] biases = [bias]
...@@ -115,7 +117,7 @@ class MSAAttention(nn.Module): ...@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z = self.linear_z(z) z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
biases.append(z) biases.append(z)
...@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self.inf = inf self.inf = inf
self.eps = eps self.eps = eps
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = nn.LayerNorm(c_in)
self.linear_q = Linear(
self.c_in, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
C_hidden = self.c_hidden
self.linear_k = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(self.c_in, self.c_hidden * self.no_heads, init="gating")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_in, init="final")
self.sigmoid = nn.Sigmoid() self.global_attention = GlobalAttention(
self.softmax = nn.Softmax(dim=-1) c_in=c_in,
c_hidden=c_hidden,
def global_attention(self, m, mask): no_heads=no_heads,
# [*, N_res, C_in] inf=inf,
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) / eps=eps,
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
) )
# [*, N_res, N_seq, C_hidden] def forward(self,
g = self.sigmoid(self.linear_g(m)) m: torch.Tensor,
mask: Optional[torch.Tensor] = None
# [*, N_res, N_seq, H, C_hidden] ) -> torch.Tensor:
g = g.view(*g.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(*o.shape[:-2], -1)
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def forward(self, m, mask=None):
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
if(mask is None): if(mask is None):
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = m.new_ones(m.shape[:-1], requires_grad=False) mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
) )
else: else:
m = self.global_attention(**mha_input) m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -235,12 +235,6 @@ class Attention(nn.Module): ...@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(k_x) k = self.linear_k(k_x)
...@@ -253,20 +247,20 @@ class Attention(nn.Module): ...@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul( a = torch.matmul(
q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden] permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
k.permute(0, 2, 3, 1), # [*, H, C_hidden, K] permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
) )
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm a *= norm
if(biases is not None): if(biases is not None):
for b in biases: for b in biases:
a = a + b a += b
a = self.softmax(a) a = self.softmax(a)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
v.permute(0, 2, 1, 3), # [*, H, V, C_hidden] permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden]
) )
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -282,11 +276,80 @@ class Attention(nn.Module): ...@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q] # [*, Q, C_q]
o = self.linear_o(o) o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o return o
def scripted_attention(*args, **kwargs): class GlobalAttention(nn.Module):
return torch.jit.script(Attention(*args, **kwargs)) def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
c_in, c_hidden * no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_ from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
...@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module): ...@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, a): def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a s_initial = a
...@@ -85,7 +85,7 @@ class AngleResnet(nn.Module): ...@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_blocks = no_blocks self.no_blocks = no_blocks
self.no_angles = no_angles self.no_angles = no_angles
self.epsilon = epsilon self.eps = epsilon
self.linear_in = Linear(self.c_in, self.c_hidden) self.linear_in = Linear(self.c_in, self.c_hidden)
self.linear_initial = Linear(self.c_in, self.c_hidden) self.linear_initial = Linear(self.c_in, self.c_hidden)
...@@ -99,7 +99,10 @@ class AngleResnet(nn.Module): ...@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, s, s_initial): def forward(self,
s: torch.Tensor,
s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
s: s:
...@@ -130,14 +133,13 @@ class AngleResnet(nn.Module): ...@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s = self.linear_out(s) s = self.linear_out(s)
# [*, no_angles, 2] # [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2) s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s unnormalized_s = s
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
torch.clamp( torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True), torch.sum(s ** 2, dim=-1, keepdim=True),
min=self.epsilon, min=self.eps,
) )
) )
s = s / norm_denom s = s / norm_denom
...@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module): ...@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor, z: torch.Tensor,
t: T, t: T,
mask: torch.Tensor, mask: torch.Tensor,
): ) -> torch.Tensor:
""" """
Args: Args:
s: s:
...@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module): ...@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
####################################### #######################################
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(s) q = self.linear_q(s)
kv = self.linear_kv(s) kv = self.linear_kv(s)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1) q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden] # [*, N_res, H, 2 * C_hidden]
kv = kv.view(*kv.shape[:-1], self.no_heads, -1) kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1) k, v = torch.split(kv, self.c_hidden, dim=-1)
...@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module): ...@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_q, 3]
q_pts = q_pts.view( q_pts = q_pts.view(
*q_pts.shape[:-2], self.no_heads, self.no_qk_points, 3 q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
) )
# [*, N_res, H * (P_q + P_v) * 3] # [*, N_res, H * (P_q + P_v) * 3]
...@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module): ...@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3] # [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view( kv_pts = kv_pts.view(
*kv_pts.shape[:-2], self.no_heads, -1, 3 kv_pts.shape[:-2] + (self.no_heads, -1, 3)
) )
# [*, N_res, H, P_q/P_v, 3] # [*, N_res, H, P_q/P_v, 3]
...@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module): ...@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
a = torch.matmul( a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a = a + math.sqrt(1. / (3 * self.c_hidden)) a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1) a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module): ...@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask = self.inf * (square_mask - 1) square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1) pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att a = a + pt_att
a = a + square_mask.unsqueeze(-3) a = a + square_mask.unsqueeze(-3)
a = self.softmax(a) a = self.softmax(a)
...@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module): ...@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v] # [*, H, 3, N_res, P_v]
o_pt = torch.matmul( o_pt = torch.matmul(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res] a.unsqueeze(-3), # [*, H, 1, N_res, N_res]
permute_final_dims(v_pts, 1, 3, 0, 2), # [*, H, 3, N_res, P_v] permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v]
) )
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, 2, 0, 3, 1) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt) o_pt = t[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
...@@ -758,35 +759,39 @@ class StructureModule(nn.Module): ...@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, device): def _init_residue_constants(self, dtype, device):
if(self.default_frames is None): if(self.default_frames is None):
self.default_frames = torch.tensor( self.default_frames = torch.tensor(
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.group_idx is None): if(self.group_idx is None):
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.atom_mask is None): if(self.atom_mask is None):
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.lit_positions is None): if(self.lit_positions is None):
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.device) self._init_residue_constants(f.dtype, f.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames) return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
...@@ -797,7 +802,7 @@ class StructureModule(nn.Module): ...@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these # TODO: Maybe this stuff should be done on CPU instead (so these
# arrays # arrays
self._init_residue_constants(f.device) self._init_residue_constants(f.dtype, f.device)
return _frames_and_literature_positions_to_atom14_pos( return _frames_and_literature_positions_to_atom14_pos(
t, t,
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.mha = scripted_attention( self.mha = Attention(
self.c_z, self.c_t, self.c_t, self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads, self.c_hidden, self.no_heads,
gating=False, gating=False,
...@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a # NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1, # [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess. # but not sure enough to remove it. It's nice to have, I guess.
template_mask = torch.ones(t.shape[:-3], device=t.device) template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1)) bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
...@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z = z.unsqueeze(-2) z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t] # [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, 1, 2, 0, 3) t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
mha_inputs = { mha_inputs = {
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module): ...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = scripted_attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
...@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module): ...@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J] # [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), 2, 0, 1) triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J] # [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4) triangle_bias = triangle_bias.unsqueeze(-4)
......
...@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
): ):
# [*, C, N_i, N_j] # [*, C, N_i, N_j]
p = torch.matmul( p = torch.matmul(
permute_final_dims(a, 2, 0, 1), permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, 2, 1, 0), permute_final_dims(b, (2, 1, 0)),
) )
# [*, N_i, N_j, C] # [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0) return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(self, def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C] a: torch.Tensor, # [*, N_k, N_i, C]
...@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j] # [*, C, N_i, N_j]
p = torch.matmul( p = torch.matmul(
permute_final_dims(a, 2, 1, 0), permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, 2, 0, 1), permute_final_dims(b, (2, 0, 1)),
) )
# [*, N_i, N_j, C] # [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0) return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None): def forward(self, z, mask=None):
""" """
......
...@@ -13,30 +13,25 @@ ...@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import torch import torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def rot_matmul(a, b): def rot_matmul(a, b):
e = ...
row_1 = torch.stack([ row_1 = torch.stack([
a[e,0,0]*b[e,0,0] + a[e,0,1]*b[e,1,0] + a[e,0,2]*b[e,2,0], a[...,0,0]*b[...,0,0] + a[...,0,1]*b[...,1,0] + a[...,0,2]*b[...,2,0],
a[e,0,0]*b[e,0,1] + a[e,0,1]*b[e,1,1] + a[e,0,2]*b[e,2,1], a[...,0,0]*b[...,0,1] + a[...,0,1]*b[...,1,1] + a[...,0,2]*b[...,2,1],
a[e,0,0]*b[e,0,2] + a[e,0,1]*b[e,1,2] + a[e,0,2]*b[e,2,2], a[...,0,0]*b[...,0,2] + a[...,0,1]*b[...,1,2] + a[...,0,2]*b[...,2,2],
], dim=-1) ], dim=-1)
row_2 = torch.stack([ row_2 = torch.stack([
a[e,1,0]*b[e,0,0] + a[e,1,1]*b[e,1,0] + a[e,1,2]*b[e,2,0], a[...,1,0]*b[...,0,0] + a[...,1,1]*b[...,1,0] + a[...,1,2]*b[...,2,0],
a[e,1,0]*b[e,0,1] + a[e,1,1]*b[e,1,1] + a[e,1,2]*b[e,2,1], a[...,1,0]*b[...,0,1] + a[...,1,1]*b[...,1,1] + a[...,1,2]*b[...,2,1],
a[e,1,0]*b[e,0,2] + a[e,1,1]*b[e,1,2] + a[e,1,2]*b[e,2,2], a[...,1,0]*b[...,0,2] + a[...,1,1]*b[...,1,2] + a[...,1,2]*b[...,2,2],
], dim=-1) ], dim=-1)
row_3 = torch.stack([ row_3 = torch.stack([
a[e,2,0]*b[e,0,0] + a[e,2,1]*b[e,1,0] + a[e,2,2]*b[e,2,0], a[...,2,0]*b[...,0,0] + a[...,2,1]*b[...,1,0] + a[...,2,2]*b[...,2,0],
a[e,2,0]*b[e,0,1] + a[e,2,1]*b[e,1,1] + a[e,2,2]*b[e,2,1], a[...,2,0]*b[...,0,1] + a[...,2,1]*b[...,1,1] + a[...,2,2]*b[...,2,1],
a[e,2,0]*b[e,0,2] + a[e,2,1]*b[e,1,2] + a[e,2,2]*b[e,2,2], a[...,2,0]*b[...,0,2] + a[...,2,1]*b[...,1,2] + a[...,2,2]*b[...,2,2],
], dim=-1) ], dim=-1)
return torch.stack([row_1, row_2, row_3], dim=-2) return torch.stack([row_1, row_2, row_3], dim=-2)
...@@ -175,7 +170,7 @@ class T: ...@@ -175,7 +170,7 @@ class T:
return T(rots, trans) return T(rots, trans)
def to_4x4(self): def to_4x4(self):
tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device) tensor = self.rots.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self.rots tensor[..., :3, :3] = self.rots
tensor[..., :3, 3] = self.trans tensor[..., :3, 3] = self.trans
tensor[..., 3, 3] = 1 tensor[..., 3, 3] = 1
...@@ -311,7 +306,7 @@ def _to_mat(pairs): ...@@ -311,7 +306,7 @@ def _to_mat(pairs):
return mat return mat
_qtr_mat = torch.zeros((4, 4, 3, 3)) _qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)]) _qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)])
_qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)]) _qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)])
_qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)]) _qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)])
...@@ -328,9 +323,11 @@ def quat_to_rot( ...@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4] # [*, 4, 4]
quat = quat[..., None] * quat[..., None, :] quat = quat[..., None] * quat[..., None, :]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3] # [*, 4, 4, 3, 3]
shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3)) shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device) quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3] # [*, 3, 3]
return torch.sum(quat, dim=(-3, -4)) return torch.sum(quat, dim=(-3, -4))
...@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector): ...@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats = vector[..., :4] quats = vector[..., :4]
trans = vector[..., 4:] trans = vector[..., 4:]
four_by_four = torch.zeros( four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four[..., :3, :3] = quat_to_rot(quats) four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1 four_by_four[..., 3, 3] = 1
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import deepspeed import deepspeed
import torch import torch
from torch.utils.checkpoint import checkpoint
from typing import Any, Tuple, List, Callable from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any BLOCK_ARG = Any
...@@ -55,7 +56,7 @@ def checkpoint_blocks( ...@@ -55,7 +56,7 @@ def checkpoint_blocks(
return a return a
def chunker(s, e): def chunker(s, e):
def exec_sliced(a): def exec_sliced(*a):
return exec(blocks[s:e], a) return exec(blocks[s:e], a)
return exec_sliced return exec_sliced
...@@ -69,7 +70,7 @@ def checkpoint_blocks( ...@@ -69,7 +70,7 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args) args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args) args = wrap(args)
return args return args
...@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: { MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m), "query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt) "attention": GlobalAttentionParams(matt.global_attention)
} }
MSAAttPairBiasParams = lambda matt: dict( MSAAttPairBiasParams = lambda matt: dict(
......
...@@ -356,7 +356,7 @@ def lddt_loss( ...@@ -356,7 +356,7 @@ def lddt_loss(
) )
dists_to_score = ( dists_to_score = (
(dmat_true < cutoff) * all_atom_mask * (dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) * permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device)) (1. - torch.eye(n, device=all_atom_mask.device))
) )
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict
def permute_final_dims(tensor, *inds): def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds) zero_index = -1 * len(inds)
first_inds = range(len(tensor.shape[:zero_index])) first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(*first_inds, *[zero_index + i for i in inds]) return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(tensor: torch.Tensor, no_dims: int): def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
...@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts): ...@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def one_hot(x, v_bins): def one_hot(x, v_bins):
reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),))) reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1) am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float() return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
...@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type): ...@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(layer, inputs, chunk_size, no_batch_dims): def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
) -> Any:
""" """
Implements the "chunking" procedure described in section 1.11.8. Implements the "chunking" procedure described in section 1.11.8.
...@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer: layer:
The layer to be applied chunk-wise The layer to be applied chunk-wise
inputs: inputs:
A (nested) dictionary of keyworded inputs. All leaves must be A (non-nested) dictionary of keyworded inputs. All leaves must
tensors and must share the same batch dimensions. be tensors and must share the same batch dimensions.
chunk_size: chunk_size:
The number of sub-batches per chunk. If multiple batch The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single dimensions are specified, a "sub-batch" is defined as a single
...@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return shapes return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)] initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = [max(s) for s in zip(*initial_dims)] orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t): def prep_inputs(t):
# TODO: make this more memory efficient. This sucks # TODO: make this more memory efficient. This sucks
...@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output # Allocate space for the output
if(out is None): if(out is None):
allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:]) allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk) out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space # Put the chunk in its pre-allocated space
...@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i += chunk_size i += chunk_size
reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:]) reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out) out = tensor_tree_map(reshape, out)
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment