Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b2d102cb
"container/vscode:/vscode.git/clone" did not exist on "20b3684387645d0f27895fcbf80e9ead88ba86b5"
Commit
b2d102cb
authored
Sep 27, 2021
by
Gustaf Ahdritz
Browse files
Refactor certain modules for TorchScript, fix recycling bug
parent
4bd4ad93
Changes
13
Show whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
349 additions
and
318 deletions
+349
-318
openfold/model/embedders.py
openfold/model/embedders.py
+8
-6
openfold/model/model.py
openfold/model/model.py
+155
-148
openfold/model/msa.py
openfold/model/msa.py
+24
-72
openfold/model/primitives.py
openfold/model/primitives.py
+79
-16
openfold/model/structure_module.py
openfold/model/structure_module.py
+32
-27
openfold/model/template.py
openfold/model/template.py
+4
-4
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+3
-3
openfold/model/triangular_multiplicative_update.py
openfold/model/triangular_multiplicative_update.py
+6
-6
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+17
-22
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+3
-2
openfold/utils/import_weights.py
openfold/utils/import_weights.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+1
-1
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+16
-10
No files found.
openfold/model/embedders.py
View file @
b2d102cb
...
...
@@ -83,7 +83,7 @@ class InputEmbedder(nn.Module):
boundaries
=
torch
.
arange
(
start
=-
self
.
relpos_k
,
end
=
self
.
relpos_k
+
1
,
device
=
d
.
device
)
oh
=
one_hot
(
d
,
boundaries
)
oh
=
one_hot
(
d
,
boundaries
)
.
type
(
ri
.
dtype
)
return
self
.
linear_relpos
(
oh
)
def
forward
(
self
,
...
...
@@ -112,14 +112,15 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
+=
self
.
relpos
(
ri
)
#pair_emb = pair_emb + self.relpos(ri)
pair_emb
=
pair_emb
+
self
.
relpos
(
ri
.
type
(
pair_emb
.
dtype
))
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
tf_m
=
(
self
.
linear_tf_m
(
tf
)
tf_m
=
(
self
.
linear_tf_m
(
tf
)
.
unsqueeze
(
-
3
)
.
expand
((
*
(
-
1
,)
*
len
(
tf
.
shape
[:
-
2
]),
n_clust
,
-
1
,
-
1
)))
.
expand
(((
-
1
,)
*
len
(
tf
.
shape
[:
-
2
])
+
(
n_clust
,
-
1
,
-
1
)))
)
msa_emb
=
self
.
linear_msa_m
(
msa
)
+
tf_m
return
msa_emb
,
pair_emb
...
...
@@ -192,6 +193,7 @@ class RecyclingEmbedder(nn.Module):
self
.
min_bin
,
self
.
max_bin
,
self
.
no_bins
,
dtype
=
x
.
dtype
,
requires_grad
=
False
,
device
=
x
.
device
)
...
...
openfold/model/model.py
View file @
b2d102cb
...
...
@@ -170,68 +170,10 @@ class AlphaFold(nn.Module):
"torsion_angles_mask"
:
angle_feats
[
"torsion_angles_mask"
],
}
def
forward
(
self
,
batch
):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
# Primary output dictionary
outputs
=
{}
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Grab some data about the input
batch_dims
=
feats
[
"target_feat"
].
shape
[:
-
2
]
no_batch_dims
=
len
(
batch_dims
)
...
...
@@ -259,24 +201,21 @@ class AlphaFold(nn.Module):
# Initialize the recycling embeddings, if needs be
if
(
None
in
[
m_1_prev
,
z_prev
,
x_prev
]):
# [*, N, C_m]
m_1_prev
=
torch
.
zeros
(
m_1_prev
=
m
.
new_
zeros
(
(
*
batch_dims
,
n
,
self
.
config
.
c_m
),
requires_grad
=
False
,
device
=
device
,
)
# [*, N, N, C_z]
z_prev
=
torch
.
zeros
(
z_prev
=
z
.
new_
zeros
(
(
*
batch_dims
,
n
,
n
,
self
.
config
.
c_z
),
requires_grad
=
False
,
device
=
device
,
)
# [*, N, 3]
x_prev
=
torch
.
zeros
(
x_prev
=
z
.
new_
zeros
(
(
*
batch_dims
,
n
,
residue_constants
.
atom_type_num
,
3
),
requires_grad
=
False
,
device
=
device
,
)
x_prev
=
pseudo_beta_fn
(
...
...
@@ -377,6 +316,74 @@ class AlphaFold(nn.Module):
# [*, N, 3]
x_prev
=
outputs
[
"final_atom_positions"
]
return
outputs
,
m_1_prev
,
z_prev
,
x_prev
def
forward
(
self
,
batch
):
"""
Args:
batch:
Dictionary of arguments outlined in Algorithm 2. Keys must
include the official names of the features in the
supplement subsection 1.2.9.
The final dimension of each input must have length equal to
the number of recycling iterations.
Features (without the recycling dimension):
"aatype" ([*, N_res]):
Contrary to the supplement, this tensor of residue
indices is not one-hot.
"target_feat" ([*, N_res, C_tf])
One-hot encoding of the target sequence. C_tf is
config.model.input_embedder.tf_dim.
"residue_index" ([*, N_res])
Tensor whose final dimension consists of
consecutive indices from 0 to N_res.
"msa_feat" ([*, N_seq, N_res, C_msa])
MSA features, constructed as in the supplement.
C_msa is config.model.input_embedder.msa_dim.
"seq_mask" ([*, N_res])
1-D sequence mask
"msa_mask" ([*, N_seq, N_res])
MSA mask
"pair_mask" ([*, N_res, N_res])
2-D pair mask
"extra_msa_mask" ([*, N_extra, N_res])
Extra MSA mask
"template_mask" ([*, N_templ])
Template mask (on the level of templates, not
residues)
"template_aatype" ([*, N_templ, N_res])
Tensor of template residue indices (indices greater
than 19 are clamped to 20 (Unknown))
"template_all_atom_pos" ([*, N_templ, N_res, 37, 3])
Template atom coordinates in atom37 format
"template_all_atom_mask" ([*, N_templ, N_res, 37])
Template atom coordinate mask
"template_pseudo_beta" ([*, N_templ, N_res, 3])
Positions of template carbon "pseudo-beta" atoms
(i.e. C_beta for all residues but glycine, for
for which C_alpha is used instead)
"template_pseudo_beta_mask" ([*, N_templ, N_res])
Pseudo-beta mask
"""
# Recycling embeddings
m_1_prev
,
z_prev
,
x_prev
=
None
,
None
,
None
# Main recycling loop
for
cycle_no
in
range
(
self
.
config
.
no_cycles
):
# Select the features for the current recycling cycle
fetch_cur_batch
=
lambda
t
:
t
[...,
cycle_no
]
feats
=
tensor_tree_map
(
fetch_cur_batch
,
batch
)
# Enable grad iff we're training and it's the final recycling layer
is_final_iter
=
(
cycle_no
==
self
.
config
.
no_cycles
-
1
)
with
torch
.
set_grad_enabled
(
self
.
training
and
is_final_iter
):
outputs
,
m_1_prev
,
z_prev
,
x_prev
=
self
.
iteration
(
feats
,
m_1_prev
,
z_prev
,
x_prev
,
)
outputs
.
update
(
self
.
aux_heads
(
outputs
))
return
outputs
openfold/model/msa.py
View file @
b2d102cb
...
...
@@ -16,8 +16,9 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
Attention
,
GlobalA
ttention
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -69,7 +70,8 @@ class MSAAttention(nn.Module):
self
.
c_z
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
scripted_attention
(
self
.
mha
=
Attention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
...
...
@@ -93,7 +95,7 @@ class MSAAttention(nn.Module):
if
(
mask
is
None
):
# [*, N_seq, N_res]
mask
=
torch
.
ones
(
(
*
m
.
shape
[:
-
3
]
,
n_seq
,
n_res
),
m
.
shape
[:
-
3
]
+
(
n_seq
,
n_res
),
device
=
m
.
device
,
requires_grad
=
False
)
...
...
@@ -103,7 +105,7 @@ class MSAAttention(nn.Module):
# [*, N_seq, no_heads, N_res, N_res]
bias
=
bias
.
expand
(
(
*
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
,
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
)
biases
=
[
bias
]
...
...
@@ -115,7 +117,7 @@ class MSAAttention(nn.Module):
z
=
self
.
linear_z
(
z
)
# [*, 1, no_heads, N_res, N_res]
z
=
permute_final_dims
(
z
,
2
,
0
,
1
).
unsqueeze
(
-
4
)
z
=
permute_final_dims
(
z
,
(
2
,
0
,
1
)
)
.
unsqueeze
(
-
4
)
biases
.
append
(
z
)
...
...
@@ -234,79 +236,29 @@ class MSAColumnGlobalAttention(nn.Module):
self
.
inf
=
inf
self
.
eps
=
eps
self
.
layer_norm_m
=
nn
.
LayerNorm
(
self
.
c_in
)
self
.
linear_q
=
Linear
(
self
.
c_in
,
self
.
c_hidden
*
self
.
no_heads
,
bias
=
False
,
init
=
"glorot"
)
C_hidden
=
self
.
c_hidden
self
.
linear_k
=
Linear
(
self
.
c_in
,
C_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_v
=
Linear
(
self
.
c_in
,
C_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_g
=
Linear
(
self
.
c_in
,
self
.
c_hidden
*
self
.
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
self
.
c_hidden
*
self
.
no_heads
,
self
.
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
global_attention
(
self
,
m
,
mask
):
# [*, N_res, C_in]
q
=
(
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
))
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
self
.
layer_norm_m
=
nn
.
LayerNorm
(
c_in
)
# [*, N_res, N_seq, C_hidden]
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
=
a
+
bias
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
self
.
global_attention
=
GlobalAttention
(
c_in
=
c_in
,
c_hidden
=
c_hidden
,
no_heads
=
no_heads
,
inf
=
inf
,
eps
=
eps
,
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
# [*, N_res, N_seq, H, C_hidden]
g
=
g
.
view
(
*
g
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
*
o
.
shape
[:
-
2
],
-
1
)
# [*, N_res, N_seq, C_in]
m
=
self
.
linear_o
(
o
)
return
m
def
forward
(
self
,
m
,
mask
=
None
):
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
n_seq
,
n_res
,
c_in
=
m
.
shape
[
-
3
:]
if
(
mask
is
None
):
# [*, N_seq, N_res]
mask
=
m
.
new_ones
(
m
.
shape
[:
-
1
],
requires_grad
=
False
)
mask
=
torch
.
ones
(
m
.
shape
[:
-
1
],
dtype
=
m
.
dtype
,
device
=
m
.
device
,
).
detach
()
# [*, N_res, N_seq, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
@@ -327,7 +279,7 @@ class MSAColumnGlobalAttention(nn.Module):
no_batch_dims
=
len
(
m
.
shape
[:
-
2
])
)
else
:
m
=
self
.
global_attention
(
**
mha_input
)
m
=
self
.
global_attention
(
m
=
mha_input
[
"m"
],
mask
=
mha_input
[
"mask"
]
)
# [*, N_seq, N_res, C_in]
m
=
m
.
transpose
(
-
2
,
-
3
)
...
...
openfold/model/primitives.py
View file @
b2d102cb
...
...
@@ -235,12 +235,6 @@ class Attention(nn.Module):
Returns
[*, Q, C_q] attention update
"""
# Flatten batch dims
batch_dims
=
q_x
.
shape
[:
-
2
]
q_x
=
q_x
.
view
((
-
1
,)
+
q_x
.
shape
[
-
2
:])
k_x
=
k_x
.
view
((
-
1
,)
+
k_x
.
shape
[
-
2
:])
v_x
=
v_x
.
view
((
-
1
,)
+
v_x
.
shape
[
-
2
:])
# [*, Q/K/V, H * C_hidden]
q
=
self
.
linear_q
(
q_x
)
k
=
self
.
linear_k
(
k_x
)
...
...
@@ -253,20 +247,20 @@ class Attention(nn.Module):
# [*, H, Q, K]
a
=
torch
.
matmul
(
q
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, Q, C_hidden]
k
.
permute
(
0
,
2
,
3
,
1
),
# [*, H, C_hidden, K]
permute
_final_dims
(
q
,
(
0
,
2
,
1
,
3
)
)
,
# [*, H, Q, C_hidden]
permute
_final_dims
(
k
,
(
0
,
2
,
3
,
1
)
)
,
# [*, H, C_hidden, K]
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
=
a
*
norm
a
*
=
norm
if
(
biases
is
not
None
):
for
b
in
biases
:
a
=
a
+
b
a
+
=
b
a
=
self
.
softmax
(
a
)
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
a
,
v
.
permute
(
0
,
2
,
1
,
3
),
# [*, H, V, C_hidden]
permute
_final_dims
(
v
,
(
0
,
2
,
1
,
3
)
)
,
# [*, H, V, C_hidden]
)
# [*, Q, H, C_hidden]
...
...
@@ -282,11 +276,80 @@ class Attention(nn.Module):
# [*, Q, C_q]
o
=
self
.
linear_o
(
o
)
# Restore the batch dims
o
=
o
.
reshape
(
batch_dims
+
o
.
shape
[
1
:])
return
o
def
scripted_attention
(
*
args
,
**
kwargs
):
return
torch
.
jit
.
script
(
Attention
(
*
args
,
**
kwargs
))
class
GlobalAttention
(
nn
.
Module
):
def
__init__
(
self
,
c_in
,
c_hidden
,
no_heads
,
inf
,
eps
):
super
(
GlobalAttention
,
self
).
__init__
()
self
.
c_in
=
c_in
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
inf
=
inf
self
.
eps
=
eps
self
.
linear_q
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
bias
=
False
,
init
=
"glorot"
)
self
.
linear_k
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_v
=
Linear
(
c_in
,
c_hidden
,
bias
=
False
,
init
=
"glorot"
,
)
self
.
linear_g
=
Linear
(
c_in
,
c_hidden
*
no_heads
,
init
=
"gating"
)
self
.
linear_o
=
Linear
(
c_hidden
*
no_heads
,
c_in
,
init
=
"final"
)
self
.
sigmoid
=
nn
.
Sigmoid
()
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
m
:
torch
.
Tensor
,
mask
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# [*, N_res, C_in]
q
=
(
torch
.
sum
(
m
*
mask
.
unsqueeze
(
-
1
),
dim
=-
2
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)[...,
None
]
+
self
.
eps
))
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, C_hidden]
k
=
self
.
linear_k
(
m
)
v
=
self
.
linear_v
(
m
)
# [*, N_res, H, N_seq]
a
=
torch
.
matmul
(
q
,
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+=
bias
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
o
=
torch
.
matmul
(
a
,
v
,
)
# [*, N_res, N_seq, C_hidden]
g
=
self
.
sigmoid
(
self
.
linear_g
(
m
))
# [*, N_res, N_seq, H, C_hidden]
g
=
g
.
view
(
g
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
))
# [*, N_res, N_seq, H, C_hidden]
o
=
o
.
unsqueeze
(
-
3
)
*
g
# [*, N_res, N_seq, H * C_hidden]
o
=
o
.
reshape
(
o
.
shape
[:
-
2
]
+
(
-
1
,))
# [*, N_res, N_seq, C_in]
m
=
self
.
linear_o
(
o
)
return
m
openfold/model/structure_module.py
View file @
b2d102cb
...
...
@@ -16,7 +16,7 @@
import
math
import
torch
import
torch.nn
as
nn
from
typing
import
Optional
from
typing
import
Optional
,
Tuple
from
openfold.model.primitives
import
Linear
,
ipa_point_weights_init_
from
openfold.np.residue_constants
import
(
...
...
@@ -49,7 +49,7 @@ class AngleResnetBlock(nn.Module):
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
a
)
:
def
forward
(
self
,
a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
s_initial
=
a
...
...
@@ -85,7 +85,7 @@ class AngleResnet(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
no_blocks
=
no_blocks
self
.
no_angles
=
no_angles
self
.
eps
ilon
=
epsilon
self
.
eps
=
epsilon
self
.
linear_in
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
self
.
linear_initial
=
Linear
(
self
.
c_in
,
self
.
c_hidden
)
...
...
@@ -99,7 +99,10 @@ class AngleResnet(nn.Module):
self
.
relu
=
nn
.
ReLU
()
def
forward
(
self
,
s
,
s_initial
):
def
forward
(
self
,
s
:
torch
.
Tensor
,
s_initial
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Args:
s:
...
...
@@ -130,14 +133,13 @@ class AngleResnet(nn.Module):
s
=
self
.
linear_out
(
s
)
# [*, no_angles, 2]
s
=
s
.
view
(
*
s
.
shape
[:
-
1
]
,
-
1
,
2
)
s
=
s
.
view
(
s
.
shape
[:
-
1
]
+
(
-
1
,
2
)
)
unnormalized_s
=
s
norm_denom
=
torch
.
sqrt
(
torch
.
clamp
(
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
s
=
True
),
min
=
self
.
eps
ilon
,
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdim
=
True
),
min
=
self
.
eps
,
)
)
s
=
s
/
norm_denom
...
...
@@ -219,7 +221,7 @@ class InvariantPointAttention(nn.Module):
z
:
torch
.
Tensor
,
t
:
T
,
mask
:
torch
.
Tensor
,
):
)
->
torch
.
Tensor
:
"""
Args:
s:
...
...
@@ -236,16 +238,15 @@ class InvariantPointAttention(nn.Module):
#######################################
# Generate scalar and point activations
#######################################
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
s
)
kv
=
self
.
linear_kv
(
s
)
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
q
=
q
.
view
(
q
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H, 2 * C_hidden]
kv
=
kv
.
view
(
*
kv
.
shape
[:
-
1
]
,
self
.
no_heads
,
-
1
)
kv
=
kv
.
view
(
kv
.
shape
[:
-
1
]
+
(
self
.
no_heads
,
-
1
)
)
# [*, N_res, H, C_hidden]
k
,
v
=
torch
.
split
(
kv
,
self
.
c_hidden
,
dim
=-
1
)
...
...
@@ -261,7 +262,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, P_q, 3]
q_pts
=
q_pts
.
view
(
*
q_pts
.
shape
[:
-
2
]
,
self
.
no_heads
,
self
.
no_qk_points
,
3
q_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
self
.
no_qk_points
,
3
)
)
# [*, N_res, H * (P_q + P_v) * 3]
...
...
@@ -274,7 +275,7 @@ class InvariantPointAttention(nn.Module):
# [*, N_res, H, (P_q + P_v), 3]
kv_pts
=
kv_pts
.
view
(
*
kv_pts
.
shape
[:
-
2
]
,
self
.
no_heads
,
-
1
,
3
kv_pts
.
shape
[:
-
2
]
+
(
self
.
no_heads
,
-
1
,
3
)
)
# [*, N_res, H, P_q/P_v, 3]
...
...
@@ -293,11 +294,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
a
=
torch
.
matmul
(
permute_final_dims
(
q
,
1
,
0
,
2
),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, N_res]
permute_final_dims
(
q
,
(
1
,
0
,
2
)
)
,
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
(
1
,
2
,
0
)
)
,
# [*, H, C_hidden, N_res]
)
a
=
a
+
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
2
,
0
,
1
)
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
(
2
,
0
,
1
)
)
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
...
@@ -321,7 +322,7 @@ class InvariantPointAttention(nn.Module):
square_mask
=
self
.
inf
*
(
square_mask
-
1
)
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
2
,
0
,
1
)
pt_att
=
permute_final_dims
(
pt_att
,
(
2
,
0
,
1
)
)
a
=
a
+
pt_att
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
...
...
@@ -339,11 +340,11 @@ class InvariantPointAttention(nn.Module):
# [*, H, 3, N_res, P_v]
o_pt
=
torch
.
matmul
(
a
.
unsqueeze
(
-
3
),
# [*, H, 1, N_res, N_res]
permute_final_dims
(
v_pts
,
1
,
3
,
0
,
2
),
# [*, H, 3, N_res, P_v]
permute_final_dims
(
v_pts
,
(
1
,
3
,
0
,
2
)
)
,
# [*, H, 3, N_res, P_v]
)
# [*, N_res, H, P_v, 3]
o_pt
=
permute_final_dims
(
o_pt
,
2
,
0
,
3
,
1
)
o_pt
=
permute_final_dims
(
o_pt
,
(
2
,
0
,
3
,
1
)
)
o_pt
=
t
[...,
None
,
None
].
invert_apply
(
o_pt
)
# [*, N_res, H * P_v]
...
...
@@ -758,35 +759,39 @@ class StructureModule(nn.Module):
return
outputs
def
_init_residue_constants
(
self
,
device
):
def
_init_residue_constants
(
self
,
dtype
,
device
):
if
(
self
.
default_frames
is
None
):
self
.
default_frames
=
torch
.
tensor
(
restype_rigid_group_default_frame
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
group_idx
is
None
):
self
.
group_idx
=
torch
.
tensor
(
restype_atom14_to_rigid_group
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
atom_mask
is
None
):
self
.
atom_mask
=
torch
.
tensor
(
restype_atom14_mask
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
,
)
if
(
self
.
lit_positions
is
None
):
self
.
lit_positions
=
torch
.
tensor
(
restype_atom14_rigid_group_positions
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
,
)
def
torsion_angles_to_frames
(
self
,
t
,
alpha
,
f
):
# Lazily initialize the residue constants on the correct device
self
.
_init_residue_constants
(
f
.
device
)
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
# Separated purely to make testing less annoying
return
_torsion_angles_to_frames
(
t
,
alpha
,
f
,
self
.
default_frames
)
...
...
@@ -797,7 +802,7 @@ class StructureModule(nn.Module):
# Lazily initialize the residue constants on the correct device
# TODO: Maybe this stuff should be done on CPU instead (so these
# arrays
self
.
_init_residue_constants
(
f
.
device
)
self
.
_init_residue_constants
(
f
.
dtype
,
f
.
device
)
return
_frames_and_literature_positions_to_atom14_pos
(
t
,
...
...
openfold/model/template.py
View file @
b2d102cb
...
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.utils.deepspeed
import
checkpoint_blocks
from
openfold.model.dropout
import
(
DropoutRowwise
,
...
...
@@ -69,7 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
mha
=
scripted_a
ttention
(
self
.
mha
=
A
ttention
(
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden
,
self
.
no_heads
,
gating
=
False
,
...
...
@@ -91,7 +91,7 @@ class TemplatePointwiseAttention(nn.Module):
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just 1,
# but not sure enough to remove it. It's nice to have, I guess.
template_mask
=
t
orch
.
ones
(
t
.
shape
[:
-
3
]
,
device
=
t
.
device
)
template_mask
=
t
.
new_
ones
(
t
.
shape
[:
-
3
])
bias
=
(
1e9
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
))
...
...
@@ -99,7 +99,7 @@ class TemplatePointwiseAttention(nn.Module):
z
=
z
.
unsqueeze
(
-
2
)
# [*, N_res, N_res, N_temp, C_t]
t
=
permute_final_dims
(
t
,
1
,
2
,
0
,
3
)
t
=
permute_final_dims
(
t
,
(
1
,
2
,
0
,
3
)
)
# [*, N_res, N_res, 1, C_z]
mha_inputs
=
{
...
...
openfold/model/triangular_attention.py
View file @
b2d102cb
...
...
@@ -18,7 +18,7 @@ import math
import
torch
import
torch.nn
as
nn
from
openfold.model.primitives
import
Linear
,
scripted_a
ttention
from
openfold.model.primitives
import
Linear
,
A
ttention
from
openfold.utils.tensor_utils
import
(
chunk_layer
,
permute_final_dims
,
...
...
@@ -57,7 +57,7 @@ class TriangleAttention(nn.Module):
self
.
linear
=
Linear
(
c_in
,
self
.
no_heads
,
bias
=
False
,
init
=
"normal"
)
self
.
mha
=
scripted_a
ttention
(
self
.
mha
=
A
ttention
(
self
.
c_in
,
self
.
c_in
,
self
.
c_in
,
self
.
c_hidden
,
self
.
no_heads
...
...
@@ -91,7 +91,7 @@ class TriangleAttention(nn.Module):
mask_bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
None
,
:]
# [*, H, I, J]
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
2
,
0
,
1
)
triangle_bias
=
permute_final_dims
(
self
.
linear
(
x
),
(
2
,
0
,
1
)
)
# [*, 1, H, I, J]
triangle_bias
=
triangle_bias
.
unsqueeze
(
-
4
)
...
...
openfold/model/triangular_multiplicative_update.py
View file @
b2d102cb
...
...
@@ -59,12 +59,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
2
,
0
,
1
),
permute_final_dims
(
b
,
2
,
1
,
0
),
permute_final_dims
(
a
,
(
2
,
0
,
1
)
)
,
permute_final_dims
(
b
,
(
2
,
1
,
0
)
)
,
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
1
,
2
,
0
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
)
)
def
_incoming_matmul
(
self
,
a
:
torch
.
Tensor
,
# [*, N_k, N_i, C]
...
...
@@ -73,12 +73,12 @@ class TriangleMultiplicativeUpdate(nn.Module):
# [*, C, N_i, N_j]
p
=
torch
.
matmul
(
permute_final_dims
(
a
,
2
,
1
,
0
),
permute_final_dims
(
b
,
2
,
0
,
1
),
permute_final_dims
(
a
,
(
2
,
1
,
0
)
)
,
permute_final_dims
(
b
,
(
2
,
0
,
1
)
)
,
)
# [*, N_i, N_j, C]
return
permute_final_dims
(
p
,
1
,
2
,
0
)
return
permute_final_dims
(
p
,
(
1
,
2
,
0
)
)
def
forward
(
self
,
z
,
mask
=
None
):
"""
...
...
openfold/utils/affine_utils.py
View file @
b2d102cb
...
...
@@ -13,30 +13,25 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
torch
# According to DeepMind, this prevents rotation compositions from being
# computed on low-precision tensor cores. I'm personally skeptical that it
# makes a difference, but to get as close as possible to their outputs, I'm
# adding it.
def
rot_matmul
(
a
,
b
):
e
=
...
row_1
=
torch
.
stack
([
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
0
],
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
1
],
a
[
e
,
0
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
0
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
0
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
0
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
1
],
a
[
...
,
0
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
0
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
0
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
row_2
=
torch
.
stack
([
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
0
],
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
1
],
a
[
e
,
1
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
1
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
1
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
0
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
1
],
a
[
...
,
1
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
1
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
1
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
row_3
=
torch
.
stack
([
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
0
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
0
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
0
],
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
1
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
1
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
1
],
a
[
e
,
2
,
0
]
*
b
[
e
,
0
,
2
]
+
a
[
e
,
2
,
1
]
*
b
[
e
,
1
,
2
]
+
a
[
e
,
2
,
2
]
*
b
[
e
,
2
,
2
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
0
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
0
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
0
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
1
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
1
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
1
],
a
[
...
,
2
,
0
]
*
b
[
...
,
0
,
2
]
+
a
[
...
,
2
,
1
]
*
b
[
...
,
1
,
2
]
+
a
[
...
,
2
,
2
]
*
b
[
...
,
2
,
2
],
],
dim
=-
1
)
return
torch
.
stack
([
row_1
,
row_2
,
row_3
],
dim
=-
2
)
...
...
@@ -175,7 +170,7 @@ class T:
return
T
(
rots
,
trans
)
def
to_4x4
(
self
):
tensor
=
torch
.
zeros
((
*
self
.
shape
,
4
,
4
)
,
device
=
self
.
rots
.
device
)
tensor
=
self
.
rots
.
new_
zeros
((
*
self
.
shape
,
4
,
4
))
tensor
[...,
:
3
,
:
3
]
=
self
.
rots
tensor
[...,
:
3
,
3
]
=
self
.
trans
tensor
[...,
3
,
3
]
=
1
...
...
@@ -311,7 +306,7 @@ def _to_mat(pairs):
return
mat
_qtr_mat
=
torch
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
=
np
.
zeros
((
4
,
4
,
3
,
3
))
_qtr_mat
[...,
0
,
0
]
=
_to_mat
([(
'aa'
,
1
),
(
'bb'
,
1
),
(
'cc'
,
-
1
),
(
'dd'
,
-
1
)])
_qtr_mat
[...,
0
,
1
]
=
_to_mat
([(
'bc'
,
2
),
(
'ad'
,
-
2
)])
_qtr_mat
[...,
0
,
2
]
=
_to_mat
([(
'bd'
,
2
),
(
'ac'
,
2
)])
...
...
@@ -328,9 +323,11 @@ def quat_to_rot(
# [*, 4, 4]
quat
=
quat
[...,
None
]
*
quat
[...,
None
,
:]
mat
=
quat
.
new_tensor
(
_qtr_mat
)
# [*, 4, 4, 3, 3]
shaped_qtr_mat
=
_qtr_
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
))
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
.
to
(
quat
.
device
)
shaped_qtr_mat
=
mat
.
view
((
1
,)
*
len
(
quat
.
shape
[:
-
2
])
+
(
4
,
4
,
3
,
3
))
quat
=
quat
[...,
None
,
None
]
*
shaped_qtr_mat
# [*, 3, 3]
return
torch
.
sum
(
quat
,
dim
=
(
-
3
,
-
4
))
...
...
@@ -339,9 +336,7 @@ def affine_vector_to_4x4(vector):
quats
=
vector
[...,
:
4
]
trans
=
vector
[...,
4
:]
four_by_four
=
torch
.
zeros
(
(
*
vector
.
shape
[:
-
1
],
4
,
4
),
device
=
vector
.
device
)
four_by_four
=
vector
.
new_zeros
((
*
vector
.
shape
[:
-
1
],
4
,
4
))
four_by_four
[...,
:
3
,
:
3
]
=
quat_to_rot
(
quats
)
four_by_four
[...,
:
3
,
3
]
=
trans
four_by_four
[...,
3
,
3
]
=
1
...
...
openfold/utils/deepspeed.py
View file @
b2d102cb
...
...
@@ -14,6 +14,7 @@
import
deepspeed
import
torch
from
torch.utils.checkpoint
import
checkpoint
from
typing
import
Any
,
Tuple
,
List
,
Callable
BLOCK_ARG
=
Any
...
...
@@ -55,7 +56,7 @@ def checkpoint_blocks(
return
a
def
chunker
(
s
,
e
):
def
exec_sliced
(
a
):
def
exec_sliced
(
*
a
):
return
exec
(
blocks
[
s
:
e
],
a
)
return
exec_sliced
...
...
@@ -69,7 +70,7 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
args
)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
wrap
(
args
)
return
args
openfold/utils/import_weights.py
View file @
b2d102cb
...
...
@@ -231,7 +231,7 @@ def import_jax_weights_(model, npz_path, version="model_1"):
MSAGlobalAttParams
=
lambda
matt
:
{
"query_norm"
:
LayerNormParams
(
matt
.
layer_norm_m
),
"attention"
:
GlobalAttentionParams
(
matt
)
"attention"
:
GlobalAttentionParams
(
matt
.
global_attention
)
}
MSAAttPairBiasParams
=
lambda
matt
:
dict
(
...
...
openfold/utils/loss.py
View file @
b2d102cb
...
...
@@ -356,7 +356,7 @@ def lddt_loss(
)
dists_to_score
=
(
(
dmat_true
<
cutoff
)
*
all_atom_mask
*
permute_final_dims
(
all_atom_mask
,
1
,
0
)
*
permute_final_dims
(
all_atom_mask
,
(
1
,
0
)
)
*
(
1.
-
torch
.
eye
(
n
,
device
=
all_atom_mask
.
device
))
)
...
...
openfold/utils/tensor_utils.py
View file @
b2d102cb
...
...
@@ -16,12 +16,13 @@
from
functools
import
partial
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
,
List
,
Callable
,
Any
,
Dict
def
permute_final_dims
(
tensor
,
*
inds
):
def
permute_final_dims
(
tensor
:
torch
.
Tensor
,
inds
:
List
[
int
]
):
zero_index
=
-
1
*
len
(
inds
)
first_inds
=
range
(
len
(
tensor
.
shape
[:
zero_index
]))
return
tensor
.
permute
(
*
first_inds
,
*
[
zero_index
+
i
for
i
in
inds
])
first_inds
=
list
(
range
(
len
(
tensor
.
shape
[:
zero_index
]))
)
return
tensor
.
permute
(
first_inds
+
[
zero_index
+
i
for
i
in
inds
])
def
flatten_final_dims
(
tensor
:
torch
.
Tensor
,
no_dims
:
int
):
...
...
@@ -70,7 +71,7 @@ def stack_tensor_dicts(dicts):
def
one_hot
(
x
,
v_bins
):
reshaped_bins
=
v_bins
.
view
(
*
((
1
,)
*
len
(
x
.
shape
)
+
(
len
(
v_bins
),))
)
reshaped_bins
=
v_bins
.
view
(((
1
,)
*
len
(
x
.
shape
)
)
+
(
len
(
v_bins
),))
diffs
=
x
[...,
None
]
-
reshaped_bins
am
=
torch
.
argmin
(
torch
.
abs
(
diffs
),
dim
=-
1
)
return
nn
.
functional
.
one_hot
(
am
,
num_classes
=
len
(
v_bins
)).
float
()
...
...
@@ -118,7 +119,12 @@ def tree_map(fn, tree, leaf_type):
tensor_tree_map
=
partial
(
tree_map
,
leaf_type
=
torch
.
Tensor
)
def
chunk_layer
(
layer
,
inputs
,
chunk_size
,
no_batch_dims
):
def
chunk_layer
(
layer
:
Callable
,
inputs
:
Dict
[
str
,
Any
],
chunk_size
:
int
,
no_batch_dims
:
int
,
)
->
Any
:
"""
Implements the "chunking" procedure described in section 1.11.8.
...
...
@@ -130,8 +136,8 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
layer:
The layer to be applied chunk-wise
inputs:
A (nested) dictionary of keyworded inputs. All leaves must
be
tensors and must share the same batch dimensions.
A (
non-
nested) dictionary of keyworded inputs. All leaves must
be
tensors and must share the same batch dimensions.
chunk_size:
The number of sub-batches per chunk. If multiple batch
dimensions are specified, a "sub-batch" is defined as a single
...
...
@@ -163,7 +169,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
return
shapes
initial_dims
=
[
shape
[:
no_batch_dims
]
for
shape
in
fetch_dims
(
inputs
)]
orig_batch_dims
=
[
max
(
s
)
for
s
in
zip
(
*
initial_dims
)]
orig_batch_dims
=
tuple
(
[
max
(
s
)
for
s
in
zip
(
*
initial_dims
)]
)
def
prep_inputs
(
t
):
# TODO: make this more memory efficient. This sucks
...
...
@@ -194,7 +200,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
# Allocate space for the output
if
(
out
is
None
):
allocate
=
lambda
t
:
t
.
new_zeros
(
flat_batch_dim
,
*
t
.
shape
[
1
:])
allocate
=
lambda
t
:
t
.
new_zeros
(
(
flat_batch_dim
,
)
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
allocate
,
output_chunk
)
# Put the chunk in its pre-allocated space
...
...
@@ -217,7 +223,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
i
+=
chunk_size
reshape
=
lambda
t
:
t
.
reshape
(
*
orig_batch_dims
,
*
t
.
shape
[
1
:])
reshape
=
lambda
t
:
t
.
reshape
(
orig_batch_dims
+
t
.
shape
[
1
:])
out
=
tensor_tree_map
(
reshape
,
out
)
return
out
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment