Commit b2d102cb authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Refactor certain modules for TorchScript, fix recycling bug

parent 4bd4ad93
...@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module): ...@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries = torch.arange( boundaries = torch.arange(
start=-self.relpos_k, end=self.relpos_k + 1, device=d.device start=-self.relpos_k, end=self.relpos_k + 1, device=d.device
) )
oh = one_hot(d, boundaries) oh = one_hot(d, boundaries).type(ri.dtype)
return self.linear_relpos(oh) return self.linear_relpos(oh)
def forward(self, def forward(self,
...@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module): ...@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z] # [*, N_res, N_res, c_z]
pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :] pair_emb = tf_emb_i[..., None, :] + tf_emb_j[..., None, :, :]
pair_emb += self.relpos(ri) pair_emb = pair_emb + self.relpos(ri.type(pair_emb.dtype))
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m] # [*, N_clust, N_res, c_m]
n_clust = msa.shape[-3] n_clust = msa.shape[-3]
tf_m = (self.linear_tf_m(tf) tf_m = (
self.linear_tf_m(tf)
.unsqueeze(-3) .unsqueeze(-3)
.expand((*(-1,) * len(tf.shape[:-2]), n_clust, -1, -1))) .expand(((-1,) * len(tf.shape[:-2]) + (n_clust, -1, -1)))
)
msa_emb = self.linear_msa_m(msa) + tf_m msa_emb = self.linear_msa_m(msa) + tf_m
return msa_emb, pair_emb return msa_emb, pair_emb
...@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module): ...@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self.min_bin, self.min_bin,
self.max_bin, self.max_bin,
self.no_bins, self.no_bins,
dtype=x.dtype,
requires_grad=False, requires_grad=False,
device=x.device device=x.device
) )
......
...@@ -170,68 +170,10 @@ class AlphaFold(nn.Module): ...@@ -170,68 +170,10 @@ class AlphaFold(nn.Module):
"torsion_angles_mask": angle_feats["torsion_angles_mask"], "torsion_angles_mask": angle_feats["torsion_angles_mask"],
} }
def forward(self, batch): def iteration(self, feats, m_1_prev, z_prev, x_prev):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Primary output dictionary # Primary output dictionary
outputs = {} outputs = {}
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Grab some data about the input # Grab some data about the input
batch_dims = feats["target_feat"].shape[:-2] batch_dims = feats["target_feat"].shape[:-2]
no_batch_dims = len(batch_dims) no_batch_dims = len(batch_dims)
...@@ -259,24 +201,21 @@ class AlphaFold(nn.Module): ...@@ -259,24 +201,21 @@ class AlphaFold(nn.Module):
# Initialize the recycling embeddings, if needs be # Initialize the recycling embeddings, if needs be
if(None in [m_1_prev, z_prev, x_prev]): if(None in [m_1_prev, z_prev, x_prev]):
# [*, N, C_m] # [*, N, C_m]
m_1_prev = torch.zeros( m_1_prev = m.new_zeros(
(*batch_dims, n, self.config.c_m), (*batch_dims, n, self.config.c_m),
requires_grad=False, requires_grad=False,
device=device,
) )
# [*, N, N, C_z] # [*, N, N, C_z]
z_prev = torch.zeros( z_prev = z.new_zeros(
(*batch_dims, n, n, self.config.c_z), (*batch_dims, n, n, self.config.c_z),
requires_grad=False, requires_grad=False,
device=device,
) )
# [*, N, 3] # [*, N, 3]
x_prev = torch.zeros( x_prev = z.new_zeros(
(*batch_dims, n, residue_constants.atom_type_num, 3), (*batch_dims, n, residue_constants.atom_type_num, 3),
requires_grad=False, requires_grad=False,
device=device,
) )
x_prev = pseudo_beta_fn( x_prev = pseudo_beta_fn(
...@@ -377,6 +316,74 @@ class AlphaFold(nn.Module): ...@@ -377,6 +316,74 @@ class AlphaFold(nn.Module):
# [*, N, 3] # [*, N, 3]
x_prev = outputs["final_atom_positions"] x_prev = outputs["final_atom_positions"]
return outputs, m_1_prev, z_prev, x_prev
def forward(self, batch):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev, z_prev, x_prev = None, None, None
# Main recycling loop
for cycle_no in range(self.config.no_cycles):
# Select the features for the current recycling cycle
fetch_cur_batch = lambda t: t[..., cycle_no]
feats = tensor_tree_map(fetch_cur_batch, batch)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter = (cycle_no == self.config.no_cycles - 1)
with torch.set_grad_enabled(self.training and is_final_iter):
outputs, m_1_prev, z_prev, x_prev = self.iteration(
feats, m_1_prev, z_prev, x_prev,
)
outputs.update(self.aux_heads(outputs)) outputs.update(self.aux_heads(outputs))
return outputs return outputs
...@@ -16,8 +16,9 @@ ...@@ -16,8 +16,9 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention, GlobalAttention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -69,7 +70,8 @@ class MSAAttention(nn.Module): ...@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self.c_z, self.no_heads, bias=False, init="normal" self.c_z, self.no_heads, bias=False, init="normal"
) )
self.mha = scripted_attention(
self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
...@@ -93,7 +95,7 @@ class MSAAttention(nn.Module): ...@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if(mask is None): if(mask is None):
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = torch.ones( mask = torch.ones(
(*m.shape[:-3], n_seq, n_res), m.shape[:-3] + (n_seq, n_res),
device=m.device, device=m.device,
requires_grad=False requires_grad=False
) )
...@@ -103,7 +105,7 @@ class MSAAttention(nn.Module): ...@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res] # [*, N_seq, no_heads, N_res, N_res]
bias = bias.expand( bias = bias.expand(
(*((-1,) * len(bias.shape[:-4])), -1, self.no_heads, n_res, -1) ((-1,) * len(bias.shape[:-4])) + (-1, self.no_heads, n_res, -1)
) )
biases = [bias] biases = [bias]
...@@ -115,7 +117,7 @@ class MSAAttention(nn.Module): ...@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z = self.linear_z(z) z = self.linear_z(z)
# [*, 1, no_heads, N_res, N_res] # [*, 1, no_heads, N_res, N_res]
z = permute_final_dims(z, 2, 0, 1).unsqueeze(-4) z = permute_final_dims(z, (2, 0, 1)).unsqueeze(-4)
biases.append(z) biases.append(z)
...@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self.inf = inf self.inf = inf
self.eps = eps self.eps = eps
self.layer_norm_m = nn.LayerNorm(self.c_in) self.layer_norm_m = nn.LayerNorm(c_in)
self.linear_q = Linear(
self.c_in, self.c_hidden * self.no_heads, bias=False, init="glorot"
)
C_hidden = self.c_hidden
self.linear_k = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
self.c_in, C_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(self.c_in, self.c_hidden * self.no_heads, init="gating")
self.linear_o = Linear(self.c_hidden * self.no_heads, self.c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def global_attention(self, m, mask):
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, C_hidden] self.global_attention = GlobalAttention(
k = self.linear_k(m) c_in=c_in,
v = self.linear_v(m) c_hidden=c_hidden,
no_heads=no_heads,
# [*, N_res, H, N_seq] inf=inf,
a = torch.matmul( eps=eps,
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a = a + bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
) )
# [*, N_res, N_seq, C_hidden] def forward(self,
g = self.sigmoid(self.linear_g(m)) m: torch.Tensor,
mask: Optional[torch.Tensor] = None
# [*, N_res, N_seq, H, C_hidden] ) -> torch.Tensor:
g = g.view(*g.shape[:-1], self.no_heads, -1)
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(*o.shape[:-2], -1)
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
def forward(self, m, mask=None):
n_seq, n_res, c_in = m.shape[-3:] n_seq, n_res, c_in = m.shape[-3:]
if(mask is None): if(mask is None):
# [*, N_seq, N_res] # [*, N_seq, N_res]
mask = m.new_ones(m.shape[:-1], requires_grad=False) mask = torch.ones(
m.shape[:-1],
dtype=m.dtype,
device=m.device,
).detach()
# [*, N_res, N_seq, C_in] # [*, N_res, N_seq, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
...@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module): ...@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims=len(m.shape[:-2]) no_batch_dims=len(m.shape[:-2])
) )
else: else:
m = self.global_attention(**mha_input) m = self.global_attention(m=mha_input["m"], mask=mha_input["mask"])
# [*, N_seq, N_res, C_in] # [*, N_seq, N_res, C_in]
m = m.transpose(-2, -3) m = m.transpose(-2, -3)
......
...@@ -235,12 +235,6 @@ class Attention(nn.Module): ...@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns Returns
[*, Q, C_q] attention update [*, Q, C_q] attention update
""" """
# Flatten batch dims
batch_dims = q_x.shape[:-2]
q_x = q_x.view((-1,) + q_x.shape[-2:])
k_x = k_x.view((-1,) + k_x.shape[-2:])
v_x = v_x.view((-1,) + v_x.shape[-2:])
# [*, Q/K/V, H * C_hidden] # [*, Q/K/V, H * C_hidden]
q = self.linear_q(q_x) q = self.linear_q(q_x)
k = self.linear_k(k_x) k = self.linear_k(k_x)
...@@ -253,20 +247,20 @@ class Attention(nn.Module): ...@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K] # [*, H, Q, K]
a = torch.matmul( a = torch.matmul(
q.permute(0, 2, 1, 3), # [*, H, Q, C_hidden] permute_final_dims(q, (0, 2, 1, 3)), # [*, H, Q, C_hidden]
k.permute(0, 2, 3, 1), # [*, H, C_hidden, K] permute_final_dims(k, (0, 2, 3, 1)), # [*, H, C_hidden, K]
) )
norm = 1 / math.sqrt(self.c_hidden) # [1] norm = 1 / math.sqrt(self.c_hidden) # [1]
a = a * norm a *= norm
if(biases is not None): if(biases is not None):
for b in biases: for b in biases:
a = a + b a += b
a = self.softmax(a) a = self.softmax(a)
# [*, H, Q, C_hidden] # [*, H, Q, C_hidden]
o = torch.matmul( o = torch.matmul(
a, a,
v.permute(0, 2, 1, 3), # [*, H, V, C_hidden] permute_final_dims(v, (0, 2, 1, 3)), # [*, H, V, C_hidden]
) )
# [*, Q, H, C_hidden] # [*, Q, H, C_hidden]
...@@ -282,11 +276,80 @@ class Attention(nn.Module): ...@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q] # [*, Q, C_q]
o = self.linear_o(o) o = self.linear_o(o)
# Restore the batch dims
o = o.reshape(batch_dims + o.shape[1:])
return o return o
def scripted_attention(*args, **kwargs): class GlobalAttention(nn.Module):
return torch.jit.script(Attention(*args, **kwargs)) def __init__(self, c_in, c_hidden, no_heads, inf, eps):
super(GlobalAttention, self).__init__()
self.c_in = c_in
self.c_hidden = c_hidden
self.no_heads = no_heads
self.inf = inf
self.eps = eps
self.linear_q = Linear(
c_in, c_hidden * no_heads, bias=False, init="glorot"
)
self.linear_k = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_v = Linear(
c_in, c_hidden, bias=False, init="glorot",
)
self.linear_g = Linear(c_in, c_hidden * no_heads, init="gating")
self.linear_o = Linear(c_hidden * no_heads, c_in, init="final")
self.sigmoid = nn.Sigmoid()
self.softmax = nn.Softmax(dim=-1)
def forward(self, m: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# [*, N_res, C_in]
q = (torch.sum(m * mask.unsqueeze(-1), dim=-2) /
(torch.sum(mask, dim=-1)[..., None] + self.eps))
# [*, N_res, H * C_hidden]
q = self.linear_q(q)
q = q * self.c_hidden ** (-0.5)
# [*, N_res, H, C_hidden]
q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, C_hidden]
k = self.linear_k(m)
v = self.linear_v(m)
# [*, N_res, H, N_seq]
a = torch.matmul(
q,
k.transpose(-1, -2), # [*, N_res, C_hidden, N_seq]
)
bias = (self.inf * (mask - 1))[..., :, None, :]
a += bias
a = self.softmax(a)
# [*, N_res, H, C_hidden]
o = torch.matmul(
a,
v,
)
# [*, N_res, N_seq, C_hidden]
g = self.sigmoid(self.linear_g(m))
# [*, N_res, N_seq, H, C_hidden]
g = g.view(g.shape[:-1] + (self.no_heads, -1))
# [*, N_res, N_seq, H, C_hidden]
o = o.unsqueeze(-3) * g
# [*, N_res, N_seq, H * C_hidden]
o = o.reshape(o.shape[:-2] + (-1,))
# [*, N_res, N_seq, C_in]
m = self.linear_o(o)
return m
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Optional from typing import Optional, Tuple
from openfold.model.primitives import Linear, ipa_point_weights_init_ from openfold.model.primitives import Linear, ipa_point_weights_init_
from openfold.np.residue_constants import ( from openfold.np.residue_constants import (
...@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module): ...@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, a): def forward(self, a: torch.Tensor) -> torch.Tensor:
s_initial = a s_initial = a
...@@ -85,7 +85,7 @@ class AngleResnet(nn.Module): ...@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self.c_hidden = c_hidden self.c_hidden = c_hidden
self.no_blocks = no_blocks self.no_blocks = no_blocks
self.no_angles = no_angles self.no_angles = no_angles
self.epsilon = epsilon self.eps = epsilon
self.linear_in = Linear(self.c_in, self.c_hidden) self.linear_in = Linear(self.c_in, self.c_hidden)
self.linear_initial = Linear(self.c_in, self.c_hidden) self.linear_initial = Linear(self.c_in, self.c_hidden)
...@@ -99,7 +99,10 @@ class AngleResnet(nn.Module): ...@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self.relu = nn.ReLU() self.relu = nn.ReLU()
def forward(self, s, s_initial): def forward(self,
s: torch.Tensor,
s_initial: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Args: Args:
s: s:
...@@ -130,14 +133,13 @@ class AngleResnet(nn.Module): ...@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s = self.linear_out(s) s = self.linear_out(s)
# [*, no_angles, 2] # [*, no_angles, 2]
s = s.view(*s.shape[:-1], -1, 2) s = s.view(s.shape[:-1] + (-1, 2))
unnormalized_s = s unnormalized_s = s
norm_denom = torch.sqrt( norm_denom = torch.sqrt(
torch.clamp( torch.clamp(
torch.sum(s ** 2, dim=-1, keepdims=True), torch.sum(s ** 2, dim=-1, keepdim=True),
min=self.epsilon, min=self.eps,
) )
) )
s = s / norm_denom s = s / norm_denom
...@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module): ...@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z: torch.Tensor, z: torch.Tensor,
t: T, t: T,
mask: torch.Tensor, mask: torch.Tensor,
): ) -> torch.Tensor:
""" """
Args: Args:
s: s:
...@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module): ...@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
####################################### #######################################
# Generate scalar and point activations # Generate scalar and point activations
####################################### #######################################
# [*, N_res, H * C_hidden] # [*, N_res, H * C_hidden]
q = self.linear_q(s) q = self.linear_q(s)
kv = self.linear_kv(s) kv = self.linear_kv(s)
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
q = q.view(*q.shape[:-1], self.no_heads, -1) q = q.view(q.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, 2 * C_hidden] # [*, N_res, H, 2 * C_hidden]
kv = kv.view(*kv.shape[:-1], self.no_heads, -1) kv = kv.view(kv.shape[:-1] + (self.no_heads, -1))
# [*, N_res, H, C_hidden] # [*, N_res, H, C_hidden]
k, v = torch.split(kv, self.c_hidden, dim=-1) k, v = torch.split(kv, self.c_hidden, dim=-1)
...@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module): ...@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3] # [*, N_res, H, P_q, 3]
q_pts = q_pts.view( q_pts = q_pts.view(
*q_pts.shape[:-2], self.no_heads, self.no_qk_points, 3 q_pts.shape[:-2] + (self.no_heads, self.no_qk_points, 3)
) )
# [*, N_res, H * (P_q + P_v) * 3] # [*, N_res, H * (P_q + P_v) * 3]
...@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module): ...@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3] # [*, N_res, H, (P_q + P_v), 3]
kv_pts = kv_pts.view( kv_pts = kv_pts.view(
*kv_pts.shape[:-2], self.no_heads, -1, 3 kv_pts.shape[:-2] + (self.no_heads, -1, 3)
) )
# [*, N_res, H, P_q/P_v, 3] # [*, N_res, H, P_q/P_v, 3]
...@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module): ...@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
a = torch.matmul( a = torch.matmul(
permute_final_dims(q, 1, 0, 2), # [*, H, N_res, C_hidden] permute_final_dims(q, (1, 0, 2)), # [*, H, N_res, C_hidden]
permute_final_dims(k, 1, 2, 0), # [*, H, C_hidden, N_res] permute_final_dims(k, (1, 2, 0)), # [*, H, C_hidden, N_res]
) )
a = a + math.sqrt(1. / (3 * self.c_hidden)) a = a + math.sqrt(1. / (3 * self.c_hidden))
a = a + math.sqrt(1. / 3) * permute_final_dims(b, 2, 0, 1) a = a + math.sqrt(1. / 3) * permute_final_dims(b, (2, 0, 1))
# [*, N_res, N_res, H, P_q, 3] # [*, N_res, N_res, H, P_q, 3]
pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5) pt_att = q_pts.unsqueeze(-4) - k_pts.unsqueeze(-5)
...@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module): ...@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask = self.inf * (square_mask - 1) square_mask = self.inf * (square_mask - 1)
# [*, H, N_res, N_res] # [*, H, N_res, N_res]
pt_att = permute_final_dims(pt_att, 2, 0, 1) pt_att = permute_final_dims(pt_att, (2, 0, 1))
a = a + pt_att a = a + pt_att
a = a + square_mask.unsqueeze(-3) a = a + square_mask.unsqueeze(-3)
a = self.softmax(a) a = self.softmax(a)
...@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module): ...@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v] # [*, H, 3, N_res, P_v]
o_pt = torch.matmul( o_pt = torch.matmul(
a.unsqueeze(-3), # [*, H, 1, N_res, N_res] a.unsqueeze(-3), # [*, H, 1, N_res, N_res]
permute_final_dims(v_pts, 1, 3, 0, 2), # [*, H, 3, N_res, P_v] permute_final_dims(v_pts, (1, 3, 0, 2)), # [*, H, 3, N_res, P_v]
) )
# [*, N_res, H, P_v, 3] # [*, N_res, H, P_v, 3]
o_pt = permute_final_dims(o_pt, 2, 0, 3, 1) o_pt = permute_final_dims(o_pt, (2, 0, 3, 1))
o_pt = t[..., None, None].invert_apply(o_pt) o_pt = t[..., None, None].invert_apply(o_pt)
# [*, N_res, H * P_v] # [*, N_res, H * P_v]
...@@ -758,35 +759,39 @@ class StructureModule(nn.Module): ...@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return outputs return outputs
def _init_residue_constants(self, device): def _init_residue_constants(self, dtype, device):
if(self.default_frames is None): if(self.default_frames is None):
self.default_frames = torch.tensor( self.default_frames = torch.tensor(
restype_rigid_group_default_frame, restype_rigid_group_default_frame,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.group_idx is None): if(self.group_idx is None):
self.group_idx = torch.tensor( self.group_idx = torch.tensor(
restype_atom14_to_rigid_group, restype_atom14_to_rigid_group,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.atom_mask is None): if(self.atom_mask is None):
self.atom_mask = torch.tensor( self.atom_mask = torch.tensor(
restype_atom14_mask, restype_atom14_mask,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
if(self.lit_positions is None): if(self.lit_positions is None):
self.lit_positions = torch.tensor( self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions, restype_atom14_rigid_group_positions,
dtype=dtype,
device=device, device=device,
requires_grad=False, requires_grad=False,
) )
def torsion_angles_to_frames(self, t, alpha, f): def torsion_angles_to_frames(self, t, alpha, f):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
self._init_residue_constants(f.device) self._init_residue_constants(f.dtype, f.device)
# Separated purely to make testing less annoying # Separated purely to make testing less annoying
return _torsion_angles_to_frames(t, alpha, f, self.default_frames) return _torsion_angles_to_frames(t, alpha, f, self.default_frames)
...@@ -797,7 +802,7 @@ class StructureModule(nn.Module): ...@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device # Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these # TODO: Maybe this stuff should be done on CPU instead (so these
# arrays # arrays
self._init_residue_constants(f.device) self._init_residue_constants(f.dtype, f.device)
return _frames_and_literature_positions_to_atom14_pos( return _frames_and_literature_positions_to_atom14_pos(
t, t,
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention
from openfold.utils.deepspeed import checkpoint_blocks from openfold.utils.deepspeed import checkpoint_blocks
from openfold.model.dropout import ( from openfold.model.dropout import (
DropoutRowwise, DropoutRowwise,
...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self.no_heads = no_heads self.no_heads = no_heads
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.mha = scripted_attention( self.mha = Attention(
self.c_z, self.c_t, self.c_t, self.c_z, self.c_t, self.c_t,
self.c_hidden, self.no_heads, self.c_hidden, self.no_heads,
gating=False, gating=False,
...@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a # NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1, # [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess. # but not sure enough to remove it. It's nice to have, I guess.
template_mask = torch.ones(t.shape[:-3], device=t.device) template_mask = t.new_ones(t.shape[:-3])
bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1)) bias = (1e9 * (template_mask[..., None, None, None, None, :] - 1))
...@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module): ...@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z = z.unsqueeze(-2) z = z.unsqueeze(-2)
# [*, N_res, N_res, N_temp, C_t] # [*, N_res, N_res, N_temp, C_t]
t = permute_final_dims(t, 1, 2, 0, 3) t = permute_final_dims(t, (1, 2, 0, 3))
# [*, N_res, N_res, 1, C_z] # [*, N_res, N_res, 1, C_z]
mha_inputs = { mha_inputs = {
......
...@@ -18,7 +18,7 @@ import math ...@@ -18,7 +18,7 @@ import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from openfold.model.primitives import Linear, scripted_attention from openfold.model.primitives import Linear, Attention
from openfold.utils.tensor_utils import ( from openfold.utils.tensor_utils import (
chunk_layer, chunk_layer,
permute_final_dims, permute_final_dims,
...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module): ...@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self.linear = Linear(c_in, self.no_heads, bias=False, init="normal") self.linear = Linear(c_in, self.no_heads, bias=False, init="normal")
self.mha = scripted_attention( self.mha = Attention(
self.c_in, self.c_in, self.c_in, self.c_in, self.c_in, self.c_in,
self.c_hidden, self.c_hidden,
self.no_heads self.no_heads
...@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module): ...@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias = (self.inf * (mask - 1))[..., :, None, None, :] mask_bias = (self.inf * (mask - 1))[..., :, None, None, :]
# [*, H, I, J] # [*, H, I, J]
triangle_bias = permute_final_dims(self.linear(x), 2, 0, 1) triangle_bias = permute_final_dims(self.linear(x), (2, 0, 1))
# [*, 1, H, I, J] # [*, 1, H, I, J]
triangle_bias = triangle_bias.unsqueeze(-4) triangle_bias = triangle_bias.unsqueeze(-4)
......
...@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
): ):
# [*, C, N_i, N_j] # [*, C, N_i, N_j]
p = torch.matmul( p = torch.matmul(
permute_final_dims(a, 2, 0, 1), permute_final_dims(a, (2, 0, 1)),
permute_final_dims(b, 2, 1, 0), permute_final_dims(b, (2, 1, 0)),
) )
# [*, N_i, N_j, C] # [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0) return permute_final_dims(p, (1, 2, 0))
def _incoming_matmul(self, def _incoming_matmul(self,
a: torch.Tensor, # [*, N_k, N_i, C] a: torch.Tensor, # [*, N_k, N_i, C]
...@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module): ...@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j] # [*, C, N_i, N_j]
p = torch.matmul( p = torch.matmul(
permute_final_dims(a, 2, 1, 0), permute_final_dims(a, (2, 1, 0)),
permute_final_dims(b, 2, 0, 1), permute_final_dims(b, (2, 0, 1)),
) )
# [*, N_i, N_j, C] # [*, N_i, N_j, C]
return permute_final_dims(p, 1, 2, 0) return permute_final_dims(p, (1, 2, 0))
def forward(self, z, mask=None): def forward(self, z, mask=None):
""" """
......
...@@ -13,30 +13,25 @@ ...@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
import torch import torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def rot_matmul(a, b): def rot_matmul(a, b):
e = ...
row_1 = torch.stack([ row_1 = torch.stack([
a[e,0,0]*b[e,0,0] + a[e,0,1]*b[e,1,0] + a[e,0,2]*b[e,2,0], a[...,0,0]*b[...,0,0] + a[...,0,1]*b[...,1,0] + a[...,0,2]*b[...,2,0],
a[e,0,0]*b[e,0,1] + a[e,0,1]*b[e,1,1] + a[e,0,2]*b[e,2,1], a[...,0,0]*b[...,0,1] + a[...,0,1]*b[...,1,1] + a[...,0,2]*b[...,2,1],
a[e,0,0]*b[e,0,2] + a[e,0,1]*b[e,1,2] + a[e,0,2]*b[e,2,2], a[...,0,0]*b[...,0,2] + a[...,0,1]*b[...,1,2] + a[...,0,2]*b[...,2,2],
], dim=-1) ], dim=-1)
row_2 = torch.stack([ row_2 = torch.stack([
a[e,1,0]*b[e,0,0] + a[e,1,1]*b[e,1,0] + a[e,1,2]*b[e,2,0], a[...,1,0]*b[...,0,0] + a[...,1,1]*b[...,1,0] + a[...,1,2]*b[...,2,0],
a[e,1,0]*b[e,0,1] + a[e,1,1]*b[e,1,1] + a[e,1,2]*b[e,2,1], a[...,1,0]*b[...,0,1] + a[...,1,1]*b[...,1,1] + a[...,1,2]*b[...,2,1],
a[e,1,0]*b[e,0,2] + a[e,1,1]*b[e,1,2] + a[e,1,2]*b[e,2,2], a[...,1,0]*b[...,0,2] + a[...,1,1]*b[...,1,2] + a[...,1,2]*b[...,2,2],
], dim=-1) ], dim=-1)
row_3 = torch.stack([ row_3 = torch.stack([
a[e,2,0]*b[e,0,0] + a[e,2,1]*b[e,1,0] + a[e,2,2]*b[e,2,0], a[...,2,0]*b[...,0,0] + a[...,2,1]*b[...,1,0] + a[...,2,2]*b[...,2,0],
a[e,2,0]*b[e,0,1] + a[e,2,1]*b[e,1,1] + a[e,2,2]*b[e,2,1], a[...,2,0]*b[...,0,1] + a[...,2,1]*b[...,1,1] + a[...,2,2]*b[...,2,1],
a[e,2,0]*b[e,0,2] + a[e,2,1]*b[e,1,2] + a[e,2,2]*b[e,2,2], a[...,2,0]*b[...,0,2] + a[...,2,1]*b[...,1,2] + a[...,2,2]*b[...,2,2],
], dim=-1) ], dim=-1)
return torch.stack([row_1, row_2, row_3], dim=-2) return torch.stack([row_1, row_2, row_3], dim=-2)
...@@ -175,7 +170,7 @@ class T: ...@@ -175,7 +170,7 @@ class T:
return T(rots, trans) return T(rots, trans)
def to_4x4(self): def to_4x4(self):
tensor = torch.zeros((*self.shape, 4, 4), device=self.rots.device) tensor = self.rots.new_zeros((*self.shape, 4, 4))
tensor[..., :3, :3] = self.rots tensor[..., :3, :3] = self.rots
tensor[..., :3, 3] = self.trans tensor[..., :3, 3] = self.trans
tensor[..., 3, 3] = 1 tensor[..., 3, 3] = 1
...@@ -311,7 +306,7 @@ def _to_mat(pairs): ...@@ -311,7 +306,7 @@ def _to_mat(pairs):
return mat return mat
_qtr_mat = torch.zeros((4, 4, 3, 3)) _qtr_mat = np.zeros((4, 4, 3, 3))
_qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)]) _qtr_mat[..., 0, 0] = _to_mat([('aa', 1), ('bb', 1), ('cc', -1), ('dd', -1)])
_qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)]) _qtr_mat[..., 0, 1] = _to_mat([('bc', 2), ('ad', -2)])
_qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)]) _qtr_mat[..., 0, 2] = _to_mat([('bd', 2), ('ac', 2)])
...@@ -328,9 +323,11 @@ def quat_to_rot( ...@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4] # [*, 4, 4]
quat = quat[..., None] * quat[..., None, :] quat = quat[..., None] * quat[..., None, :]
mat = quat.new_tensor(_qtr_mat)
# [*, 4, 4, 3, 3] # [*, 4, 4, 3, 3]
shaped_qtr_mat = _qtr_mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3)) shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + (4, 4, 3, 3))
quat = quat[..., None, None] * shaped_qtr_mat.to(quat.device) quat = quat[..., None, None] * shaped_qtr_mat
# [*, 3, 3] # [*, 3, 3]
return torch.sum(quat, dim=(-3, -4)) return torch.sum(quat, dim=(-3, -4))
...@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector): ...@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats = vector[..., :4] quats = vector[..., :4]
trans = vector[..., 4:] trans = vector[..., 4:]
four_by_four = torch.zeros( four_by_four = vector.new_zeros((*vector.shape[:-1], 4, 4))
(*vector.shape[:-1], 4, 4), device=vector.device
)
four_by_four[..., :3, :3] = quat_to_rot(quats) four_by_four[..., :3, :3] = quat_to_rot(quats)
four_by_four[..., :3, 3] = trans four_by_four[..., :3, 3] = trans
four_by_four[..., 3, 3] = 1 four_by_four[..., 3, 3] = 1
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import deepspeed import deepspeed
import torch import torch
from torch.utils.checkpoint import checkpoint
from typing import Any, Tuple, List, Callable from typing import Any, Tuple, List, Callable
BLOCK_ARG = Any BLOCK_ARG = Any
...@@ -55,7 +56,7 @@ def checkpoint_blocks( ...@@ -55,7 +56,7 @@ def checkpoint_blocks(
return a return a
def chunker(s, e): def chunker(s, e):
def exec_sliced(a): def exec_sliced(*a):
return exec(blocks[s:e], a) return exec(blocks[s:e], a)
return exec_sliced return exec_sliced
...@@ -69,7 +70,7 @@ def checkpoint_blocks( ...@@ -69,7 +70,7 @@ def checkpoint_blocks(
for s in range(0, len(blocks), blocks_per_ckpt): for s in range(0, len(blocks), blocks_per_ckpt):
e = s + blocks_per_ckpt e = s + blocks_per_ckpt
args = deepspeed.checkpointing.checkpoint(chunker(s, e), args) args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args = wrap(args) args = wrap(args)
return args return args
...@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"): ...@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams = lambda matt: { MSAGlobalAttParams = lambda matt: {
"query_norm": LayerNormParams(matt.layer_norm_m), "query_norm": LayerNormParams(matt.layer_norm_m),
"attention": GlobalAttentionParams(matt) "attention": GlobalAttentionParams(matt.global_attention)
} }
MSAAttPairBiasParams = lambda matt: dict( MSAAttPairBiasParams = lambda matt: dict(
......
...@@ -356,7 +356,7 @@ def lddt_loss( ...@@ -356,7 +356,7 @@ def lddt_loss(
) )
dists_to_score = ( dists_to_score = (
(dmat_true < cutoff) * all_atom_mask * (dmat_true < cutoff) * all_atom_mask *
permute_final_dims(all_atom_mask, 1, 0) * permute_final_dims(all_atom_mask, (1, 0)) *
(1. - torch.eye(n, device=all_atom_mask.device)) (1. - torch.eye(n, device=all_atom_mask.device))
) )
......
...@@ -16,12 +16,13 @@ ...@@ -16,12 +16,13 @@
from functools import partial from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import Tuple, List, Callable, Any, Dict
def permute_final_dims(tensor, *inds): def permute_final_dims(tensor: torch.Tensor, inds: List[int]):
zero_index = -1 * len(inds) zero_index = -1 * len(inds)
first_inds = range(len(tensor.shape[:zero_index])) first_inds = list(range(len(tensor.shape[:zero_index])))
return tensor.permute(*first_inds, *[zero_index + i for i in inds]) return tensor.permute(first_inds + [zero_index + i for i in inds])
def flatten_final_dims(tensor: torch.Tensor, no_dims: int): def flatten_final_dims(tensor: torch.Tensor, no_dims: int):
...@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts): ...@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def one_hot(x, v_bins): def one_hot(x, v_bins):
reshaped_bins = v_bins.view(*((1,) * len(x.shape) + (len(v_bins),))) reshaped_bins = v_bins.view(((1,) * len(x.shape)) + (len(v_bins),))
diffs = x[..., None] - reshaped_bins diffs = x[..., None] - reshaped_bins
am = torch.argmin(torch.abs(diffs), dim=-1) am = torch.argmin(torch.abs(diffs), dim=-1)
return nn.functional.one_hot(am, num_classes=len(v_bins)).float() return nn.functional.one_hot(am, num_classes=len(v_bins)).float()
...@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type): ...@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor)
def chunk_layer(layer, inputs, chunk_size, no_batch_dims): def chunk_layer(
layer: Callable,
inputs: Dict[str, Any],
chunk_size: int,
no_batch_dims: int,
) -> Any:
""" """
Implements the "chunking" procedure described in section 1.11.8. Implements the "chunking" procedure described in section 1.11.8.
...@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer: layer:
The layer to be applied chunk-wise The layer to be applied chunk-wise
inputs: inputs:
A (nested) dictionary of keyworded inputs. All leaves must be A (non-nested) dictionary of keyworded inputs. All leaves must
tensors and must share the same batch dimensions. be tensors and must share the same batch dimensions.
chunk_size: chunk_size:
The number of sub-batches per chunk. If multiple batch The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single dimensions are specified, a "sub-batch" is defined as a single
...@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return shapes return shapes
initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)] initial_dims = [shape[:no_batch_dims] for shape in fetch_dims(inputs)]
orig_batch_dims = [max(s) for s in zip(*initial_dims)] orig_batch_dims = tuple([max(s) for s in zip(*initial_dims)])
def prep_inputs(t): def prep_inputs(t):
# TODO: make this more memory efficient. This sucks # TODO: make this more memory efficient. This sucks
...@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output # Allocate space for the output
if(out is None): if(out is None):
allocate = lambda t: t.new_zeros(flat_batch_dim, *t.shape[1:]) allocate = lambda t: t.new_zeros((flat_batch_dim,) + t.shape[1:])
out = tensor_tree_map(allocate, output_chunk) out = tensor_tree_map(allocate, output_chunk)
# Put the chunk in its pre-allocated space # Put the chunk in its pre-allocated space
...@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims): ...@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i += chunk_size i += chunk_size
reshape = lambda t: t.reshape(*orig_batch_dims, *t.shape[1:]) reshape = lambda t: t.reshape(orig_batch_dims + t.shape[1:])
out = tensor_tree_map(reshape, out) out = tensor_tree_map(reshape, out)
return out return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment