Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
1d47c1e7
Commit
1d47c1e7
authored
Sep 28, 2021
by
Gustaf Ahdritz
Browse files
Finish accommodating FP16. FAPE does not decrease
parent
3d9d977a
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
141 additions
and
78 deletions
+141
-78
config.py
config.py
+30
-22
openfold/model/model.py
openfold/model/model.py
+16
-7
openfold/model/msa.py
openfold/model/msa.py
+0
-1
openfold/model/primitives.py
openfold/model/primitives.py
+3
-1
openfold/model/template.py
openfold/model/template.py
+12
-3
openfold/utils/affine_utils.py
openfold/utils/affine_utils.py
+26
-18
openfold/utils/deepspeed.py
openfold/utils/deepspeed.py
+3
-0
openfold/utils/feats.py
openfold/utils/feats.py
+2
-2
openfold/utils/loss.py
openfold/utils/loss.py
+49
-24
No files found.
config.py
View file @
1d47c1e7
...
@@ -19,10 +19,13 @@ c_m = mlc.FieldReference(256)
...
@@ -19,10 +19,13 @@ c_m = mlc.FieldReference(256)
c_t
=
mlc
.
FieldReference
(
64
)
c_t
=
mlc
.
FieldReference
(
64
)
c_e
=
mlc
.
FieldReference
(
64
)
c_e
=
mlc
.
FieldReference
(
64
)
c_s
=
mlc
.
FieldReference
(
384
)
c_s
=
mlc
.
FieldReference
(
384
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
1
)
blocks_per_ckpt
=
mlc
.
FieldReference
(
1
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
)
#1280
)
chunk_size
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
)
eps
=
1e-4
inf
=
1e4
config
=
mlc
.
ConfigDict
({
config
=
mlc
.
ConfigDict
({
"model"
:
{
"model"
:
{
"c_z"
:
c_z
,
"c_z"
:
c_z
,
...
@@ -45,7 +48,7 @@ config = mlc.ConfigDict({
...
@@ -45,7 +48,7 @@ config = mlc.ConfigDict({
"min_bin"
:
3.25
,
"min_bin"
:
3.25
,
"max_bin"
:
20.75
,
"max_bin"
:
20.75
,
"no_bins"
:
15
,
"no_bins"
:
15
,
"inf"
:
1e8
,
"inf"
:
inf
,
#
1e8,
},
},
"template"
:
{
"template"
:
{
"distogram"
:
{
"distogram"
:
{
...
@@ -74,6 +77,7 @@ config = mlc.ConfigDict({
...
@@ -74,6 +77,7 @@ config = mlc.ConfigDict({
"dropout_rate"
:
0.25
,
"dropout_rate"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
},
},
"template_pointwise_attention"
:
{
"template_pointwise_attention"
:
{
"c_t"
:
c_t
,
"c_t"
:
c_t
,
...
@@ -83,8 +87,10 @@ config = mlc.ConfigDict({
...
@@ -83,8 +87,10 @@ config = mlc.ConfigDict({
"c_hidden"
:
16
,
"c_hidden"
:
16
,
"no_heads"
:
4
,
"no_heads"
:
4
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"inf"
:
inf
,
#1e-9,
},
},
"eps"
:
1e-6
,
"inf"
:
inf
,
"eps"
:
eps
,
#1e-6,
"enabled"
:
True
,
"enabled"
:
True
,
"embed_angles"
:
True
,
"embed_angles"
:
True
,
},
},
...
@@ -108,10 +114,10 @@ config = mlc.ConfigDict({
...
@@ -108,10 +114,10 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
inf
,
#
1e9,
"eps"
:
1e-10
,
"eps"
:
eps
,
#
1e-10,
},
},
"enabled"
:
True
,
"enabled"
:
False
,
#
True,
},
},
"evoformer_stack"
:
{
"evoformer_stack"
:
{
"c_m"
:
c_m
,
"c_m"
:
c_m
,
...
@@ -129,8 +135,8 @@ config = mlc.ConfigDict({
...
@@ -129,8 +135,8 @@ config = mlc.ConfigDict({
"pair_dropout"
:
0.25
,
"pair_dropout"
:
0.25
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"blocks_per_ckpt"
:
blocks_per_ckpt
,
"chunk_size"
:
chunk_size
,
"chunk_size"
:
chunk_size
,
"inf"
:
1e9
,
"inf"
:
inf
,
#
1e9,
"eps"
:
1e-10
,
"eps"
:
eps
,
#
1e-10,
},
},
"structure_module"
:
{
"structure_module"
:
{
"c_s"
:
c_s
,
"c_s"
:
c_s
,
...
@@ -146,8 +152,8 @@ config = mlc.ConfigDict({
...
@@ -146,8 +152,8 @@ config = mlc.ConfigDict({
"no_resnet_blocks"
:
2
,
"no_resnet_blocks"
:
2
,
"no_angles"
:
7
,
"no_angles"
:
7
,
"trans_scale_factor"
:
10
,
"trans_scale_factor"
:
10
,
"epsilon"
:
1e-12
,
"epsilon"
:
eps
,
#
1e-12,
"inf"
:
1e5
,
"inf"
:
inf
,
#
1e5,
},
},
"heads"
:
{
"heads"
:
{
"lddt"
:
{
"lddt"
:
{
...
@@ -186,11 +192,11 @@ config = mlc.ConfigDict({
...
@@ -186,11 +192,11 @@ config = mlc.ConfigDict({
"min_bin"
:
2.3125
,
"min_bin"
:
2.3125
,
"max_bin"
:
21.6875
,
"max_bin"
:
21.6875
,
"no_bins"
:
64
,
"no_bins"
:
64
,
"eps"
:
1e-6
,
"eps"
:
eps
,
#
1e-6,
"weight"
:
0.3
,
"weight"
:
0.
,
#
0.3,
},
},
"experimentally_resolved"
:
{
"experimentally_resolved"
:
{
"eps"
:
1e-8
,
"eps"
:
eps
,
#
1e-8,
"min_resolution"
:
0.1
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"weight"
:
0.
,
"weight"
:
0.
,
...
@@ -206,6 +212,7 @@ config = mlc.ConfigDict({
...
@@ -206,6 +212,7 @@ config = mlc.ConfigDict({
"length_scale"
:
10.
,
"length_scale"
:
10.
,
"weight"
:
0.5
,
"weight"
:
0.5
,
},
},
"eps"
:
1e-4
,
"weight"
:
1.0
,
"weight"
:
1.0
,
},
},
"lddt"
:
{
"lddt"
:
{
...
@@ -213,24 +220,25 @@ config = mlc.ConfigDict({
...
@@ -213,24 +220,25 @@ config = mlc.ConfigDict({
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.
,
"cutoff"
:
15.
,
"no_bins"
:
50
,
"no_bins"
:
50
,
"eps"
:
1e-10
,
"eps"
:
eps
,
#
1e-10,
"weight"
:
0.01
,
"weight"
:
0.
,
#
0.01,
},
},
"masked_msa"
:
{
"masked_msa"
:
{
"eps"
:
1e-8
,
"eps"
:
eps
,
#
1e-8,
"weight"
:
2.0
,
"weight"
:
0.
,
#
2.0,
},
},
"supervised_chi"
:
{
"supervised_chi"
:
{
"chi_weight"
:
0.5
,
"chi_weight"
:
0.5
,
"angle_norm_weight"
:
0.01
,
"angle_norm_weight"
:
0.01
,
"eps"
:
1e-6
,
"eps"
:
eps
,
#
1e-6,
"weight"
:
1.0
,
"weight"
:
0.
,
#
1.0,
},
},
"violation"
:
{
"violation"
:
{
"violation_tolerance_factor"
:
12.0
,
"violation_tolerance_factor"
:
12.0
,
"clash_overlap_tolerance"
:
1.5
,
"clash_overlap_tolerance"
:
1.5
,
"eps"
:
1e-6
,
"eps"
:
eps
,
#
1e-6,
"weight"
:
0.
,
"weight"
:
0.
,
},
},
"eps"
:
eps
,
},
},
})
})
openfold/model/model.py
View file @
1d47c1e7
...
@@ -115,14 +115,22 @@ class AlphaFold(nn.Module):
...
@@ -115,14 +115,22 @@ class AlphaFold(nn.Module):
batch
,
batch
,
)
)
#tensor_dtype = (
# single_template_feats["template_all_atom_masks"].dtype
#)
# Build template angle feats
# Build template angle feats
angle_feats
=
atom37_to_torsion_angles
(
angle_feats
=
atom37_to_torsion_angles
(
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_all_atom_positions"
],
single_template_feats
[
"template_all_atom_positions"
],
#.float(),
single_template_feats
[
"template_all_atom_masks"
],
single_template_feats
[
"template_all_atom_masks"
],
#.float(),
eps
=
1e-8
eps
=
self
.
config
.
template
.
eps
,
)
)
#angle_feats = tensor_tree_map(
# lambda t: t.type(tensor_dtype), angle_feats
#)
template_angle_feat
=
build_template_angle_feat
(
template_angle_feat
=
build_template_angle_feat
(
angle_feats
,
angle_feats
,
single_template_feats
[
"template_aatype"
],
single_template_feats
[
"template_aatype"
],
...
@@ -134,6 +142,7 @@ class AlphaFold(nn.Module):
...
@@ -134,6 +142,7 @@ class AlphaFold(nn.Module):
# [*, S_t, N, N, C_t]
# [*, S_t, N, N, C_t]
t
=
build_template_pair_feat
(
t
=
build_template_pair_feat
(
single_template_feats
,
single_template_feats
,
inf
=
self
.
config
.
template
.
inf
,
eps
=
self
.
config
.
template
.
eps
,
eps
=
self
.
config
.
template
.
eps
,
**
self
.
config
.
template
.
distogram
**
self
.
config
.
template
.
distogram
)
)
...
@@ -164,9 +173,9 @@ class AlphaFold(nn.Module):
...
@@ -164,9 +173,9 @@ class AlphaFold(nn.Module):
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
t
=
t
*
(
torch
.
sum
(
batch
[
"template_mask"
])
>
0
)
return
{
return
{
"template_angle_embedding"
:
a
,
"template_angle_embedding"
:
template_embeds
[
"angle"
]
,
"template_pair_embedding"
:
t
,
"template_pair_embedding"
:
t
,
"torsion_angles_mask"
:
angle_feat
s
[
"torsion_
angles_
mask"
],
"torsion_angles_mask"
:
template_embed
s
[
"torsion_mask"
],
}
}
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
def
iteration
(
self
,
feats
,
m_1_prev
,
z_prev
,
x_prev
):
...
...
openfold/model/msa.py
View file @
1d47c1e7
...
@@ -107,7 +107,6 @@ class MSAAttention(nn.Module):
...
@@ -107,7 +107,6 @@ class MSAAttention(nn.Module):
bias
=
bias
.
expand
(
bias
=
bias
.
expand
(
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
((
-
1
,)
*
len
(
bias
.
shape
[:
-
4
]))
+
(
-
1
,
self
.
no_heads
,
n_res
,
-
1
)
)
)
biases
=
[
bias
]
biases
=
[
bias
]
if
(
self
.
pair_bias
):
if
(
self
.
pair_bias
):
# [*, N_res, N_res, C_z]
# [*, N_res, N_res, C_z]
...
...
openfold/model/primitives.py
View file @
1d47c1e7
...
@@ -257,6 +257,8 @@ class Attention(nn.Module):
...
@@ -257,6 +257,8 @@ class Attention(nn.Module):
a
+=
b
a
+=
b
a
=
self
.
softmax
(
a
)
a
=
self
.
softmax
(
a
)
#print(torch.any(torch.isnan(a)))
# [*, H, Q, C_hidden]
# [*, H, Q, C_hidden]
o
=
torch
.
matmul
(
o
=
torch
.
matmul
(
a
,
a
,
...
...
openfold/model/template.py
View file @
1d47c1e7
...
@@ -50,6 +50,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -50,6 +50,7 @@ class TemplatePointwiseAttention(nn.Module):
c_hidden
,
c_hidden
,
no_heads
,
no_heads
,
chunk_size
,
chunk_size
,
inf
,
**
kwargs
**
kwargs
):
):
"""
"""
...
@@ -68,6 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -68,6 +69,7 @@ class TemplatePointwiseAttention(nn.Module):
self
.
c_hidden
=
c_hidden
self
.
c_hidden
=
c_hidden
self
.
no_heads
=
no_heads
self
.
no_heads
=
no_heads
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
mha
=
Attention
(
self
.
mha
=
Attention
(
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
self
.
c_z
,
self
.
c_t
,
self
.
c_t
,
...
@@ -89,11 +91,11 @@ class TemplatePointwiseAttention(nn.Module):
...
@@ -89,11 +91,11 @@ class TemplatePointwiseAttention(nn.Module):
"""
"""
if
(
template_mask
is
None
):
if
(
template_mask
is
None
):
# NOTE: This is not the "template_mask" from the supplement, but a
# NOTE: This is not the "template_mask" from the supplement, but a
# [*, N_templ] mask from the code. I'm pretty sure it's always just
1,
# [*, N_templ] mask from the code. I'm pretty sure it's always just
# but not sure enough to remove it. It's nice to have, I guess.
#
1,
but not sure enough to remove it. It's nice to have, I guess.
template_mask
=
t
.
new_ones
(
t
.
shape
[:
-
3
])
template_mask
=
t
.
new_ones
(
t
.
shape
[:
-
3
])
bias
=
(
1e9
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
))
bias
=
(
self
.
inf
*
(
template_mask
[...,
None
,
None
,
None
,
None
,
:]
-
1
))
# [*, N_res, N_res, 1, C_z]
# [*, N_res, N_res, 1, C_z]
z
=
z
.
unsqueeze
(
-
2
)
z
=
z
.
unsqueeze
(
-
2
)
...
@@ -133,6 +135,8 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -133,6 +135,8 @@ class TemplatePairStackBlock(nn.Module):
pair_transition_n
,
pair_transition_n
,
dropout_rate
,
dropout_rate
,
chunk_size
,
chunk_size
,
inf
,
**
kwargs
,
):
):
super
(
TemplatePairStackBlock
,
self
).
__init__
()
super
(
TemplatePairStackBlock
,
self
).
__init__
()
...
@@ -143,6 +147,7 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -143,6 +147,7 @@ class TemplatePairStackBlock(nn.Module):
self
.
pair_transition_n
=
pair_transition_n
self
.
pair_transition_n
=
pair_transition_n
self
.
dropout_rate
=
dropout_rate
self
.
dropout_rate
=
dropout_rate
self
.
chunk_size
=
chunk_size
self
.
chunk_size
=
chunk_size
self
.
inf
=
inf
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_row
=
DropoutRowwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
self
.
dropout_col
=
DropoutColumnwise
(
self
.
dropout_rate
)
...
@@ -152,12 +157,14 @@ class TemplatePairStackBlock(nn.Module):
...
@@ -152,12 +157,14 @@ class TemplatePairStackBlock(nn.Module):
self
.
c_hidden_tri_att
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
)
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
tri_att_end
=
TriangleAttentionEndingNode
(
self
.
c_t
,
self
.
c_t
,
self
.
c_hidden_tri_att
,
self
.
c_hidden_tri_att
,
self
.
no_heads
,
self
.
no_heads
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
)
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
self
.
tri_mul_out
=
TriangleMultiplicationOutgoing
(
...
@@ -200,6 +207,7 @@ class TemplatePairStack(nn.Module):
...
@@ -200,6 +207,7 @@ class TemplatePairStack(nn.Module):
dropout_rate
,
dropout_rate
,
blocks_per_ckpt
,
blocks_per_ckpt
,
chunk_size
,
chunk_size
,
inf
=
1e9
,
**
kwargs
,
**
kwargs
,
):
):
"""
"""
...
@@ -237,6 +245,7 @@ class TemplatePairStack(nn.Module):
...
@@ -237,6 +245,7 @@ class TemplatePairStack(nn.Module):
pair_transition_n
=
pair_transition_n
,
pair_transition_n
=
pair_transition_n
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
chunk_size
=
chunk_size
,
chunk_size
=
chunk_size
,
inf
=
inf
,
)
)
self
.
blocks
.
append
(
block
)
self
.
blocks
.
append
(
block
)
...
...
openfold/utils/affine_utils.py
View file @
1d47c1e7
...
@@ -57,11 +57,17 @@ class T:
...
@@ -57,11 +57,17 @@ class T:
raise
ValueError
(
"Only one of rots and trans can be None"
)
raise
ValueError
(
"Only one of rots and trans can be None"
)
elif
(
self
.
rots
is
None
):
elif
(
self
.
rots
is
None
):
self
.
rots
=
T
.
identity_rot
(
self
.
rots
=
T
.
identity_rot
(
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
dtype
,
self
.
trans
.
device
self
.
trans
.
shape
[:
-
1
],
self
.
trans
.
dtype
,
self
.
trans
.
device
,
self
.
trans
.
requires_grad
,
)
)
elif
(
self
.
trans
is
None
):
elif
(
self
.
trans
is
None
):
self
.
trans
=
T
.
identity_trans
(
self
.
trans
=
T
.
identity_trans
(
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
dtype
,
self
.
rots
.
device
self
.
rots
.
shape
[:
-
2
],
self
.
rots
.
dtype
,
self
.
rots
.
device
,
self
.
rots
.
requires_grad
)
)
if
(
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
or
if
(
self
.
rots
.
shape
[
-
2
:]
!=
(
3
,
3
)
or
...
@@ -137,7 +143,7 @@ class T:
...
@@ -137,7 +143,7 @@ class T:
return
T
(
rots
,
trans
)
return
T
(
rots
,
trans
)
@
staticmethod
@
staticmethod
def
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
=
False
):
def
identity_rot
(
shape
,
dtype
,
device
,
requires_grad
):
rots
=
torch
.
eye
(
rots
=
torch
.
eye
(
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
3
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
requires_grad
)
)
...
@@ -147,7 +153,7 @@ class T:
...
@@ -147,7 +153,7 @@ class T:
return
rots
return
rots
@
staticmethod
@
staticmethod
def
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
=
False
):
def
identity_trans
(
shape
,
dtype
,
device
,
requires_grad
):
trans
=
torch
.
zeros
(
trans
=
torch
.
zeros
(
(
*
shape
,
3
),
(
*
shape
,
3
),
dtype
=
dtype
,
dtype
=
dtype
,
...
@@ -182,20 +188,22 @@ class T:
...
@@ -182,20 +188,22 @@ class T:
@
staticmethod
@
staticmethod
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
def
from_3_points
(
p_neg_x_axis
,
origin
,
p_xy_plane
,
eps
=
1e-8
):
v1
=
origin
-
p_neg_x_axis
e0
=
origin
-
p_neg_x_axis
v2
=
p_xy_plane
-
origin
e1
=
p_xy_plane
-
origin
e1
=
v1
/
torch
.
sqrt
(
torch
.
sum
(
v1
**
2
,
dim
=-
1
)
+
eps
)[...,
None
]
u2
=
v2
-
e1
*
(
torch
.
einsum
(
'...i,...i->...'
,
v2
,
e1
)[...,
None
])
e2
=
u2
/
torch
.
sqrt
(
torch
.
sum
(
u2
**
2
,
dim
=-
1
)
+
eps
)[...,
None
]
e3
=
torch
.
cross
(
e1
,
e2
,
dim
=-
1
)
rots
=
torch
.
cat
(
# Angle norming is very sensitive to floating point imprecisions
(
#float_type = e0.dtype
e1
.
unsqueeze
(
-
1
),
#e0 = e0.float()
e2
.
unsqueeze
(
-
1
),
#e1 = e1.float()
e3
.
unsqueeze
(
-
1
),
),
dim
=-
1
,
e0
=
e0
/
torch
.
sqrt
(
torch
.
sum
(
e0
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
)
e1
=
e1
-
e0
*
torch
.
sum
(
e0
*
e1
,
dim
=-
1
,
keepdims
=
True
)
e1
=
e1
/
torch
.
sqrt
(
torch
.
sum
(
e1
**
2
,
dim
=-
1
,
keepdims
=
True
)
+
eps
)
e2
=
torch
.
cross
(
e0
,
e1
)
rots
=
torch
.
stack
([
e0
,
e1
,
e2
],
dim
=-
1
)
#rots = rots.type(float_type)
return
T
(
rots
,
origin
)
return
T
(
rots
,
origin
)
...
...
openfold/utils/deepspeed.py
View file @
1d47c1e7
...
@@ -70,6 +70,9 @@ def checkpoint_blocks(
...
@@ -70,6 +70,9 @@ def checkpoint_blocks(
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
for
s
in
range
(
0
,
len
(
blocks
),
blocks_per_ckpt
):
e
=
s
+
blocks_per_ckpt
e
=
s
+
blocks_per_ckpt
#print(len(args))
#for a in args:
# print(a.requires_grad)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
args
=
checkpoint
(
chunker
(
s
,
e
),
*
args
)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
#args = deepspeed.checkpointing.checkpoint(chunker(s, e), *args)
args
=
wrap
(
args
)
args
=
wrap
(
args
)
...
...
openfold/utils/feats.py
View file @
1d47c1e7
...
@@ -286,7 +286,7 @@ def atom37_to_torsion_angles(
...
@@ -286,7 +286,7 @@ def atom37_to_torsion_angles(
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
torch
.
square
(
torsion_angles_sin_cos
),
dim
=-
1
,
keepdims
=
True
)
+
eps
)
+
eps
)
)
torsion_angles_sin_cos
/
=
denom
torsion_angles_sin_cos
=
torsion_angles_sin_cos
/
denom
torsion_angles_sin_cos
*=
torch
.
tensor
(
torsion_angles_sin_cos
*=
torch
.
tensor
(
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
[
1.
,
1.
,
-
1.
,
1.
,
1.
,
1.
,
1.
],
device
=
aatype
.
device
,
...
@@ -298,7 +298,7 @@ def atom37_to_torsion_angles(
...
@@ -298,7 +298,7 @@ def atom37_to_torsion_angles(
mirror_torsion_angles
=
torch
.
cat
(
mirror_torsion_angles
=
torch
.
cat
(
[
[
a
atype
.
new_ones
(
*
aatype
.
shape
,
3
),
a
ll_atom_mask
.
new_ones
(
*
aatype
.
shape
,
3
),
1.
-
2.
*
chi_is_ambiguous
1.
-
2.
*
chi_is_ambiguous
],
dim
=-
1
],
dim
=-
1
)
)
...
...
openfold/utils/loss.py
View file @
1d47c1e7
...
@@ -80,7 +80,7 @@ def compute_fape(
...
@@ -80,7 +80,7 @@ def compute_fape(
positions_mask
:
torch
.
Tensor
,
positions_mask
:
torch
.
Tensor
,
length_scale
:
float
,
length_scale
:
float
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
l1_clamp_distance
:
Optional
[
float
]
=
None
,
eps
=
1e-
4
eps
=
1e-
8
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
# [*, N_frames, N_pts, 3]
# [*, N_frames, N_pts, 3]
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
local_pred_pos
=
pred_frames
.
invert
()[...,
None
].
apply
(
...
@@ -100,12 +100,19 @@ def compute_fape(
...
@@ -100,12 +100,19 @@ def compute_fape(
normed_error
*=
frames_mask
[...,
None
]
normed_error
*=
frames_mask
[...,
None
]
normed_error
*=
positions_mask
[...,
None
,
:]
normed_error
*=
positions_mask
[...,
None
,
:]
norm_factor
=
(
# FP16-friendly averaging. Roughly equivalent to:
torch
.
sum
(
frames_mask
,
dim
=-
1
)
*
#
torch
.
sum
(
positions_mask
,
dim
=-
1
)
# norm_factor = (
)
# torch.sum(frames_mask, dim=-1) *
# torch.sum(positions_mask, dim=-1)
normed_error
=
torch
.
sum
(
normed_error
,
dim
=
(
-
1
,
-
2
))
/
(
eps
+
norm_factor
)
# )
# normed_error = torch.sum(normed_error, dim=(-1, -2)) / (eps + norm_factor)
#
# ("roughly" because eps is necessarily duplicated in the latter
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
frames_mask
,
dim
=-
1
))[...,
None
]
normed_error
=
torch
.
sum
(
normed_error
,
dim
=-
1
)
normed_error
=
normed_error
/
(
eps
+
torch
.
sum
(
positions_mask
,
dim
=-
1
))
return
normed_error
return
normed_error
...
@@ -118,6 +125,7 @@ def backbone_loss(
...
@@ -118,6 +125,7 @@ def backbone_loss(
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
use_clamped_fape
:
Optional
[
torch
.
Tensor
]
=
None
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
loss_unit_distance
:
float
=
10.
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
pred_aff
=
T
.
from_tensor
(
traj
)
pred_aff
=
T
.
from_tensor
(
traj
)
...
@@ -132,6 +140,7 @@ def backbone_loss(
...
@@ -132,6 +140,7 @@ def backbone_loss(
backbone_affine_mask
[...,
None
,
:],
backbone_affine_mask
[...,
None
,
:],
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
)
if
(
use_clamped_fape
is
not
None
):
if
(
use_clamped_fape
is
not
None
):
...
@@ -144,6 +153,7 @@ def backbone_loss(
...
@@ -144,6 +153,7 @@ def backbone_loss(
backbone_affine_mask
[...,
None
,
:],
backbone_affine_mask
[...,
None
,
:],
l1_clamp_distance
=
None
,
l1_clamp_distance
=
None
,
length_scale
=
loss_unit_distance
,
length_scale
=
loss_unit_distance
,
eps
=
eps
,
)
)
fape_loss
=
(
fape_loss
=
(
...
@@ -167,6 +177,7 @@ def sidechain_loss(
...
@@ -167,6 +177,7 @@ def sidechain_loss(
alt_naming_is_better
:
torch
.
Tensor
,
alt_naming_is_better
:
torch
.
Tensor
,
clamp_distance
:
float
=
10.
,
clamp_distance
:
float
=
10.
,
length_scale
:
float
=
10.
,
length_scale
:
float
=
10.
,
eps
:
float
=
1e-4
,
**
kwargs
,
**
kwargs
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
renamed_gt_frames
=
(
renamed_gt_frames
=
(
...
@@ -210,6 +221,7 @@ def sidechain_loss(
...
@@ -210,6 +221,7 @@ def sidechain_loss(
renamed_atom14_gt_exists
,
renamed_atom14_gt_exists
,
l1_clamp_distance
=
clamp_distance
,
l1_clamp_distance
=
clamp_distance
,
length_scale
=
length_scale
,
length_scale
=
length_scale
,
eps
=
eps
,
)
)
return
fape
return
fape
...
@@ -428,10 +440,14 @@ def distogram_loss(
...
@@ -428,10 +440,14 @@ def distogram_loss(
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
square_mask
=
pseudo_beta_mask
[...,
None
]
*
pseudo_beta_mask
[...,
None
,
:]
mean
=
(
# FP16-friendly sum. Equivalent to:
torch
.
sum
(
errors
*
square_mask
,
dim
=
(
-
1
,
-
2
))
/
# mean = (torch.sum(errors * square_mask, dim=(-1, -2)) /
(
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
)))
# (eps + torch.sum(square_mask, dim=(-1, -2))))
)
denom
=
eps
+
torch
.
sum
(
square_mask
,
dim
=
(
-
1
,
-
2
))
mean
=
errors
*
square_mask
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
mean
=
mean
/
denom
[...,
None
]
mean
=
torch
.
sum
(
mean
,
dim
=-
1
)
return
mean
return
mean
...
@@ -1285,10 +1301,18 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
...
@@ -1285,10 +1301,18 @@ def masked_msa_loss(logits, true_msa, bert_mask, eps=1e-8, **kwargs):
logits
,
logits
,
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
torch
.
nn
.
functional
.
one_hot
(
true_msa
,
num_classes
=
23
)
)
)
loss
=
(
torch
.
sum
(
errors
*
bert_mask
,
dim
=
(
-
1
,
-
2
))
/
# FP16-friendly averaging. Equivalent to:
(
eps
+
torch
.
sum
(
bert_mask
,
dim
=
(
-
1
,
-
2
)))
# loss = (
)
# torch.sum(errors * bert_mask, dim=(-1, -2)) /
# (eps + torch.sum(bert_mask, dim=(-1, -2)))
# )
denom
=
eps
+
torch
.
sum
(
bert_mask
,
dim
=
(
-
1
,
-
2
))
loss
=
errors
*
bert_mask
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
loss
=
loss
/
denom
[...,
None
]
loss
=
torch
.
sum
(
loss
,
dim
=-
1
)
return
loss
return
loss
...
@@ -1299,8 +1323,6 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1299,8 +1323,6 @@ class AlphaFoldLoss(nn.Module):
self
.
config
=
config
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
def
forward
(
self
,
out
,
batch
):
cum_loss
=
0
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
if
(
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
):
out
[
"violation"
]
=
find_structural_violations
(
out
[
"violation"
]
=
find_structural_violations
(
batch
,
batch
,
...
@@ -1331,6 +1353,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1331,6 +1353,7 @@ class AlphaFoldLoss(nn.Module):
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
if
(
"chi_angles_sin_cos"
not
in
batch
.
keys
()):
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
batch
.
update
(
feats
.
atom37_to_torsion_angles
(
**
batch
,
**
batch
,
eps
=
self
.
config
.
eps
,
))
))
# TODO: Verify that this is correct
# TODO: Verify that this is correct
...
@@ -1382,12 +1405,14 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1382,12 +1405,14 @@ class AlphaFoldLoss(nn.Module):
),
),
}
}
cum_loss
=
0
for
k
,
loss_fn
in
loss_fns
.
items
():
for
k
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
k
].
weight
weight
=
self
.
config
[
k
].
weight
if
(
weight
):
if
(
weight
):
print
(
k
)
loss
=
loss_fn
()
loss
=
loss_fn
()
#print(k)
print
(
weight
*
loss
)
#print(loss)
cum_loss
=
cum_loss
+
weight
*
loss
cum_loss
+=
weight
*
loss
print
(
cum_loss
)
return
cum_loss
return
cum_loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment