Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
e1b69c13
Commit
e1b69c13
authored
Feb 23, 2022
by
Gustaf Ahdritz
Browse files
Merge branch 'main' of
ssh://github.com/aqlaboratory/openfold
into multimer
parents
e699d7d2
4a613bbe
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
77 additions
and
57 deletions
+77
-57
openfold/utils/loss.py
openfold/utils/loss.py
+2
-0
scripts/generate_chain_data_cache.py
scripts/generate_chain_data_cache.py
+19
-15
train_openfold.py
train_openfold.py
+56
-42
No files found.
openfold/utils/loss.py
View file @
e1b69c13
...
...
@@ -1640,6 +1640,8 @@ class AlphaFoldLoss(nn.Module):
crop_len
=
batch
[
"aatype"
].
shape
[
-
1
]
cum_loss
=
cum_loss
*
torch
.
sqrt
(
min
(
seq_len
,
crop_len
))
losses
[
"loss"
]
=
cum_loss
.
detach
().
clone
()
if
(
not
_return_breakdown
):
return
cum_loss
...
...
scripts/generate_chain_data_cache.py
View file @
e1b69c13
...
...
@@ -39,12 +39,10 @@ def parse_file(
local_data
[
"seq"
]
=
seq
local_data
[
"resolution"
]
=
mmcif
.
header
[
"resolution"
]
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
None
)
if
(
cluster_size
is
None
):
print
(
file_id
)
out
.
pop
(
full_name
)
continue
else
:
if
(
chain_cluster_size_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
local_data
[
"cluster_size"
]
=
cluster_size
elif
(
ext
==
".pdb"
):
with
open
(
os
.
path
.
join
(
args
.
data_dir
,
f
),
"r"
)
as
fp
:
...
...
@@ -58,12 +56,12 @@ def parse_file(
)
local_data
[
"resolution"
]
=
0.
cluster_size
=
chain_cluster_size_dict
.
get
(
file_id
.
upper
(),
None
)
if
(
cluster_size
is
None
):
print
(
file_id
)
return
{}
else
:
local_data
[
"cluster_size"
]
=
cluster_size
cluster_size
=
chain_cluster_size_dict
.
get
(
file_id
.
upper
(),
-
1
)
if
(
chain_
cluster_size
_dict
is
not
None
):
cluster_size
=
chain_cluster_size_dict
.
get
(
full_name
.
upper
(),
-
1
)
chain_dict
[
"cluster_size"
]
=
cluster_size
out
=
{
file_id
:
chain_dict
}
...
...
@@ -71,8 +69,9 @@ def parse_file(
def
main
(
args
):
chain_cluster_size_dict
=
{}
chain_cluster_size_dict
=
None
if
(
args
.
cluster_file
is
not
None
):
chain_cluster_size_dict
=
{}
with
open
(
args
.
cluster_file
,
"r"
)
as
fp
:
clusters
=
[
l
.
strip
()
for
l
in
fp
.
readlines
()]
...
...
@@ -112,7 +111,12 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--cluster_file"
,
type
=
str
,
default
=
None
,
help
=
"Path to a cluster file (e.g. PDB40), one cluster per line"
help
=
(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser
.
add_argument
(
"--no_workers"
,
type
=
int
,
default
=
4
,
...
...
train_openfold.py
View file @
e1b69c13
...
...
@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
_log
(
self
,
loss_breakdown
,
batch
,
outputs
,
train
=
True
):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"train/loss_epoch"
,
loss_breakdown
[
"loss"
],
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
...
...
@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
,
)
self
.
log
(
"train/loss_epoch"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"train/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"train/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
)
return
loss
...
...
@@ -108,7 +117,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
ema
.
update
(
self
.
model
)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug
# # Temporary measure to address DeepSpeed scheduler bug
(PL issue 11694)
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
...
...
@@ -125,29 +134,18 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss and other metrics
batch
[
"use_clamped_fape"
]
=
0.
loss
,
loss_breakdown
=
self
.
loss
(
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
batch
,
_return_breakdown
=
True
)
self
.
log
(
"val/loss"
,
loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"val/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
other_metrics
=
self
.
compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
True
,
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"val/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
compute_validation_metrics
(
self
,
def
_
compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
...
...
@@ -440,18 +438,23 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log metrics to Weights & Biases"
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
help
=
"Name of the current experiment. Used for wandb logging"
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
help
=
"ID of a previous run to be resumed"
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
help
=
"Name of the wandb project to which this run will belong"
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
help
=
"wandb username or team name to which runs are attributed"
)
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
...
...
@@ -465,16 +468,27 @@ if __name__ == "__main__":
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
help
=
(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log the actual learning rate"
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"initial_training"
,
help
=
'Config setting. Choose e.g. "initial_training", "finetuning", "model_1", etc.'
help
=
(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser
.
add_argument
(
"--_alignment_index_path"
,
type
=
str
,
default
=
None
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment