Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ab0c6977
Commit
ab0c6977
authored
Sep 26, 2021
by
Gustaf Ahdritz
Browse files
Continue debugging loss functions, remove in-place ops
parent
85c0a9a9
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
133 additions
and
85 deletions
+133
-85
openfold/model/embedders.py
openfold/model/embedders.py
+1
-0
openfold/model/evoformer.py
openfold/model/evoformer.py
+2
-1
openfold/model/model.py
openfold/model/model.py
+5
-5
openfold/model/msa.py
openfold/model/msa.py
+7
-6
openfold/model/outer_product_mean.py
openfold/model/outer_product_mean.py
+2
-2
openfold/model/pair_transition.py
openfold/model/pair_transition.py
+1
-1
openfold/model/primitives.py
openfold/model/primitives.py
+2
-2
openfold/model/structure_module.py
openfold/model/structure_module.py
+11
-8
openfold/model/template.py
openfold/model/template.py
+1
-1
openfold/model/triangular_attention.py
openfold/model/triangular_attention.py
+1
-1
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+1
-0
openfold/utils/feats.py
openfold/utils/feats.py
+19
-14
openfold/utils/loss.py
openfold/utils/loss.py
+79
-42
openfold/utils/tensor_utils.py
openfold/utils/tensor_utils.py
+1
-2
No files found.
openfold/model/embedders.py
View file @
ab0c6977
...
@@ -113,6 +113,7 @@ class InputEmbedder(nn.Module):
...
@@ -113,6 +113,7 @@ class InputEmbedder(nn.Module):
# [*, N_res, N_res, c_z]
# [*, N_res, N_res, c_z]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
=
tf_emb_i
[...,
None
,
:]
+
tf_emb_j
[...,
None
,
:,
:]
pair_emb
+=
self
.
relpos
(
ri
)
pair_emb
+=
self
.
relpos
(
ri
)
#pair_emb = pair_emb + self.relpos(ri)
# [*, N_clust, N_res, c_m]
# [*, N_clust, N_res, c_m]
n_clust
=
msa
.
shape
[
-
3
]
n_clust
=
msa
.
shape
[
-
3
]
...
...
openfold/model/evoformer.py
View file @
ab0c6977
...
@@ -94,7 +94,7 @@ class MSATransition(nn.Module):
...
@@ -94,7 +94,7 @@ class MSATransition(nn.Module):
m
=
self
.
layer_norm
(
m
)
m
=
self
.
layer_norm
(
m
)
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
inp
=
{
"m"
:
m
,
"mask"
:
mask
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
...
@@ -132,6 +132,7 @@ class EvoformerBlock(nn.Module):
...
@@ -132,6 +132,7 @@ class EvoformerBlock(nn.Module):
c_z
=
c_z
,
c_z
=
c_z
,
c_hidden
=
c_hidden_msa_att
,
c_hidden
=
c_hidden_msa_att
,
no_heads
=
no_heads_msa
,
no_heads
=
no_heads_msa
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
...
openfold/model/model.py
View file @
ab0c6977
...
@@ -108,7 +108,7 @@ class AlphaFold(nn.Module):
...
@@ -108,7 +108,7 @@ class AlphaFold(nn.Module):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
def
embed_templates
(
self
,
batch
,
z
,
pair_mask
,
templ_dim
):
# Embed the templates one at a time (with a poor man's vmap)
# Embed the templates one at a time (with a poor man's vmap)
template_embeds
=
[]
template_embeds
=
[]
n_templ
=
batch
[
"template_aatype"
].
shape
[
-
2
]
n_templ
=
batch
[
"template_aatype"
].
shape
[
templ_dim
]
for
i
in
range
(
n_templ
):
for
i
in
range
(
n_templ
):
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
idx
=
batch
[
"template_aatype"
].
new_tensor
(
i
)
single_template_feats
=
tensor_tree_map
(
single_template_feats
=
tensor_tree_map
(
...
@@ -155,14 +155,14 @@ class AlphaFold(nn.Module):
...
@@ -155,14 +155,14 @@ class AlphaFold(nn.Module):
partial
(
torch
.
cat
,
dim
=
templ_dim
),
partial
(
torch
.
cat
,
dim
=
templ_dim
),
template_embeds
,
template_embeds
,
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
t
=
self
.
template_pointwise_att
(
t
=
self
.
template_pointwise_att
(
template_embeds
[
"pair"
],
template_embeds
[
"pair"
],
z
,
z
,
template_mask
=
batch
[
"template_mask"
]
template_mask
=
batch
[
"template_mask"
]
)
)
t
*
=
torch
.
sum
(
batch
[
"template_mask"
])
>
0
t
=
t
*
torch
.
sum
(
batch
[
"template_mask"
])
>
0
return
{
return
{
"template_angle_embedding"
:
a
,
"template_angle_embedding"
:
a
,
...
@@ -297,7 +297,7 @@ class AlphaFold(nn.Module):
...
@@ -297,7 +297,7 @@ class AlphaFold(nn.Module):
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
m
[...,
0
,
:,
:]
+=
m_1_prev_emb
# [*, N, N, C_z]
# [*, N, N, C_z]
z
+
=
z_prev_emb
z
=
z
+
z_prev_emb
# Embed the templates + merge with MSA/pair embeddings
# Embed the templates + merge with MSA/pair embeddings
if
(
self
.
config
.
template
.
enabled
):
if
(
self
.
config
.
template
.
enabled
):
...
@@ -312,7 +312,7 @@ class AlphaFold(nn.Module):
...
@@ -312,7 +312,7 @@ class AlphaFold(nn.Module):
)
)
# [*, N, N, C_z]
# [*, N, N, C_z]
z
+
=
template_embeds
[
"template_pair_embedding"
]
z
=
z
+
template_embeds
[
"template_pair_embedding"
]
if
(
self
.
config
.
template
.
embed_angles
):
if
(
self
.
config
.
template
.
embed_angles
):
# [*, S = S_c + S_t, N, C_m]
# [*, S = S_c + S_t, N, C_m]
...
...
openfold/model/msa.py
View file @
ab0c6977
...
@@ -125,7 +125,7 @@ class MSAAttention(nn.Module):
...
@@ -125,7 +125,7 @@ class MSAAttention(nn.Module):
"v_x"
:
m
,
"v_x"
:
m
,
"biases"
:
biases
"biases"
:
biases
}
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
...
@@ -142,7 +142,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -142,7 +142,7 @@ class MSARowAttentionWithPairBias(MSAAttention):
"""
"""
Implements Algorithm 7.
Implements Algorithm 7.
"""
"""
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
inf
=
1e9
):
def
__init__
(
self
,
c_m
,
c_z
,
c_hidden
,
no_heads
,
chunk_size
,
inf
=
1e9
):
"""
"""
Args:
Args:
c_m:
c_m:
...
@@ -161,7 +161,8 @@ class MSARowAttentionWithPairBias(MSAAttention):
...
@@ -161,7 +161,8 @@ class MSARowAttentionWithPairBias(MSAAttention):
c_hidden
,
c_hidden
,
no_heads
,
no_heads
,
pair_bias
=
True
,
pair_bias
=
True
,
c_z
=
c_z
,
c_z
=
c_z
,
chunk_size
=
chunk_size
,
inf
=
inf
,
inf
=
inf
,
)
)
...
@@ -259,7 +260,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -259,7 +260,7 @@ class MSAColumnGlobalAttention(nn.Module):
# [*, N_res, H * C_hidden]
# [*, N_res, H * C_hidden]
q
=
self
.
linear_q
(
q
)
q
=
self
.
linear_q
(
q
)
q
*
=
self
.
c_hidden
**
(
-
0.5
)
q
=
q
*
self
.
c_hidden
**
(
-
0.5
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
q
=
q
.
view
(
*
q
.
shape
[:
-
1
],
self
.
no_heads
,
-
1
)
...
@@ -274,7 +275,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -274,7 +275,7 @@ class MSAColumnGlobalAttention(nn.Module):
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
k
.
transpose
(
-
1
,
-
2
),
# [*, N_res, C_hidden, N_seq]
)
)
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
bias
=
(
self
.
inf
*
(
mask
-
1
))[...,
:,
None
,
:]
a
+
=
bias
a
=
a
+
bias
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
# [*, N_res, H, C_hidden]
# [*, N_res, H, C_hidden]
...
@@ -318,7 +319,7 @@ class MSAColumnGlobalAttention(nn.Module):
...
@@ -318,7 +319,7 @@ class MSAColumnGlobalAttention(nn.Module):
"m"
:
m
,
"m"
:
m
,
"mask"
:
mask
,
"mask"
:
mask
,
}
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
m
=
chunk_layer
(
m
=
chunk_layer
(
self
.
global_attention
,
self
.
global_attention
,
mha_input
,
mha_input
,
...
...
openfold/model/outer_product_mean.py
View file @
ab0c6977
...
@@ -83,7 +83,7 @@ class OuterProductMean(nn.Module):
...
@@ -83,7 +83,7 @@ class OuterProductMean(nn.Module):
a
=
a
.
transpose
(
-
2
,
-
3
)
a
=
a
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
b
=
b
.
transpose
(
-
2
,
-
3
)
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
# Since the "batch dim" in this case is not a true batch dimension
# Since the "batch dim" in this case is not a true batch dimension
# (in that the shape of the output depends on it), we need to
# (in that the shape of the output depends on it), we need to
# iterate over it ourselves
# iterate over it ourselves
...
@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
...
@@ -107,7 +107,7 @@ class OuterProductMean(nn.Module):
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
norm
=
torch
.
einsum
(
"...abc,...adc->...bdc"
,
mask
,
mask
)
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
outer
/
=
self
.
eps
+
norm
outer
=
outer
/
self
.
eps
+
norm
return
outer
return
outer
...
...
openfold/model/pair_transition.py
View file @
ab0c6977
...
@@ -73,7 +73,7 @@ class PairTransition(nn.Module):
...
@@ -73,7 +73,7 @@ class PairTransition(nn.Module):
z
=
self
.
layer_norm
(
z
)
z
=
self
.
layer_norm
(
z
)
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
inp
=
{
"z"
:
z
,
"mask"
:
mask
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
_transition
,
self
.
_transition
,
inp
,
inp
,
...
...
openfold/model/primitives.py
View file @
ab0c6977
...
@@ -251,10 +251,10 @@ class Attention(nn.Module):
...
@@ -251,10 +251,10 @@ class Attention(nn.Module):
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, K]
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, K]
)
)
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
norm
=
1
/
math
.
sqrt
(
self
.
c_hidden
)
# [1]
a
*
=
norm
a
=
a
*
norm
if
(
biases
is
not
None
):
if
(
biases
is
not
None
):
for
b
in
biases
:
for
b
in
biases
:
a
+
=
b
a
=
a
+
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
...
...
openfold/model/structure_module.py
View file @
ab0c6977
...
@@ -129,10 +129,11 @@ class AngleResnet(nn.Module):
...
@@ -129,10 +129,11 @@ class AngleResnet(nn.Module):
# [*, no_angles * 2]
# [*, no_angles * 2]
s
=
self
.
linear_out
(
s
)
s
=
self
.
linear_out
(
s
)
unnormalized_s
=
s
# [*, no_angles, 2]
# [*, no_angles, 2]
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
s
=
s
.
view
(
*
s
.
shape
[:
-
1
],
-
1
,
2
)
unnormalized_s
=
s
norm_denom
=
torch
.
sqrt
(
norm_denom
=
torch
.
sqrt
(
torch
.
clamp
(
torch
.
clamp
(
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdims
=
True
),
torch
.
sum
(
s
**
2
,
dim
=-
1
,
keepdims
=
True
),
...
@@ -295,8 +296,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -295,8 +296,8 @@ class InvariantPointAttention(nn.Module):
permute_final_dims
(
q
,
1
,
0
,
2
),
# [*, H, N_res, C_hidden]
permute_final_dims
(
q
,
1
,
0
,
2
),
# [*, H, N_res, C_hidden]
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, N_res]
permute_final_dims
(
k
,
1
,
2
,
0
),
# [*, H, C_hidden, N_res]
)
)
a
*
=
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
=
a
+
math
.
sqrt
(
1.
/
(
3
*
self
.
c_hidden
))
a
+
=
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
2
,
0
,
1
)
a
=
a
+
math
.
sqrt
(
1.
/
3
)
*
permute_final_dims
(
b
,
2
,
0
,
1
)
# [*, N_res, N_res, H, P_q, 3]
# [*, N_res, N_res, H, P_q, 3]
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
pt_att
=
q_pts
.
unsqueeze
(
-
4
)
-
k_pts
.
unsqueeze
(
-
5
)
...
@@ -307,7 +308,9 @@ class InvariantPointAttention(nn.Module):
...
@@ -307,7 +308,9 @@ class InvariantPointAttention(nn.Module):
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
head_weights
=
self
.
softplus
(
self
.
head_weights
).
view
(
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
*
((
1
,)
*
len
(
pt_att
.
shape
[:
-
2
])
+
(
-
1
,
1
))
)
)
head_weights
*=
math
.
sqrt
(
1.
/
(
3
*
(
self
.
no_qk_points
*
9.
/
2
)))
head_weights
=
(
head_weights
*
math
.
sqrt
(
1.
/
(
3
*
(
self
.
no_qk_points
*
9.
/
2
)))
)
pt_att
=
pt_att
*
head_weights
pt_att
=
pt_att
*
head_weights
# [*, N_res, N_res, H]
# [*, N_res, N_res, H]
...
@@ -319,8 +322,8 @@ class InvariantPointAttention(nn.Module):
...
@@ -319,8 +322,8 @@ class InvariantPointAttention(nn.Module):
# [*, H, N_res, N_res]
# [*, H, N_res, N_res]
pt_att
=
permute_final_dims
(
pt_att
,
2
,
0
,
1
)
pt_att
=
permute_final_dims
(
pt_att
,
2
,
0
,
1
)
a
+
=
pt_att
a
=
a
+
pt_att
a
+
=
square_mask
.
unsqueeze
(
-
3
)
a
=
a
+
square_mask
.
unsqueeze
(
-
3
)
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
################
################
...
@@ -510,7 +513,7 @@ def _frames_and_literature_positions_to_atom14_pos(
...
@@ -510,7 +513,7 @@ def _frames_and_literature_positions_to_atom14_pos(
# [*, N, 14, 3]
# [*, N, 14, 3]
lit_positions
=
lit_positions
[
f
,
...]
lit_positions
=
lit_positions
[
f
,
...]
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
=
t_atoms_to_global
.
apply
(
lit_positions
)
pred_positions
*
=
atom_mask
pred_positions
=
pred_positions
*
atom_mask
return
pred_positions
return
pred_positions
...
...
openfold/model/template.py
View file @
ab0c6977
...
@@ -108,7 +108,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -108,7 +108,7 @@ class TemplatePointwiseAttention(nn.Module):
"v_x"
:
t
,
"v_x"
:
t
,
"biases"
:
[
bias
],
"biases"
:
[
bias
],
}
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
z
=
chunk_layer
(
z
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
...
...
openfold/model/triangular_attention.py
View file @
ab0c6977
...
@@ -102,7 +102,7 @@ class TriangleAttention(nn.Module):
...
@@ -102,7 +102,7 @@ class TriangleAttention(nn.Module):
"v_x"
:
x
,
"v_x"
:
x
,
"biases"
:
[
mask_bias
,
triangle_bias
],
"biases"
:
[
mask_bias
,
triangle_bias
],
}
}
if
(
not
self
.
training
and
self
.
chunk_size
is
not
None
):
if
(
self
.
chunk_size
is
not
None
):
x
=
chunk_layer
(
x
=
chunk_layer
(
self
.
mha
,
self
.
mha
,
mha_inputs
,
mha_inputs
,
...
...
openfold/utils/deepspeed.py
View file @
ab0c6977
...
@@ -70,5 +70,6 @@ def checkpoint_blocks(
...
@@ -70,5 +70,6 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
args
)
args
=
deepspeed
.
checkpointing
.
checkpoint
(
chunker
(
s
,
e
),
args
)
args
=
wrap
(
args
)
return
args
return
args
openfold/utils/feats.py
View file @
ab0c6977
...
@@ -158,15 +158,16 @@ def atom14_to_atom37(atom14, batch):
...
@@ -158,15 +158,16 @@ def atom14_to_atom37(atom14, batch):
def
atom37_to_torsion_angles
(
def
atom37_to_torsion_angles
(
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
all_atom_pos
:
torch
.
Tensor
,
all_atom_pos
itions
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
all_atom_mask
:
torch
.
Tensor
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""
"""
Args:
Args:
aatype:
aatype:
[*, N_res] residue indices
[*, N_res] residue indices
all_atom_pos:
all_atom_pos
itions
:
[*, N_res, 37, 3] atom positions (in atom37
[*, N_res, 37, 3] atom positions (in atom37
format)
format)
all_atom_mask:
all_atom_mask:
...
@@ -183,28 +184,32 @@ def atom37_to_torsion_angles(
...
@@ -183,28 +184,32 @@ def atom37_to_torsion_angles(
"""
"""
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
aatype
=
torch
.
clamp
(
aatype
,
max
=
20
)
pad
=
all_atom_pos
.
new_zeros
([
*
all_atom_pos
.
shape
[:
-
3
],
1
,
37
,
3
])
pad
=
all_atom_positions
.
new_zeros
(
prev_all_atom_pos
=
torch
.
cat
([
pad
,
all_atom_pos
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
[
*
all_atom_positions
.
shape
[:
-
3
],
1
,
37
,
3
]
)
prev_all_atom_positions
=
torch
.
cat
(
[
pad
,
all_atom_positions
[...,
:
-
1
,
:,
:]],
dim
=-
3
)
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
pad
=
all_atom_mask
.
new_zeros
([
*
all_atom_mask
.
shape
[:
-
2
],
1
,
37
])
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
prev_all_atom_mask
=
torch
.
cat
([
pad
,
all_atom_mask
[...,
:
-
1
,
:]],
dim
=-
2
)
pre_omega_atom_pos
=
torch
.
cat
(
pre_omega_atom_pos
=
torch
.
cat
(
[
[
prev_all_atom_pos
[...,
1
:
3
,
:],
prev_all_atom_pos
itions
[...,
1
:
3
,
:],
all_atom_pos
[...,
:
2
,
:]
all_atom_pos
itions
[...,
:
2
,
:]
],
dim
=-
2
],
dim
=-
2
)
)
phi_atom_pos
=
torch
.
cat
(
phi_atom_pos
=
torch
.
cat
(
[
[
prev_all_atom_pos
[...,
2
:
3
,
:],
prev_all_atom_pos
itions
[...,
2
:
3
,
:],
all_atom_pos
[...,
:
3
,
:]
all_atom_pos
itions
[...,
:
3
,
:]
],
dim
=-
2
],
dim
=-
2
)
)
psi_atom_pos
=
torch
.
cat
(
psi_atom_pos
=
torch
.
cat
(
[
[
all_atom_pos
[...,
:
3
,
:],
all_atom_pos
itions
[...,
:
3
,
:],
all_atom_pos
[...,
4
:
5
,
:]
all_atom_pos
itions
[...,
4
:
5
,
:]
],
dim
=-
2
],
dim
=-
2
)
)
...
@@ -227,7 +232,7 @@ def atom37_to_torsion_angles(
...
@@ -227,7 +232,7 @@ def atom37_to_torsion_angles(
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
atom_indices
=
chi_atom_indices
[...,
aatype
,
:,
:]
chis_atom_pos
=
batched_gather
(
chis_atom_pos
=
batched_gather
(
all_atom_pos
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
all_atom_pos
itions
,
atom_indices
,
-
2
,
len
(
atom_indices
.
shape
[:
-
2
])
)
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
chi_angles_mask
=
list
(
rc
.
chi_angles_mask
)
...
@@ -335,9 +340,9 @@ def atom37_to_frames(
...
@@ -335,9 +340,9 @@ def atom37_to_frames(
device
=
aatype
.
device
,
device
=
aatype
.
device
,
requires_grad
=
False
requires_grad
=
False
)
)
restype_rigidgroup_mask
[
:
,
0
]
=
1
restype_rigidgroup_mask
[
...
,
0
]
=
1
restype_rigidgroup_mask
[
:
,
3
]
=
1
restype_rigidgroup_mask
[
...
,
3
]
=
1
restype_rigidgroup_mask
[:
20
,
4
:]
=
(
restype_rigidgroup_mask
[
...,
:
20
,
4
:]
=
(
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
all_atom_mask
.
new_tensor
(
rc
.
chi_angles_mask
)
)
)
...
...
openfold/utils/loss.py
View file @
ab0c6977
...
@@ -22,6 +22,7 @@ from typing import Dict, Optional
...
@@ -22,6 +22,7 @@ from typing import Dict, Optional
from
openfold.np
import
residue_constants
from
openfold.np
import
residue_constants
from
openfold.model.primitives
import
Linear
from
openfold.model.primitives
import
Linear
from
openfold.utils
import
feats
from
openfold.utils.affine_utils
import
T
from
openfold.utils.affine_utils
import
T
from
openfold.utils.tensor_utils
import
(
from
openfold.utils.tensor_utils
import
(
tree_map
,
tree_map
,
...
@@ -150,7 +151,9 @@ def backbone_loss(
...
@@ -150,7 +151,9 @@ def backbone_loss(
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
unclamped_fape_loss
*
(
1
-
use_clamped_fape
)
)
)
return
torch
.
mean
(
fape_loss
,
dim
=-
1
)
# Take the mean over the layer dimension
fape_loss
=
torch
.
mean
(
fape_loss
,
dim
=
0
)
return
fape_loss
def
sidechain_loss
(
def
sidechain_loss
(
...
@@ -172,11 +175,10 @@ def sidechain_loss(
...
@@ -172,11 +175,10 @@ def sidechain_loss(
alt_naming_is_better
[...,
None
,
None
,
None
]
*
alt_naming_is_better
[...,
None
,
None
,
None
]
*
rigidgroups_alt_gt_frames
rigidgroups_alt_gt_frames
)
)
batch_dims
=
sidechain_frames
.
shape
[:
-
5
]
# Steamroll the inputs
# Steamroll the inputs
sidechain_frames
=
sidechain_frames
[
-
1
]
sidechain_frames
=
sidechain_frames
[
-
1
]
batch_dims
=
sidechain_frames
.
shape
[:
-
4
]
sidechain_frames
=
sidechain_frames
.
view
(
sidechain_frames
=
sidechain_frames
.
view
(
*
batch_dims
,
-
1
,
4
,
4
*
batch_dims
,
-
1
,
4
,
4
)
)
...
@@ -198,7 +200,7 @@ def sidechain_loss(
...
@@ -198,7 +200,7 @@ def sidechain_loss(
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
renamed_atom14_gt_exists
=
renamed_atom14_gt_exists
.
view
(
*
batch_dims
,
-
1
*
batch_dims
,
-
1
)
)
fape
=
compute_fape
(
fape
=
compute_fape
(
sidechain_frames
,
sidechain_frames
,
renamed_gt_frames
,
renamed_gt_frames
,
...
@@ -240,7 +242,7 @@ def supervised_chi_loss(
...
@@ -240,7 +242,7 @@ def supervised_chi_loss(
aatype
:
torch
.
Tensor
,
aatype
:
torch
.
Tensor
,
seq_mask
:
torch
.
Tensor
,
seq_mask
:
torch
.
Tensor
,
chi_mask
:
torch
.
Tensor
,
chi_mask
:
torch
.
Tensor
,
chi_angles
:
torch
.
Tensor
,
chi_angles
_sin_cos
:
torch
.
Tensor
,
chi_weight
:
float
,
chi_weight
:
float
,
angle_norm_weight
:
float
,
angle_norm_weight
:
float
,
eps
=
1e-6
,
eps
=
1e-6
,
...
@@ -256,24 +258,24 @@ def supervised_chi_loss(
...
@@ -256,24 +258,24 @@ def supervised_chi_loss(
angles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
),
angles_sin_cos
.
new_tensor
(
residue_constants
.
chi_pi_periodic
),
)
)
true_chi
=
chi_angles
true_chi
=
chi_angles_sin_cos
.
unsqueeze
(
-
4
)
sin_true_chi
=
torch
.
sin
(
true_chi
)
cos_true_chi
=
torch
.
cos
(
true_chi
)
sin_cos_true_chi
=
torch
.
stack
([
sin_true_chi
,
cos_true_chi
],
dim
=-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
shifted_mask
=
(
1
-
2
*
chi_pi_periodic
).
unsqueeze
(
-
1
)
sin_cos_
true_chi_shifted
=
shifted_mask
*
sin_cos_
true_chi
true_chi_shifted
=
shifted_mask
*
true_chi
sq_chi_error
=
torch
.
sum
(
sq_chi_error
=
torch
.
sum
(
(
sin_cos_
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi
-
pred_angles
)
**
2
,
dim
=-
1
)
)
sq_chi_error_shifted
=
torch
.
sum
(
sq_chi_error_shifted
=
torch
.
sum
(
(
sin_cos_
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
(
true_chi_shifted
-
pred_angles
)
**
2
,
dim
=-
1
)
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
sq_chi_error
=
torch
.
minimum
(
sq_chi_error
,
sq_chi_error_shifted
)
# The ol' switcheroo
sq_chi_error
=
sq_chi_error
.
permute
(
*
range
(
len
(
sq_chi_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
sq_chi_loss
=
masked_mean
(
sq_chi_loss
=
masked_mean
(
chi_mask
,
sq_chi_error
,
dim
=
(
-
1
,
-
2
)
chi_mask
[...,
None
,
:,
:]
,
sq_chi_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
)
loss
=
0
loss
=
0
...
@@ -283,8 +285,11 @@ def supervised_chi_loss(
...
@@ -283,8 +285,11 @@ def supervised_chi_loss(
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
torch
.
sum
(
unnormalized_angles_sin_cos
**
2
,
dim
=-
1
)
+
eps
)
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
norm_error
=
torch
.
abs
(
angle_norm
-
1.
)
norm_error
=
norm_error
.
permute
(
*
range
(
len
(
norm_error
.
shape
))[
1
:
-
2
],
0
,
-
2
,
-
1
)
angle_norm_loss
=
masked_mean
(
angle_norm_loss
=
masked_mean
(
seq_mask
[...,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
)
seq_mask
[...,
None
,
:,
None
],
norm_error
,
dim
=
(
-
1
,
-
2
,
-
3
)
)
)
loss
+=
angle_norm_weight
*
angle_norm_loss
loss
+=
angle_norm_weight
*
angle_norm_loss
...
@@ -377,10 +382,10 @@ def lddt_loss(
...
@@ -377,10 +382,10 @@ def lddt_loss(
)
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
errors
=
softmax_cross_entropy
(
logits
,
lddt_ca_one_hot
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
all_atom_mask
=
all_atom_mask
.
squeeze
(
-
1
)
loss
=
(
loss
=
(
torch
.
sum
(
errors
*
all_atom_mask
)
/
(
torch
.
sum
(
all_atom_mask
)
+
1e-8
)
torch
.
sum
(
errors
*
all_atom_mask
,
dim
=-
1
)
/
(
eps
+
torch
.
sum
(
all_atom_mask
,
dim
=-
1
))
)
)
loss
*=
(
loss
*=
(
...
@@ -483,10 +488,10 @@ def tm_score(
...
@@ -483,10 +488,10 @@ def tm_score(
def
between_residue_bond_loss
(
def
between_residue_bond_loss
(
pred_atom_positions
:
torch
.
Tensor
,
# (N, 37
(
14
)
, 3)
pred_atom_positions
:
torch
.
Tensor
,
# (
*,
N, 37
/
14, 3)
pred_atom_mask
:
torch
.
Tensor
,
# (N, 37
(
14)
)
pred_atom_mask
:
torch
.
Tensor
,
# (
*,
N, 37
/
14)
residue_index
:
torch
.
Tensor
,
# (N)
residue_index
:
torch
.
Tensor
,
# (
*,
N)
aatype
:
torch
.
Tensor
,
# (N)
aatype
:
torch
.
Tensor
,
# (
*,
N)
tolerance_factor_soft
=
12.0
,
tolerance_factor_soft
=
12.0
,
tolerance_factor_hard
=
12.0
,
tolerance_factor_hard
=
12.0
,
eps
=
1e-6
,
eps
=
1e-6
,
...
@@ -561,7 +566,10 @@ def between_residue_bond_loss(
...
@@ -561,7 +566,10 @@ def between_residue_bond_loss(
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
c_n_bond_length_error
-
tolerance_factor_soft
*
gt_stddev
)
)
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
mask
=
this_c_mask
*
next_n_mask
*
has_no_gap_mask
c_n_loss
=
torch
.
sum
(
mask
*
c_n_loss_per_residue
)
/
(
torch
.
sum
(
mask
)
+
eps
)
c_n_loss
=
(
torch
.
sum
(
mask
*
c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
c_n_violation_mask
=
mask
*
(
c_n_violation_mask
=
mask
*
(
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
c_n_bond_length_error
>
(
tolerance_factor_hard
*
gt_stddev
)
)
)
...
@@ -589,7 +597,8 @@ def between_residue_bond_loss(
...
@@ -589,7 +597,8 @@ def between_residue_bond_loss(
)
)
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
mask
=
this_ca_mask
*
this_c_mask
*
next_n_mask
*
has_no_gap_mask
ca_c_n_loss
=
(
ca_c_n_loss
=
(
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
)
/
(
torch
.
sum
(
mask
)
+
eps
)
torch
.
sum
(
mask
*
ca_c_n_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
)
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
ca_c_n_violation_mask
=
mask
*
(
ca_c_n_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
))
(
tolerance_factor_hard
*
gt_stddev
))
...
@@ -604,7 +613,8 @@ def between_residue_bond_loss(
...
@@ -604,7 +613,8 @@ def between_residue_bond_loss(
)
)
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
mask
=
this_c_mask
*
next_n_mask
*
next_ca_mask
*
has_no_gap_mask
c_n_ca_loss
=
(
c_n_ca_loss
=
(
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
)
/
(
torch
.
sum
(
mask
)
+
eps
)
torch
.
sum
(
mask
*
c_n_ca_loss_per_residue
,
dim
=-
1
)
/
(
torch
.
sum
(
mask
,
dim
=-
1
)
+
eps
)
)
)
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_violation_mask
=
mask
*
(
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
c_n_ca_cos_angle_error
>
(
tolerance_factor_hard
*
gt_stddev
)
...
@@ -619,7 +629,7 @@ def between_residue_bond_loss(
...
@@ -619,7 +629,7 @@ def between_residue_bond_loss(
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
0
,
1
))
+
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
per_residue_loss_sum
,
(
1
,
0
))
)
)
# Compute hard violations.
# Compute hard violations.
violation_mask
=
torch
.
max
(
violation_mask
=
torch
.
max
(
torch
.
stack
(
torch
.
stack
(
...
@@ -627,7 +637,8 @@ def between_residue_bond_loss(
...
@@ -627,7 +637,8 @@ def between_residue_bond_loss(
c_n_violation_mask
,
c_n_violation_mask
,
ca_c_n_violation_mask
,
ca_c_n_violation_mask
,
c_n_ca_violation_mask
c_n_ca_violation_mask
]
],
dim
=-
2
,
),
),
dim
=-
2
dim
=-
2
)[
0
]
)[
0
]
...
@@ -635,7 +646,7 @@ def between_residue_bond_loss(
...
@@ -635,7 +646,7 @@ def between_residue_bond_loss(
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
0
,
1
)),
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
torch
.
nn
.
functional
.
pad
(
violation_mask
,
(
1
,
0
))
)
)
return
{
return
{
'c_n_loss_mean'
:
c_n_loss
,
'c_n_loss_mean'
:
c_n_loss
,
'ca_c_n_loss_mean'
:
ca_c_n_loss
,
'ca_c_n_loss_mean'
:
ca_c_n_loss
,
...
@@ -708,7 +719,7 @@ def between_residue_clash_loss(
...
@@ -708,7 +719,7 @@ def between_residue_clash_loss(
# Backbone C--N bond between subsequent residues is no clash.
# Backbone C--N bond between subsequent residues is no clash.
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
c_one_hot
=
torch
.
nn
.
functional
.
one_hot
(
residue_index
.
new_tensor
(
2
),
num_classes
=
14
residue_index
.
new_tensor
(
2
.
),
num_classes
=
14
)
)
c_one_hot
=
c_one_hot
.
reshape
(
c_one_hot
=
c_one_hot
.
reshape
(
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
*
((
1
,)
*
len
(
residue_index
.
shape
[:
-
1
])),
*
c_one_hot
.
shape
...
@@ -958,7 +969,7 @@ def find_structural_violations(
...
@@ -958,7 +969,7 @@ def find_structural_violations(
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
atom14_dists_upper_bound
=
atom14_dists_upper_bound
,
tighten_bounds_for_loss
=
0.0
tighten_bounds_for_loss
=
0.0
)
)
# Combine them to a single per-residue violation mask (used later for LDDT).
# Combine them to a single per-residue violation mask (used later for LDDT).
per_residue_violations_mask
=
torch
.
max
(
per_residue_violations_mask
=
torch
.
max
(
torch
.
stack
(
torch
.
stack
(
...
@@ -1255,6 +1266,7 @@ def experimentally_resolved_loss(
...
@@ -1255,6 +1266,7 @@ def experimentally_resolved_loss(
min_resolution
:
float
,
min_resolution
:
float
,
max_resolution
:
float
,
max_resolution
:
float
,
eps
:
float
=
1e-8
,
eps
:
float
=
1e-8
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
errors
=
sigmoid_cross_entropy
(
logits
,
all_atom_mask
)
errors
=
sigmoid_cross_entropy
(
logits
,
all_atom_mask
)
loss_num
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
))
loss_num
=
torch
.
sum
(
errors
*
atom37_atom_exists
,
dim
=
(
-
1
,
-
2
))
...
@@ -1268,7 +1280,7 @@ def experimentally_resolved_loss(
...
@@ -1268,7 +1280,7 @@ def experimentally_resolved_loss(
return
loss
return
loss
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
):
def
masked_msa_loss
(
logits
,
true_msa
,
bert_mask
,
eps
=
1e-8
,
**
kwargs
):
errors
=
softmax_cross_entropy
(
errors
=
softmax_cross_entropy
(
logits
,
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
...
@@ -1296,24 +1308,48 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1296,24 +1308,48 @@ class AlphaFoldLoss(nn.Module):
**
self
.
config
.
violation
,
**
self
.
config
.
violation
,
)
)
if
(
"atom14_atom_is_ambiguous"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
build_ambiguity_feats
(
batch
))
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
if
(
"renamed_atom14_gt_positions"
not
in
out
.
keys
()):
batch
.
update
(
compute_renamed_ground_truth
(
batch
.
update
(
compute_renamed_ground_truth
(
batch
,
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
out
[
"sm"
][
"positions"
][
-
1
],
))
))
if
(
"backbone_affine_tensor"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_frames
(
**
batch
))
# TODO: Verify that this is correct
batch
[
"backbone_affine_tensor"
]
=
(
batch
[
"rigidgroups_gt_frames"
][...,
0
,
:,
:]
)
batch
[
"backbone_affine_mask"
]
=
(
batch
[
"rigidgroups_gt_exists"
][...,
0
]
)
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
**
batch
,
))
# TODO: Verify that this is correct
batch
[
"chi_angles_sin_cos"
]
=
(
batch
[
"torsion_angles_sin_cos"
][...,
3
:,
:]
)
batch
[
"chi_mask"
]
=
batch
[
"torsion_angles_mask"
][...,
3
:]
loss_fns
=
{
loss_fns
=
{
"distogram"
:
"distogram"
:
lambda
:
distogram_loss
(
lambda
:
distogram_loss
(
out
[
"distogram_logits"
],
logits
=
out
[
"distogram_logits"
],
**
{
**
batch
,
**
{
**
batch
,
**
self
.
config
.
distogram
},
**
self
.
config
.
distogram
},
),
),
"experimentally_resolved"
:
"experimentally_resolved"
:
lambda
:
experimentally_resolved_loss
(
lambda
:
experimentally_resolved_loss
(
out
[
"experimentally_resolved"
],
logits
=
out
[
"experimentally_resolved_logits"
],
**
{
**
batch
,
**
{
**
batch
,
**
self
.
config
.
experimentally_resolved
},
**
self
.
config
.
experimentally_resolved
},
),
),
"fape"
:
"fape"
:
lambda
:
fape_loss
(
lambda
:
fape_loss
(
...
@@ -1323,14 +1359,13 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1323,14 +1359,13 @@ class AlphaFoldLoss(nn.Module):
),
),
"lddt"
:
"lddt"
:
lambda
:
lddt_loss
(
lambda
:
lddt_loss
(
out
[
"lddt_logits"
],
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
]
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
{
**
batch
,
**
self
.
config
.
lddt
},
**
self
.
config
.
lddt
},
),
),
"masked_msa"
:
"masked_msa"
:
lambda
:
masked_msa_loss
(
lambda
:
masked_msa_loss
(
out
[
"masked_msa_logits"
],
logits
=
out
[
"masked_msa_logits"
],
**
{
**
batch
,
**
{
**
batch
,
**
self
.
config
.
masked_msa
},
**
self
.
config
.
masked_msa
},
),
),
...
@@ -1338,8 +1373,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1338,8 +1373,7 @@ class AlphaFoldLoss(nn.Module):
lambda
:
supervised_chi_loss
(
lambda
:
supervised_chi_loss
(
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
out
[
"sm"
][
"unnormalized_angles"
],
**
{
**
batch
,
**
{
**
batch
,
**
self
.
config
.
supervised_chi
},
**
self
.
config
.
supervised_chi
},
),
),
"violation"
:
"violation"
:
lambda
:
violation_loss
(
lambda
:
violation_loss
(
...
@@ -1351,6 +1385,9 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1351,6 +1385,9 @@ class AlphaFoldLoss(nn.Module):
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
weight
=
self
.
config
[
k
].
weight
if
(
weight
):
if
(
weight
):
cum_loss
+=
weight
*
loss_fn
()
print
(
k
)
loss
=
loss_fn
()
print
(
loss
)
cum_loss
+=
weight
*
loss
return
cum_loss
return
cum_loss
openfold/utils/tensor_utils.py
View file @
ab0c6977
...
@@ -142,8 +142,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
...
@@ -142,8 +142,7 @@ def chunk_layer(layer, inputs, chunk_size, no_batch_dims):
be considered batch dimensions.
be considered batch dimensions.
Returns:
Returns:
The reassembled output of the layer on the inputs.
The reassembled output of the layer on the inputs.
"""
"""
if
(
not
(
len
(
inputs
)
>
0
)):
if
(
not
(
len
(
inputs
)
>
0
)):
raise
ValueError
(
"Must provide at least one input"
)
raise
ValueError
(
"Must provide at least one input"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment