Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
7de0ab00
Commit
7de0ab00
authored
Jan 24, 2024
by
Jennifer
Committed by
Jennifer Wei
May 06, 2024
Browse files
first pass changes to run with pl 2.1
parent
a51b08cd
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
55 additions
and
51 deletions
+55
-51
train_openfold.py
train_openfold.py
+55
-51
No files found.
train_openfold.py
View file @
7de0ab00
...
@@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
ema
=
ExponentialMovingAverage
(
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
last_lr_step
=
-
1
self
.
save_hyperparameters
()
self
.
save_hyperparameters
()
...
@@ -73,7 +73,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -73,7 +73,7 @@ class OpenFoldWrapper(pl.LightningModule):
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
sync_dist
=
False
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
sync_dist
=
False
,
)
)
if
(
train
):
if
(
train
):
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
indiv_loss
,
...
@@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -82,12 +82,12 @@ class OpenFoldWrapper(pl.LightningModule):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
batch
,
outputs
,
outputs
,
superimposition_metrics
=
(
not
train
)
superimposition_metrics
=
(
not
train
)
)
)
for
k
,
v
in
other_metrics
.
items
():
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
torch
.
mean
(
v
),
...
@@ -96,7 +96,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -96,7 +96,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
@@ -127,12 +127,13 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -127,12 +127,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
def
clone_param
(
t
):
return
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
@@ -160,17 +161,17 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -160,17 +161,17 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
def
_compute_validation_metrics
(
self
,
batch
,
batch
,
outputs
,
outputs
,
superimposition_metrics
=
False
superimposition_metrics
=
False
):
):
metrics
=
{}
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
...
@@ -178,7 +179,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -178,7 +179,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
lddt_ca_score
=
lddt_ca
(
pred_coords
,
pred_coords
,
gt_coords
,
gt_coords
,
...
@@ -186,18 +187,18 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -186,18 +187,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
per_residue
=
False
,
)
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
pred_coords_masked_ca
,
gt_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
)
...
@@ -211,7 +212,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -211,7 +212,7 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
return
metrics
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
...
@@ -220,8 +221,8 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -220,8 +221,8 @@ class OpenFoldWrapper(pl.LightningModule):
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
...
@@ -246,8 +247,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -246,8 +247,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
ema
=
checkpoint
[
"ema"
]
ema
=
checkpoint
[
"ema"
]
if
(
not
self
.
model
.
template_config
.
enabled
):
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
self
.
ema
.
load_state_dict
(
ema
)
self
.
ema
.
load_state_dict
(
ema
)
def
on_save_checkpoint
(
self
,
checkpoint
):
def
on_save_checkpoint
(
self
,
checkpoint
):
...
@@ -258,13 +260,13 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -258,13 +260,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
load_from_jax
(
self
,
jax_path
):
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
os
.
path
.
normpath
(
jax_path
)
)
)
)[
0
]
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
self
.
model
,
jax_path
,
version
=
model_version
)
)
def
get_model_state_dict_from_ds_checkpoint
(
checkpoint_dir
):
def
get_model_state_dict_from_ds_checkpoint
(
checkpoint_dir
):
...
@@ -331,30 +333,31 @@ def main(args):
...
@@ -331,30 +333,31 @@ def main(args):
if
args
.
resume_from_jax_params
:
if
args
.
resume_from_jax_params
:
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
# TorchScript components of the model
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
if
"multimer"
in
args
.
config_preset
:
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
else
:
else
:
data_module
=
OpenFoldDataModule
(
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
data_module
.
prepare_data
()
data_module
.
prepare_data
()
data_module
.
setup
()
data_module
.
setup
()
callbacks
=
[]
callbacks
=
[]
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
auto_insert_metric_name
=
False
,
...
@@ -362,7 +365,7 @@ def main(args):
...
@@ -362,7 +365,7 @@ def main(args):
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
es
=
EarlyStoppingVerbose
(
monitor
=
"val/lddt_ca"
,
monitor
=
"val/lddt_ca"
,
min_delta
=
args
.
min_delta
,
min_delta
=
args
.
min_delta
,
...
@@ -374,7 +377,7 @@ def main(args):
...
@@ -374,7 +377,7 @@ def main(args):
)
)
callbacks
.
append
(
es
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
perf
=
PerformanceLoggingCallback
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
...
@@ -382,7 +385,7 @@ def main(args):
...
@@ -382,7 +385,7 @@ def main(args):
)
)
callbacks
.
append
(
perf
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
callbacks
.
append
(
lr_monitor
)
...
@@ -448,7 +451,7 @@ def main(args):
...
@@ -448,7 +451,7 @@ def main(args):
ckpt_path
=
args
.
resume_from_ckpt
ckpt_path
=
args
.
resume_from_ckpt
trainer
.
fit
(
trainer
.
fit
(
model_module
,
model_module
,
datamodule
=
data_module
,
datamodule
=
data_module
,
ckpt_path
=
ckpt_path
,
ckpt_path
=
ckpt_path
,
)
)
...
@@ -686,16 +689,17 @@ if __name__ == "__main__":
...
@@ -686,16 +689,17 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
raise
ValueError
(
"For distributed training, --seed must be specified"
)
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment