Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
4358096c
Commit
4358096c
authored
Jul 20, 2022
by
Gustaf Ahdritz
Browse files
Fix pLDDT bug
parent
236c6865
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
172 deletions
+37
-172
openfold/config.py
openfold/config.py
+1
-1
openfold/utils/loss.py
openfold/utils/loss.py
+1
-1
train_openfold.py
train_openfold.py
+35
-170
No files found.
openfold/config.py
View file @
4358096c
...
@@ -550,7 +550,7 @@ config = mlc.ConfigDict(
...
@@ -550,7 +550,7 @@ config = mlc.ConfigDict(
"eps"
:
1e-4
,
"eps"
:
1e-4
,
"weight"
:
1.0
,
"weight"
:
1.0
,
},
},
"lddt"
:
{
"
p
lddt
_loss
"
:
{
"min_resolution"
:
0.1
,
"min_resolution"
:
0.1
,
"max_resolution"
:
3.0
,
"max_resolution"
:
3.0
,
"cutoff"
:
15.0
,
"cutoff"
:
15.0
,
...
...
openfold/utils/loss.py
View file @
4358096c
...
@@ -1562,7 +1562,7 @@ class AlphaFoldLoss(nn.Module):
...
@@ -1562,7 +1562,7 @@ class AlphaFoldLoss(nn.Module):
"plddt_loss"
:
lambda
:
lddt_loss
(
"plddt_loss"
:
lambda
:
lddt_loss
(
logits
=
out
[
"lddt_logits"
],
logits
=
out
[
"lddt_logits"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
all_atom_pred_pos
=
out
[
"final_atom_positions"
],
**
{
**
batch
,
**
self
.
config
.
lddt
},
**
{
**
batch
,
**
self
.
config
.
p
lddt
_loss
},
),
),
"masked_msa"
:
lambda
:
masked_msa_loss
(
"masked_msa"
:
lambda
:
masked_msa_loss
(
logits
=
out
[
"masked_msa_logits"
],
logits
=
out
[
"masked_msa_logits"
],
...
...
train_openfold.py
View file @
4358096c
...
@@ -8,7 +8,6 @@ import os
...
@@ -8,7 +8,6 @@ import os
#os.environ["NODE_RANK"]="0"
#os.environ["NODE_RANK"]="0"
import
random
import
random
import
sys
import
time
import
time
import
numpy
as
np
import
numpy
as
np
...
@@ -27,22 +26,14 @@ from openfold.data.data_modules import (
...
@@ -27,22 +26,14 @@ from openfold.data.data_modules import (
)
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.callbacks
import
(
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
EarlyStoppingVerbose
,
)
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.argparse
import
remove_arguments
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.loss
import
AlphaFoldLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
drmsd
,
gdt_ts
,
gdt_ha
,
)
from
scripts.zero_to_fp32
import
(
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
get_fp32_state_dict_from_zero_checkpoint
)
)
...
@@ -66,36 +57,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -66,36 +57,6 @@ class OpenFoldWrapper(pl.LightningModule):
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
def
_log
(
self
,
loss_breakdown
,
batch
,
outputs
,
train
=
True
):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
v
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
...
@@ -107,121 +68,54 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -107,121 +68,54 @@ class OpenFoldWrapper(pl.LightningModule):
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
# Compute loss
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
loss
=
self
.
loss
(
outputs
,
batch
)
outputs
,
batch
,
_return_breakdown
=
True
)
# Log it
self
.
log
(
"train/loss"
,
loss
,
on_step
=
True
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
)
return
loss
return
loss
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
def
training_step_end
(
self
,
outputs
):
self
.
ema
.
update
(
self
.
model
)
# Temporary measure to address DeepSpeed scheduler bug
if
(
self
.
trainer
.
global_step
!=
self
.
last_lr_step
):
self
.
lr_schedulers
().
step
()
self
.
last_lr_step
=
self
.
trainer
.
global_step
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
self
.
cached_weights
=
self
.
model
.
state_dict
()
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
#
Run the model
#
Calculate validation loss
outputs
=
self
(
batch
)
outputs
=
self
(
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
batch
=
tensor_tree_map
(
lambda
t
:
t
[...,
-
1
],
batch
)
loss
=
lddt_ca
(
# Compute loss and other metrics
outputs
[
"final_atom_positions"
],
batch
[
"use_clamped_fape"
]
=
0.
batch
[
"all_atom_positions"
],
_
,
loss_breakdown
=
self
.
loss
(
batch
[
"all_atom_mask"
],
outputs
,
batch
,
_return_breakdown
=
True
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
)
self
.
log
(
"val/loss"
,
loss
,
logger
=
True
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
all_atom_mask
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
gdt_ha_score
=
gdt_ha
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
return
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
)
return
{
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
"optimizer"
:
optimizer
,
self
.
ema
.
update
(
self
.
model
)
"lr_scheduler"
:
{
"scheduler"
:
lr_scheduler
,
"interval"
:
"step"
,
"name"
:
"AlphaFoldLRScheduler"
,
}
}
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
self
.
ema
.
load_state_dict
(
checkpoint
[
"ema"
])
...
@@ -235,7 +129,7 @@ def main(args):
...
@@ -235,7 +129,7 @@ def main(args):
seed_everything
(
args
.
seed
)
seed_everything
(
args
.
seed
)
config
=
model_config
(
config
=
model_config
(
args
.
config_preset
,
"model_1"
,
train
=
True
,
train
=
True
,
low_prec
=
(
args
.
precision
==
"16"
)
low_prec
=
(
args
.
precision
==
"16"
)
)
)
...
@@ -246,7 +140,7 @@ def main(args):
...
@@ -246,7 +140,7 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
# TorchScript components of the model
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
...
@@ -265,18 +159,16 @@ def main(args):
...
@@ -265,18 +159,16 @@ def main(args):
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
save_top_k
=-
1
,
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
es
=
EarlyStoppingVerbose
(
monitor
=
"val/l
ddt_ca
"
,
monitor
=
"val/l
oss
"
,
min_delta
=
args
.
min_delta
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
patience
=
args
.
patience
,
verbose
=
False
,
verbose
=
False
,
mode
=
"m
ax
"
,
mode
=
"m
in
"
,
check_finite
=
True
,
check_finite
=
True
,
strict
=
True
,
strict
=
True
,
)
)
...
@@ -306,8 +198,14 @@ def main(args):
...
@@ -306,8 +198,14 @@ def main(args):
loggers
.
append
(
wdb_logger
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
if
(
args
.
deepspeed_config_path
is
not
None
):
#if "SLURM_JOB_ID" in os.environ:
# cluster_environment = SLURMEnvironment()
#else:
# cluster_environment = None
strategy
=
DeepSpeedPlugin
(
strategy
=
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
,
config
=
args
.
deepspeed_config_path
,
# cluster_environment=cluster_environment,
)
)
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
...
@@ -316,12 +214,7 @@ def main(args):
...
@@ -316,12 +214,7 @@ def main(args):
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
else
:
strategy
=
None
strategy
=
None
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
args
,
default_root_dir
=
args
.
output_dir
,
default_root_dir
=
args
.
output_dir
,
...
@@ -459,65 +352,37 @@ if __name__ == "__main__":
...
@@ -459,65 +352,37 @@ if __name__ == "__main__":
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log metrics to Weights & Biases"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
"--experiment_name"
,
type
=
str
,
default
=
None
,
help
=
"Name of the current experiment. Used for wandb logging"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
"--wandb_id"
,
type
=
str
,
default
=
None
,
help
=
"ID of a previous run to be resumed"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
"--wandb_project"
,
type
=
str
,
default
=
None
,
help
=
"Name of the wandb project to which this run will belong"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
"--wandb_entity"
,
type
=
str
,
default
=
None
,
help
=
"wandb username or team name to which runs are attributed"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to TorchScript eligible components of them model"
help
=
"Whether to TorchScript eligible components of them model"
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--train_
chain
_data_cache_path"
,
type
=
str
,
default
=
None
,
"--train_
prot
_data_cache_path"
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--distillation_
chain
_data_cache_path"
,
type
=
str
,
default
=
None
,
"--distillation_
prot
_data_cache_path"
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
help
=
(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log the actual learning rate"
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"initial_training"
,
help
=
(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--_
distillation_structure
_index_path"
,
type
=
str
,
default
=
None
,
"--_
alignment
_index_path"
,
type
=
str
,
default
=
None
,
)
)
parser
.
add_argument
(
parser
.
add_argument
(
"--alignment_index_path"
,
type
=
str
,
default
=
None
,
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Training alignment index. See the README for instructions."
)
parser
.
add_argument
(
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Distillation alignment index. See the README for instructions."
)
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment