Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d2eae833
Commit
d2eae833
authored
Jul 18, 2023
by
Geoffrey Yu
Browse files
added an example training script for training openfold multimer
parent
2f2793f4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
606 additions
and
0 deletions
+606
-0
train_openfold_multimer.py
train_openfold_multimer.py
+606
-0
No files found.
train_openfold_multimer.py
0 → 100644
View file @
d2eae833
import
argparse
import
logging
import
os
import
random
import
sys
import
time
import
numpy
as
np
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.plugins.environments
import
SLURMEnvironment
import
torch
from
openfold.config
import
model_config
from
openfold.data.data_modules
import
(
OpenFoldMultimerDataModule
,
DummyDataLoader
,
)
from
openfold.model.model
import
AlphaFold
from
openfold.model.torchscript
import
script_preset_
from
openfold.np
import
residue_constants
from
openfold.utils.argparse_utils
import
remove_arguments
from
openfold.utils.callbacks
import
(
EarlyStoppingVerbose
,
)
from
openfold.utils.exponential_moving_average
import
ExponentialMovingAverage
from
openfold.utils.loss
import
AlphaFoldMultimerLoss
,
lddt_ca
from
openfold.utils.lr_schedulers
import
AlphaFoldLRScheduler
from
openfold.utils.seed
import
seed_everything
from
openfold.utils.superimposition
import
superimpose
from
openfold.utils.tensor_utils
import
tensor_tree_map
from
openfold.utils.validation_metrics
import
(
drmsd
,
gdt_ts
,
gdt_ha
,
)
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
)
from
openfold.utils.logger
import
PerformanceLoggingCallback
class
OpenFoldWrapper
(
pl
.
LightningModule
):
def
__init__
(
self
,
config
):
super
(
OpenFoldWrapper
,
self
).
__init__
()
self
.
config
=
config
self
.
config
.
loss
.
masked_msa
.
num_classes
=
22
# somehow need overwrite this part in multimer loss config
self
.
config
.
model
.
evoformer_stack
.
no_blocks
=
4
# no need to go overboard here
self
.
config
.
model
.
evoformer_stack
.
blocks_per_ckpt
=
None
# don't want to set up
self
.
model
=
AlphaFold
(
config
)
self
.
loss
=
AlphaFoldMultimerLoss
(
config
.
loss
)
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
def
_log
(
self
,
loss_breakdown
,
batch
,
outputs
,
train
=
True
):
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
if
(
train
):
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
,
)
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
outputs
,
superimposition_metrics
=
(
not
train
)
)
for
k
,
v
in
other_metrics
.
items
():
assert
(
len
(
v
.
shape
)
==
1
)
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
on_step
=
False
,
on_epoch
=
True
,
logger
=
True
)
def
training_step
(
self
,
batch
,
batch_idx
):
all_chain_features
,
ground_truth
=
batch
if
(
self
.
ema
.
device
!=
all_chain_features
[
"aatype"
].
device
):
self
.
ema
.
to
(
all_chain_features
[
"aatype"
].
device
)
# Run the model
outputs
=
self
(
all_chain_features
)
# Compute loss
loss
,
loss_breakdown
=
self
.
loss
(
outputs
,
(
all_chain_features
,
ground_truth
),
_return_breakdown
=
True
)
# Log it
self
.
_log
(
loss_breakdown
,
all_chain_features
,
outputs
)
return
loss
def
on_before_zero_grad
(
self
,
*
args
,
**
kwargs
):
self
.
ema
.
update
(
self
.
model
)
def
validation_step
(
self
,
batch
,
batch_idx
):
all_chain_features
,
ground_truth
=
batch
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
# Run the model
outputs
=
self
(
all_chain_features
)
# Compute loss and other metrics
all_chain_features
[
"use_clamped_fape"
]
=
0.
_
,
loss_breakdown
=
self
.
loss
(
outputs
,
all_chain_features
,
_return_breakdown
=
True
)
self
.
_log
(
loss_breakdown
,
all_chain_features
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
batch
,
outputs
,
superimposition_metrics
=
False
):
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
ca_pos
=
residue_constants
.
atom_order
[
"CA"
]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
pred_coords
,
gt_coords
,
all_atom_mask
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
gdt_ts_score
=
gdt_ts
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
gdt_ha_score
=
gdt_ha
(
superimposed_pred
,
gt_coords_masked_ca
,
all_atom_mask_ca
)
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
lr
=
learning_rate
,
eps
=
eps
)
if
self
.
last_lr_step
!=
-
1
:
for
group
in
optimizer
.
param_groups
:
if
'initial_lr'
not
in
group
:
group
[
'initial_lr'
]
=
learning_rate
lr_scheduler
=
AlphaFoldLRScheduler
(
optimizer
,
)
return
{
"optimizer"
:
optimizer
,
"lr_scheduler"
:
{
"scheduler"
:
lr_scheduler
,
"interval"
:
"step"
,
"name"
:
"AlphaFoldLRScheduler"
,
}
}
def
on_load_checkpoint
(
self
,
checkpoint
):
ema
=
checkpoint
[
"ema"
]
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
self
.
ema
.
load_state_dict
(
ema
)
def
on_save_checkpoint
(
self
,
checkpoint
):
checkpoint
[
"ema"
]
=
self
.
ema
.
state_dict
()
def
resume_last_lr_step
(
self
,
lr_step
):
self
.
last_lr_step
=
lr_step
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
)
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
config
=
model_config
(
args
.
config_preset
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
# TorchScript components of the model
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
#data_module = DummyDataLoader("new_batch.pickle")
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
)
data_module
.
prepare_data
()
data_module
.
setup
()
callbacks
=
[]
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
save_top_k
=-
1
,
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
monitor
=
"val/lddt_ca"
,
min_delta
=
args
.
min_delta
,
patience
=
args
.
patience
,
verbose
=
False
,
mode
=
"max"
,
check_finite
=
True
,
strict
=
True
,
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
global_batch_size
=
global_batch_size
,
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
id
=
args
.
wandb_id
,
project
=
args
.
wandb_project
,
**
{
"entity"
:
args
.
wandb_entity
}
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeedPlugin
(
config
=
args
.
deepspeed_config_path
,
)
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDPPlugin
(
find_unused_parameters
=
False
)
else
:
strategy
=
None
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
args
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callbacks
,
logger
=
loggers
,
accelerator
=
"gpu"
,
devices
=
1
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
else
:
ckpt_path
=
args
.
resume_from_ckpt
trainer
.
fit
(
model_module
,
datamodule
=
data_module
,
ckpt_path
=
ckpt_path
,
)
def
bool_type
(
bool_str
:
str
):
bool_str_lower
=
bool_str
.
lower
()
if
bool_str_lower
in
(
'false'
,
'f'
,
'no'
,
'n'
,
'0'
):
return
False
elif
bool_str_lower
in
(
'true'
,
't'
,
'yes'
,
'y'
,
'1'
):
return
True
else
:
raise
ValueError
(
f
'Cannot interpret
{
bool_str
}
as bool'
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"train_data_dir"
,
type
=
str
,
help
=
"Directory containing training mmCIF files"
)
parser
.
add_argument
(
"train_alignment_dir"
,
type
=
str
,
help
=
"Directory containing precomputed training alignments"
)
parser
.
add_argument
(
"template_mmcif_dir"
,
type
=
str
,
help
=
"Directory containing mmCIF files to search for templates"
)
parser
.
add_argument
(
"train_mmcif_data_cache_path"
,
type
=
str
,
help
=
"json file that records each mmcif's chains, seqs, resolution etc"
)
parser
.
add_argument
(
"output_dir"
,
type
=
str
,
help
=
'''Directory in which to output checkpoints, logs, etc. Ignored
if not on rank 0'''
)
parser
.
add_argument
(
"max_template_date"
,
type
=
str
,
help
=
'''Cutoff for all templates. In training mode, templates are also
filtered by the release date of the target'''
)
parser
.
add_argument
(
"--distillation_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing training PDB files"
)
parser
.
add_argument
(
"--distillation_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed distillation alignments"
)
parser
.
add_argument
(
"--val_data_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing validation mmCIF files"
)
parser
.
add_argument
(
"--val_alignment_dir"
,
type
=
str
,
default
=
None
,
help
=
"Directory containing precomputed validation alignments"
)
parser
.
add_argument
(
"--kalign_binary_path"
,
type
=
str
,
default
=
'/usr/bin/kalign'
,
help
=
"Path to the kalign binary"
)
parser
.
add_argument
(
"--train_filter_path"
,
type
=
str
,
default
=
None
,
help
=
'''Optional path to a text file containing names of training
examples to include, one per line. Used to filter the training
set'''
)
parser
.
add_argument
(
"--distillation_filter_path"
,
type
=
str
,
default
=
None
,
help
=
"""See --train_filter_path"""
)
parser
.
add_argument
(
"--obsolete_pdbs_file_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to obsolete.dat file containing list of obsolete PDBs and
their replacements."""
)
parser
.
add_argument
(
"--template_release_dates_cache_path"
,
type
=
str
,
default
=
None
,
help
=
"""Output of scripts/generate_mmcif_cache.py run on template mmCIF
files."""
)
parser
.
add_argument
(
"--use_small_bfd"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to use a reduced version of the BFD database"
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
None
,
help
=
"Random seed"
)
parser
.
add_argument
(
"--deepspeed_config_path"
,
type
=
str
,
default
=
None
,
help
=
"Path to DeepSpeed config. If not provided, DeepSpeed is disabled"
)
parser
.
add_argument
(
"--checkpoint_every_epoch"
,
action
=
"store_true"
,
default
=
False
,
help
=
"""Whether to checkpoint at the end of every training epoch"""
)
parser
.
add_argument
(
"--early_stopping"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to stop training when validation loss fails to decrease"
)
parser
.
add_argument
(
"--min_delta"
,
type
=
float
,
default
=
0
,
help
=
"""The smallest decrease in validation loss that counts as an
improvement for the purposes of early stopping"""
)
parser
.
add_argument
(
"--patience"
,
type
=
int
,
default
=
3
,
help
=
"Early stopping patience"
)
parser
.
add_argument
(
"--resume_from_ckpt"
,
type
=
str
,
default
=
None
,
help
=
"Path to a model checkpoint from which to restore training state"
)
parser
.
add_argument
(
"--resume_model_weights_only"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to load just model weights as opposed to training state"
)
parser
.
add_argument
(
"--resume_from_jax_params"
,
type
=
str
,
default
=
None
,
help
=
"""Path to an .npz JAX parameter file with which to initialize the model"""
)
parser
.
add_argument
(
"--log_performance"
,
type
=
bool_type
,
default
=
False
,
help
=
"Measure performance"
)
parser
.
add_argument
(
"--wandb"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log metrics to Weights & Biases"
)
parser
.
add_argument
(
"--experiment_name"
,
type
=
str
,
default
=
None
,
help
=
"Name of the current experiment. Used for wandb logging"
)
parser
.
add_argument
(
"--wandb_id"
,
type
=
str
,
default
=
None
,
help
=
"ID of a previous run to be resumed"
)
parser
.
add_argument
(
"--wandb_project"
,
type
=
str
,
default
=
None
,
help
=
"Name of the wandb project to which this run will belong"
)
parser
.
add_argument
(
"--wandb_entity"
,
type
=
str
,
default
=
None
,
help
=
"wandb username or team name to which runs are attributed"
)
parser
.
add_argument
(
"--script_modules"
,
type
=
bool_type
,
default
=
False
,
help
=
"Whether to TorchScript eligible components of them model"
)
parser
.
add_argument
(
"--train_chain_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--distillation_chain_data_cache_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--train_epoch_len"
,
type
=
int
,
default
=
10000
,
help
=
(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser
.
add_argument
(
"--log_lr"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to log the actual learning rate"
)
parser
.
add_argument
(
"--config_preset"
,
type
=
str
,
default
=
"initial_training"
,
help
=
(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser
.
add_argument
(
"--_distillation_structure_index_path"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Training alignment index. See the README for instructions."
)
parser
.
add_argument
(
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Distillation alignment index. See the README for instructions."
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
)
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment