Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
b066dd55
Commit
b066dd55
authored
Jan 21, 2022
by
Gustaf Ahdritz
Browse files
Fix no-DeepSpeed DDP
parent
8ffec72a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
23 additions
and
18 deletions
+23
-18
openfold/config.py
openfold/config.py
+10
-6
openfold/utils/loss.py
openfold/utils/loss.py
+13
-12
No files found.
openfold/config.py
View file @
b066dd55
...
...
@@ -64,6 +64,7 @@ c_s = mlc.FieldReference(384, field_type=int)
blocks_per_ckpt
=
mlc
.
FieldReference
(
None
,
field_type
=
int
)
chunk_size
=
mlc
.
FieldReference
(
4
,
field_type
=
int
)
aux_distogram_bins
=
mlc
.
FieldReference
(
64
,
field_type
=
int
)
tm_enabled
=
mlc
.
FieldReference
(
False
,
field_type
=
bool
)
eps
=
mlc
.
FieldReference
(
1e-8
,
field_type
=
float
)
templates_enabled
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
embed_template_torsion_angles
=
mlc
.
FieldReference
(
True
,
field_type
=
bool
)
...
...
@@ -148,7 +149,7 @@ config = mlc.ConfigDict(
"same_prob"
:
0.1
,
"uniform_prob"
:
0.1
,
},
"max_extra_msa"
:
5120
,
"max_extra_msa"
:
1024
,
"max_recycling_iters"
:
3
,
"msa_cluster_features"
:
True
,
"reduce_msa_clusters_by_max_templates"
:
False
,
...
...
@@ -174,7 +175,6 @@ config = mlc.ConfigDict(
},
"supervised"
:
{
"clamp_prob"
:
0.9
,
"uniform_recycling"
:
True
,
"supervised_features"
:
[
"all_atom_mask"
,
"all_atom_positions"
,
...
...
@@ -194,6 +194,7 @@ config = mlc.ConfigDict(
"crop_size"
:
None
,
"supervised"
:
False
,
"subsample_recycling"
:
False
,
"uniform_recycling"
:
False
,
},
"eval"
:
{
"fixed_size"
:
True
,
...
...
@@ -206,27 +207,29 @@ config = mlc.ConfigDict(
"crop_size"
:
None
,
"supervised"
:
True
,
"subsample_recycling"
:
False
,
"uniform_recycling"
:
False
,
},
"train"
:
{
"fixed_size"
:
True
,
"subsample_templates"
:
True
,
"masked_msa_replace_fraction"
:
0.15
,
"max_msa_clusters"
:
5
12
,
"max_msa_clusters"
:
12
8
,
"max_template_hits"
:
4
,
"max_templates"
:
4
,
"shuffle_top_k_prefiltered"
:
20
,
"crop"
:
True
,
"crop_size"
:
384
,
"crop_size"
:
256
,
"supervised"
:
True
,
"clamp_prob"
:
0.9
,
"subsample_recycling"
:
True
,
"max_distillation_msa_clusters"
:
1000
,
"uniform_recycling"
:
True
,
},
"data_module"
:
{
"use_small_bfd"
:
False
,
"data_loaders"
:
{
"batch_size"
:
1
,
"num_workers"
:
2
,
"num_workers"
:
4
,
},
},
},
...
...
@@ -374,7 +377,7 @@ config = mlc.ConfigDict(
"tm"
:
{
"c_z"
:
c_z
,
"no_bins"
:
aux_distogram_bins
,
"enabled"
:
False
,
"enabled"
:
tm_enabled
,
},
"masked_msa"
:
{
"c_m"
:
c_m
,
...
...
@@ -452,6 +455,7 @@ config = mlc.ConfigDict(
"max_resolution"
:
3.0
,
"eps"
:
eps
,
# 1e-8,
"weight"
:
0.0
,
"enabled"
:
tm_enabled
,
},
"eps"
:
eps
,
},
...
...
openfold/utils/loss.py
View file @
b066dd55
...
...
@@ -43,8 +43,8 @@ def softmax_cross_entropy(logits, labels):
def
sigmoid_cross_entropy
(
logits
,
labels
):
log_p
=
torch
.
nn
.
functional
.
log
sigmoid
(
logits
)
log_not_p
=
torch
.
nn
.
functional
.
log
sigmoid
(
-
logits
)
log_p
=
torch
.
log
(
torch
.
sigmoid
(
logits
)
)
log_not_p
=
torch
.
log
(
torch
.
sigmoid
(
-
logits
)
)
loss
=
-
labels
*
log_p
-
(
1
-
labels
)
*
log_not_p
return
loss
...
...
@@ -1462,7 +1462,7 @@ class AlphaFoldLoss(nn.Module):
self
.
config
=
config
def
forward
(
self
,
out
,
batch
):
if
"violation"
not
in
out
.
keys
()
and
self
.
config
.
violation
.
weight
:
if
"violation"
not
in
out
.
keys
():
out
[
"violation"
]
=
find_structural_violations
(
batch
,
out
[
"sm"
][
"positions"
][
-
1
],
...
...
@@ -1509,20 +1509,21 @@ class AlphaFoldLoss(nn.Module):
out
[
"violation"
],
**
batch
,
),
"tm"
:
lambda
:
tm_loss
(
}
if
(
self
.
config
.
tm
.
enabled
):
loss_fns
[
"tm"
]
=
lambda
:
tm_loss
(
logits
=
out
[
"tm_logits"
],
**
{
**
batch
,
**
out
,
**
self
.
config
.
tm
},
),
}
)
cum_loss
=
0.
for
loss_name
,
loss_fn
in
loss_fns
.
items
():
weight
=
self
.
config
[
loss_name
].
weight
if
weight
:
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
loss
=
loss_fn
()
if
(
torch
.
isnan
(
loss
)
or
torch
.
isinf
(
loss
)):
logging
.
warning
(
f
"
{
loss_name
}
loss is NaN. Skipping..."
)
loss
=
loss
.
new_tensor
(
0.
,
requires_grad
=
True
)
cum_loss
=
cum_loss
+
weight
*
loss
return
cum_loss
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment