Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4c7c30b1
"cmd/vscode:/vscode.git/clone" did not exist on "3abd184e05929f8961daf4ecac8a01263d5c158f"
Commit
4c7c30b1
authored
Feb 22, 2022
by
Gustaf Ahdritz
Browse files
Fix chain data cache script bug, improve logging
parent
61d004a2
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
48 deletions
+64
-48
openfold/utils/loss.py
openfold/utils/loss.py
+2
-0
scripts/generate_chain_data_cache.py
scripts/generate_chain_data_cache.py
+8
-8
train_openfold.py
train_openfold.py
+54
-40
No files found.
openfold/utils/loss.py
View file @
4c7c30b1
...
...
@@ -1627,6 +1627,8 @@ class AlphaFoldLoss(nn.Module):
crop_len
=
batch
[
"aatype"
].
shape
[
-
1
]
cum_loss
=
cum_loss
*
torch
.
sqrt
(
min
(
seq_len
,
crop_len
))
losses
[
"loss"
]
=
cum_loss
.
detach
().
clone
()
if
(
not
_return_breakdown
):
return
cum_loss
...
...
scripts/generate_chain_data_cache.py
View file @
4c7c30b1
...
...
@@ -39,12 +39,7 @@ def parse_file(
local_data
[
"seq"
]
=
seq
local_data
[
"resolution"
]
=
mmcif
.
header
[
"resolution"
]
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
None
)
if
(
cluster_size
is
None
):
print
(
file_id
)
out
.
pop
(
full_name
)
continue
else
:
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
local_data
[
"cluster_size"
]
=
cluster_size
elif
(
ext
==
".pdb"
):
with
open
(
os
.
path
.
join
(
args
.
data_dir
,
f
),
"r"
)
as
fp
:
...
...
@@ -112,7 +107,12 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--cluster_file"
,
type
=
str
,
default
=
None
,
help
=
"Path to a cluster file (e.g. PDB40), one cluster per line"
help
=
(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
...
...
train_openfold.py
View file @
4c7c30b1
...
...
@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
_log
(
self
,
loss
,
loss_breakdown
,
train
=
True
):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"train/loss_epoch"
,
loss_breakdown
[
"loss"
],
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
...
...
@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
,
)
self
.
log
(
"train/loss_epoch"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"train/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"train/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
)
return
loss
...
...
@@ -125,22 +134,11 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
loss
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
self
.
log
(
"val/loss"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"val/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
True
,
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"val/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
...
...
@@ -440,18 +438,23 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log metrics to Weights & Biases"
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
help
=
"Name of the current experiment. Used for wandb logging"
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
help
=
"ID of a previous run to be resumed"
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
help
=
"Name of the wandb project to which this run will belong"
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
help
=
"wandb username or team name to which runs are attributed"
)
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
...
...
@@ -465,16 +468,27 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
help
=
(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log the actual learning rate"
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"initial_training"
,
help
=
'Config setting. Choose e.g. "initial_training", "finetuning", "model_1", etc.'
help
=
(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment