Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
6dc34d71
Commit
6dc34d71
authored
Jan 24, 2024
by
Jennifer
Browse files
first pass changes to run with pl 2.1
parent
5f5a79a7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
139 additions
and
107 deletions
+139
-107
openfold/data/data_modules.py
openfold/data/data_modules.py
+5
-4
openfold/utils/seed.py
openfold/utils/seed.py
+1
-1
train_openfold.py
train_openfold.py
+133
-102
No files found.
openfold/data/data_modules.py
View file @
6dc34d71
...
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
,
stage
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
=
None
):
generator
=
None
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
...
...
@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
val_dataloader
(
self
):
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
None
# Temp fix to pass the validation step
return
[]
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
...
...
@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
def
setup
(
self
,
setup
=
None
):
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
openfold/utils/seed.py
View file @
6dc34d71
...
...
@@ -2,7 +2,7 @@ import os
import
logging
import
random
import
numpy
as
np
from
pytorch_lightning
.utilities.seed
import
seed_everything
from
pytorch_lightning
import
seed_everything
from
openfold.utils.suppress_output
import
SuppressLogging
...
...
train_openfold.py
View file @
6dc34d71
...
...
@@ -8,7 +8,7 @@ import pytorch_lightning as pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.
plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.
strategies
import
DeepSpeedStrategy
,
DDPStrategy
import
torch
from
openfold.config
import
model_config
...
...
@@ -56,7 +56,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
save_hyperparameters
...
...
@@ -68,12 +68,12 @@ class OpenFoldWrapper(pl.LightningModule):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
...
...
@@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule):
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
...
...
@@ -95,7 +95,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
...
@@ -126,12 +126,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
def
clone_param
(
t
):
return
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
...
@@ -153,23 +154,23 @@ class OpenFoldWrapper(pl.LightningModule):
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
on_
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
...
...
@@ -177,7 +178,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
...
...
@@ -185,18 +186,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
...
...
@@ -210,22 +211,22 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
learning_rate
,
self
.
model
.
parameters
(),
lr
=
learning_rate
,
eps
=
eps
)
...
...
@@ -250,8 +251,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_load_checkpoint
(
self
,
checkpoint
):
ema
=
checkpoint
[
"ema"
]
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
self
.
ema
.
load_state_dict
(
ema
)
def
on_save_checkpoint
(
self
,
checkpoint
):
...
...
@@ -262,23 +264,23 @@ class OpenFoldWrapper(pl.LightningModule):
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
self
.
model
,
jax_path
,
version
=
model_version
)
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
config
=
model_config
(
args
.
config_preset
,
train
=
True
,
args
.
config_preset
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
if
args
.
experiment_config_json
:
...
...
@@ -321,30 +323,31 @@ def main(args):
if
args
.
resume_from_jax_params
:
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
else
:
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
setup
()
callbacks
=
[]
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
...
...
@@ -352,7 +355,7 @@ def main(args):
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
monitor
=
"val/lddt_ca"
,
min_delta
=
args
.
min_delta
,
...
...
@@ -364,7 +367,7 @@ def main(args):
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
...
...
@@ -372,12 +375,12 @@ def main(args):
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
...
...
@@ -388,38 +391,43 @@ def main(args):
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Plugin
(
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Strategy
(
config
=
args
.
deepspeed_config_path
,
)
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDP
Plugin
(
find_unused_parameters
=
False
)
strategy
=
DDP
Strategy
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
if
(
args
.
wandb
):
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
)
if
(
args
.
resume_model_weights_only
):
# Raw dump of all args from pl.Trainer constructor
trainer_kws
=
set
([
'accelerator'
,
'strategy'
,
'devices'
,
'num_nodes'
,
'precision'
,
'logger'
,
'callbacks'
,
'fast_dev_run'
,
'max_epochs'
,
'min_epochs'
,
'max_steps'
,
'min_steps'
,
'max_tim'
,
'limit_train_batches'
,
'limit_val_batches'
,
'limit_test_batches'
,
'limit_predict_batches'
,
'overfit_batches'
,
'val_check_interval'
,
'check_val_every_n_epoch'
,
'num_sanity_val_steps'
,
'log_every_n_steps'
,
'enable_checkpointing'
,
'enable_progress_bar'
,
'enable_model_summary'
,
'accumulate_grad_batches'
,
'gradient_clip_val'
,
'gradient_clip_algorithm'
,
'deterministic'
,
'benchmark'
,
'inference_mode'
,
'use_distributed_sampler'
,
'profiler'
,
'detect_anomaly'
,
'barebones'
,
'plugins'
,
'sync_batchnorm'
,
'reload_dataloaders_every_n_epochs'
,
'default_root_dir'
,
])
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
trainer_args
.
update
({
'default_root_dir'
:
args
.
output_dir
,
'strategy'
:
strategy
,
'callbacks'
:
callbacks
,
'logger'
:
loggers
,
})
trainer
=
pl
.
Trainer
(
**
trainer_args
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
else
:
ckpt_path
=
args
.
resume_from_ckpt
trainer
.
fit
(
model_module
,
model_module
,
datamodule
=
data_module
,
ckpt_path
=
ckpt_path
,
)
...
...
@@ -621,36 +629,59 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--replace_sampler_ddp"
,
type
=
bool_type
,
default
=
True
,
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
# parser = pl.Trainer.add_argparse_args(parser)
#
# # Disable the initial validation pass
# parser.set_defaults(
# num_sanity_val_steps=0,
# )
# # Remove some buggy/redundant arguments introduced by the Trainer
# remove_arguments(
# parser,
# [
# "--accelerator",
# "--resume_from_checkpoint",
# "--reload_dataloaders_every_epoch",
# "--reload_dataloaders_every_n_epochs",
# ]
# )
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment