Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4c7c30b1
Commit
4c7c30b1
authored
Feb 22, 2022
by
Gustaf Ahdritz
Browse files
Fix chain data cache script bug, improve logging
parent
61d004a2
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
48 deletions
+64
-48
openfold/utils/loss.py
openfold/utils/loss.py
+2
-0
scripts/generate_chain_data_cache.py
scripts/generate_chain_data_cache.py
+8
-8
train_openfold.py
train_openfold.py
+54
-40
No files found.
openfold/utils/loss.py
View file @
4c7c30b1
...
...
@@ -1627,6 +1627,8 @@ class AlphaFoldLoss(nn.Module):
crop_len
=
batch
[
"aatype"
].
shape
[
-
1
]
cum_loss
=
cum_loss
*
torch
.
sqrt
(
min
(
seq_len
,
crop_len
))
losses
[
"loss"
]
=
cum_loss
.
detach
().
clone
()
if
(
not
_return_breakdown
):
return
cum_loss
...
...
scripts/generate_chain_data_cache.py
View file @
4c7c30b1
...
...
@@ -39,13 +39,8 @@ def parse_file(
local_data
[
"seq"
]
=
seq
local_data
[
"resolution"
]
=
mmcif
.
header
[
"resolution"
]
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
None
)
if
(
cluster_size
is
None
):
print
(
file_id
)
out
.
pop
(
full_name
)
continue
else
:
local_data
[
"cluster_size"
]
=
cluster_size
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
local_data
[
"cluster_size"
]
=
cluster_size
elif
(
ext
==
".pdb"
):
with
open
(
os
.
path
.
join
(
args
.
data_dir
,
f
),
"r"
)
as
fp
:
pdb_string
=
fp
.
read
()
...
...
@@ -112,7 +107,12 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--cluster_file"
,
type
=
str
,
default
=
None
,
help
=
"Path to a cluster file (e.g. PDB40), one cluster per line"
help
=
(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
...
...
train_openfold.py
View file @
4c7c30b1
...
...
@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
_log
(
self
,
loss
,
loss_breakdown
,
train
=
True
):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"train/loss_epoch"
,
loss_breakdown
[
"loss"
],
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
...
...
@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
,
)
self
.
log
(
"train/loss_epoch"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"train/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"train/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
)
return
loss
...
...
@@ -125,23 +134,12 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
loss
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
self
.
log
(
"val/loss"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"val/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
True
,
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"val/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
...
...
@@ -440,18 +438,23 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log metrics to Weights & Biases"
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
help
=
"Name of the current experiment. Used for wandb logging"
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
help
=
"ID of a previous run to be resumed"
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
help
=
"Name of the wandb project to which this run will belong"
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
help
=
"wandb username or team name to which runs are attributed"
)
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
...
...
@@ -465,16 +468,27 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log the actual learning rate"
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"initial_training"
,
help
=
'Config setting. Choose e.g. "initial_training", "finetuning", "model_1", etc.'
help
=
(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment