Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
ff368008
"runtime/tests/git@developer.sourcefind.cn:OpenDAS/dynamo.git" did not exist on "e03d23c89850f74b3df63b1b1f573c4064826644"
Commit
ff368008
authored
Jan 24, 2024
by
Jennifer
Browse files
first pass changes to run with pl 2.1
parent
456103da
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
148 additions
and
114 deletions
+148
-114
openfold/data/data_modules.py
openfold/data/data_modules.py
+5
-4
openfold/utils/seed.py
openfold/utils/seed.py
+1
-1
train_openfold.py
train_openfold.py
+142
-109
No files found.
openfold/data/data_modules.py
View file @
ff368008
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -937,7 +937,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
with
open
(
distillation_alignment_index_path
,
"r"
)
as
fp
:
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
self
.
distillation_alignment_index
=
json
.
load
(
fp
)
def
setup
(
self
):
def
setup
(
self
,
stage
=
None
):
# Most of the arguments are the same for the three datasets
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
dataset_gen
=
partial
(
OpenFoldSingleDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -1016,7 +1016,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
mode
=
"predict"
,
mode
=
"predict"
,
)
)
def
_gen_dataloader
(
self
,
stage
):
def
_gen_dataloader
(
self
,
stage
=
None
):
generator
=
None
generator
=
None
if
self
.
batch_seed
is
not
None
:
if
self
.
batch_seed
is
not
None
:
generator
=
torch
.
Generator
()
generator
=
torch
.
Generator
()
...
@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
...
@@ -1053,7 +1053,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
def
val_dataloader
(
self
):
def
val_dataloader
(
self
):
if
self
.
eval_dataset
is
not
None
:
if
self
.
eval_dataset
is
not
None
:
return
self
.
_gen_dataloader
(
"eval"
)
return
self
.
_gen_dataloader
(
"eval"
)
return
None
# Temp fix to pass the validation step
return
[]
def
predict_dataloader
(
self
):
def
predict_dataloader
(
self
):
return
self
.
_gen_dataloader
(
"predict"
)
return
self
.
_gen_dataloader
(
"predict"
)
...
@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
...
@@ -1085,7 +1086,7 @@ class OpenFoldMultimerDataModule(OpenFoldDataModule):
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
training_mode
=
self
.
train_data_dir
is
not
None
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
self
.
val_mmcif_data_cache_path
=
val_mmcif_data_cache_path
def
setup
(
self
):
def
setup
(
self
,
setup
=
None
):
# Most of the arguments are the same for the three datasets
# Most of the arguments are the same for the three datasets
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
dataset_gen
=
partial
(
OpenFoldSingleMultimerDataset
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
template_mmcif_dir
=
self
.
template_mmcif_dir
,
...
...
openfold/utils/seed.py
View file @
ff368008
...
@@ -2,7 +2,7 @@ import os
...
@@ -2,7 +2,7 @@ import os
import
logging
import
logging
import
random
import
random
import
numpy
as
np
import
numpy
as
np
from
pytorch_lightning
.utilities.seed
import
seed_everything
from
pytorch_lightning
import
seed_everything
from
openfold.utils.suppress_output
import
SuppressLogging
from
openfold.utils.suppress_output
import
SuppressLogging
...
...
train_openfold.py
View file @
ff368008
...
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
...
@@ -7,7 +7,7 @@ import pytorch_lightning as pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.callbacks.model_checkpoint
import
ModelCheckpoint
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.loggers
import
WandbLogger
from
pytorch_lightning.
plugins.training_type
import
DeepSpeedPlugin
,
DDPPlugin
from
pytorch_lightning.
strategies
import
DeepSpeedStrategy
,
DDPStrategy
import
torch
import
torch
from
openfold.config
import
model_config
from
openfold.config
import
model_config
...
@@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -55,7 +55,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
ema
=
ExponentialMovingAverage
(
self
.
ema
=
ExponentialMovingAverage
(
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
model
=
self
.
model
,
decay
=
config
.
ema
.
decay
)
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
last_lr_step
=
-
1
...
@@ -66,12 +66,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -66,12 +66,12 @@ class OpenFoldWrapper(pl.LightningModule):
phase
=
"train"
if
train
else
"val"
phase
=
"train"
if
train
else
"val"
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
for
loss_name
,
indiv_loss
in
loss_breakdown
.
items
():
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
"
,
f
"
{
phase
}
/
{
loss_name
}
"
,
indiv_loss
,
indiv_loss
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
on_step
=
train
,
on_epoch
=
(
not
train
),
logger
=
True
,
)
)
if
(
train
):
if
(
train
):
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
f
"
{
phase
}
/
{
loss_name
}
_epoch"
,
indiv_loss
,
indiv_loss
,
...
@@ -80,12 +80,12 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -80,12 +80,12 @@ class OpenFoldWrapper(pl.LightningModule):
with
torch
.
no_grad
():
with
torch
.
no_grad
():
other_metrics
=
self
.
_compute_validation_metrics
(
other_metrics
=
self
.
_compute_validation_metrics
(
batch
,
batch
,
outputs
,
outputs
,
superimposition_metrics
=
(
not
train
)
superimposition_metrics
=
(
not
train
)
)
)
for
k
,
v
in
other_metrics
.
items
():
for
k
,
v
in
other_metrics
.
items
():
self
.
log
(
self
.
log
(
f
"
{
phase
}
/
{
k
}
"
,
f
"
{
phase
}
/
{
k
}
"
,
torch
.
mean
(
v
),
torch
.
mean
(
v
),
...
@@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -93,7 +93,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
def
training_step
(
self
,
batch
,
batch_idx
):
def
training_step
(
self
,
batch
,
batch_idx
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
if
(
self
.
ema
.
device
!=
batch
[
"aatype"
].
device
):
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
self
.
ema
.
to
(
batch
[
"aatype"
].
device
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
@@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -124,12 +124,13 @@ class OpenFoldWrapper(pl.LightningModule):
def
validation_step
(
self
,
batch
,
batch_idx
):
def
validation_step
(
self
,
batch
,
batch_idx
):
# At the start of validation, load the EMA weights
# At the start of validation, load the EMA weights
if
(
self
.
cached_weights
is
None
):
if
(
self
.
cached_weights
is
None
):
# model.state_dict() contains references to model weights rather
# model.state_dict() contains references to model weights rather
# than copies. Therefore, we need to clone them before calling
# than copies. Therefore, we need to clone them before calling
# load_state_dict().
# load_state_dict().
clone_param
=
lambda
t
:
t
.
detach
().
clone
()
def
clone_param
(
t
):
return
t
.
detach
().
clone
()
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
cached_weights
=
tensor_tree_map
(
clone_param
,
self
.
model
.
state_dict
())
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
self
.
model
.
load_state_dict
(
self
.
ema
.
state_dict
()[
"params"
])
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
ground_truth
=
batch
.
pop
(
'gt_features'
,
None
)
...
@@ -151,23 +152,23 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -151,23 +152,23 @@ class OpenFoldWrapper(pl.LightningModule):
)
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
self
.
_log
(
loss_breakdown
,
batch
,
outputs
,
train
=
False
)
def
validation_epoch_end
(
self
,
_
):
def
on_
validation_epoch_end
(
self
,
_
):
# Restore the model weights to normal
# Restore the model weights to normal
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
model
.
load_state_dict
(
self
.
cached_weights
)
self
.
cached_weights
=
None
self
.
cached_weights
=
None
def
_compute_validation_metrics
(
self
,
def
_compute_validation_metrics
(
self
,
batch
,
batch
,
outputs
,
outputs
,
superimposition_metrics
=
False
superimposition_metrics
=
False
):
):
metrics
=
{}
metrics
=
{}
gt_coords
=
batch
[
"all_atom_positions"
]
gt_coords
=
batch
[
"all_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
pred_coords
=
outputs
[
"final_atom_positions"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
all_atom_mask
=
batch
[
"all_atom_mask"
]
# This is super janky for superimposition. Fix later
# This is super janky for superimposition. Fix later
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
gt_coords_masked
=
gt_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
pred_coords_masked
=
pred_coords
*
all_atom_mask
[...,
None
]
...
@@ -175,7 +176,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -175,7 +176,7 @@ class OpenFoldWrapper(pl.LightningModule):
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
gt_coords_masked_ca
=
gt_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
pred_coords_masked_ca
=
pred_coords_masked
[...,
ca_pos
,
:]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
all_atom_mask_ca
=
all_atom_mask
[...,
ca_pos
]
lddt_ca_score
=
lddt_ca
(
lddt_ca_score
=
lddt_ca
(
pred_coords
,
pred_coords
,
gt_coords
,
gt_coords
,
...
@@ -183,18 +184,18 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -183,18 +184,18 @@ class OpenFoldWrapper(pl.LightningModule):
eps
=
self
.
config
.
globals
.
eps
,
eps
=
self
.
config
.
globals
.
eps
,
per_residue
=
False
,
per_residue
=
False
,
)
)
metrics
[
"lddt_ca"
]
=
lddt_ca_score
metrics
[
"lddt_ca"
]
=
lddt_ca_score
drmsd_ca_score
=
drmsd
(
drmsd_ca_score
=
drmsd
(
pred_coords_masked_ca
,
pred_coords_masked_ca
,
gt_coords_masked_ca
,
gt_coords_masked_ca
,
mask
=
all_atom_mask_ca
,
# still required here to compute n
mask
=
all_atom_mask_ca
,
# still required here to compute n
)
)
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
metrics
[
"drmsd_ca"
]
=
drmsd_ca_score
if
(
superimposition_metrics
):
if
(
superimposition_metrics
):
superimposed_pred
,
alignment_rmsd
=
superimpose
(
superimposed_pred
,
alignment_rmsd
=
superimpose
(
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
gt_coords_masked_ca
,
pred_coords_masked_ca
,
all_atom_mask_ca
,
)
)
...
@@ -208,22 +209,22 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -208,22 +209,22 @@ class OpenFoldWrapper(pl.LightningModule):
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"alignment_rmsd"
]
=
alignment_rmsd
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ts"
]
=
gdt_ts_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
metrics
[
"gdt_ha"
]
=
gdt_ha_score
return
metrics
return
metrics
def
configure_optimizers
(
self
,
def
configure_optimizers
(
self
,
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# return torch.optim.Adam(
# self.model.parameters(),
# self.model.parameters(),
# lr=learning_rate,
# lr=learning_rate,
# eps=eps
# eps=eps
# )
# )
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
lr
=
learning_rate
,
lr
=
learning_rate
,
eps
=
eps
eps
=
eps
)
)
...
@@ -248,8 +249,9 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -248,8 +249,9 @@ class OpenFoldWrapper(pl.LightningModule):
def
on_load_checkpoint
(
self
,
checkpoint
):
def
on_load_checkpoint
(
self
,
checkpoint
):
ema
=
checkpoint
[
"ema"
]
ema
=
checkpoint
[
"ema"
]
if
(
not
self
.
model
.
template_config
.
enabled
):
if
(
not
self
.
model
.
template_config
.
enabled
):
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
ema
[
"params"
]
=
{
k
:
v
for
k
,
v
in
ema
[
"params"
].
items
()
if
not
"template"
in
k
}
self
.
ema
.
load_state_dict
(
ema
)
self
.
ema
.
load_state_dict
(
ema
)
def
on_save_checkpoint
(
self
,
checkpoint
):
def
on_save_checkpoint
(
self
,
checkpoint
):
...
@@ -260,69 +262,72 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -260,69 +262,72 @@ class OpenFoldWrapper(pl.LightningModule):
def
load_from_jax
(
self
,
jax_path
):
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
os
.
path
.
normpath
(
jax_path
)
)
)
)[
0
]
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
self
.
model
,
jax_path
,
version
=
model_version
)
)
def
main
(
args
):
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
)
seed_everything
(
args
.
seed
)
config
=
model_config
(
config
=
model_config
(
args
.
config_preset
,
args
.
config_preset
,
train
=
True
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
)
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
(
args
.
resume_from_ckpt
):
if
(
args
.
resume_from_ckpt
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
last_global_step
=
get_global_step_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
last_global_step
=
int
(
sd
[
'global_step'
])
last_global_step
=
int
(
sd
[
'global_step'
])
model_module
.
resume_last_lr_step
(
last_global_step
)
model_module
.
resume_last_lr_step
(
last_global_step
)
logging
.
info
(
"Successfully loaded last lr step..."
)
logging
.
info
(
"Successfully loaded last lr step..."
)
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
args
.
resume_from_ckpt
and
args
.
resume_model_weights_only
):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
if
(
os
.
path
.
isdir
(
args
.
resume_from_ckpt
)):
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
sd
=
get_fp32_state_dict_from_zero_checkpoint
(
args
.
resume_from_ckpt
)
else
:
else
:
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
torch
.
load
(
args
.
resume_from_ckpt
)
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
import_openfold_weights_
(
model
=
model_module
,
state_dict
=
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
resume_from_jax_params
):
if
(
args
.
resume_from_jax_params
):
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
model_module
.
load_from_jax
(
args
.
resume_from_jax_params
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
resume_from_jax_params
}
..."
)
# TorchScript components of the model
# TorchScript components of the model
if
(
args
.
script_modules
):
if
(
args
.
script_modules
):
script_preset_
(
model_module
)
script_preset_
(
model_module
)
if
"multimer"
in
args
.
config_preset
:
if
"multimer"
in
args
.
config_preset
:
data_module
=
OpenFoldMultimerDataModule
(
data_module
=
OpenFoldMultimerDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
else
:
else
:
data_module
=
OpenFoldDataModule
(
data_module
=
OpenFoldDataModule
(
config
=
config
.
data
,
config
=
config
.
data
,
batch_seed
=
args
.
seed
,
batch_seed
=
args
.
seed
,
**
vars
(
args
)
**
vars
(
args
)
)
)
data_module
.
prepare_data
()
data_module
.
prepare_data
()
data_module
.
setup
()
data_module
.
setup
()
callbacks
=
[]
callbacks
=
[]
if
(
args
.
checkpoint_every_epoch
):
if
(
args
.
checkpoint_every_epoch
):
mc
=
ModelCheckpoint
(
mc
=
ModelCheckpoint
(
every_n_epochs
=
1
,
every_n_epochs
=
1
,
auto_insert_metric_name
=
False
,
auto_insert_metric_name
=
False
,
...
@@ -330,7 +335,7 @@ def main(args):
...
@@ -330,7 +335,7 @@ def main(args):
)
)
callbacks
.
append
(
mc
)
callbacks
.
append
(
mc
)
if
(
args
.
early_stopping
):
if
(
args
.
early_stopping
):
es
=
EarlyStoppingVerbose
(
es
=
EarlyStoppingVerbose
(
monitor
=
"val/lddt_ca"
,
monitor
=
"val/lddt_ca"
,
min_delta
=
args
.
min_delta
,
min_delta
=
args
.
min_delta
,
...
@@ -342,7 +347,7 @@ def main(args):
...
@@ -342,7 +347,7 @@ def main(args):
)
)
callbacks
.
append
(
es
)
callbacks
.
append
(
es
)
if
(
args
.
log_performance
):
if
(
args
.
log_performance
):
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
global_batch_size
=
args
.
num_nodes
*
args
.
gpus
perf
=
PerformanceLoggingCallback
(
perf
=
PerformanceLoggingCallback
(
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
log_file
=
os
.
path
.
join
(
args
.
output_dir
,
"performance_log.json"
),
...
@@ -350,12 +355,12 @@ def main(args):
...
@@ -350,12 +355,12 @@ def main(args):
)
)
callbacks
.
append
(
perf
)
callbacks
.
append
(
perf
)
if
(
args
.
log_lr
):
if
(
args
.
log_lr
):
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
lr_monitor
=
LearningRateMonitor
(
logging_interval
=
"step"
)
callbacks
.
append
(
lr_monitor
)
callbacks
.
append
(
lr_monitor
)
loggers
=
[]
loggers
=
[]
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
=
WandbLogger
(
wdb_logger
=
WandbLogger
(
name
=
args
.
experiment_name
,
name
=
args
.
experiment_name
,
save_dir
=
args
.
output_dir
,
save_dir
=
args
.
output_dir
,
...
@@ -365,38 +370,43 @@ def main(args):
...
@@ -365,38 +370,43 @@ def main(args):
)
)
loggers
.
append
(
wdb_logger
)
loggers
.
append
(
wdb_logger
)
if
(
args
.
deepspeed_config_path
is
not
None
):
if
(
args
.
deepspeed_config_path
is
not
None
):
strategy
=
DeepSpeed
Plugin
(
strategy
=
DeepSpeed
Strategy
(
config
=
args
.
deepspeed_config_path
,
config
=
args
.
deepspeed_config_path
,
)
)
if
(
args
.
wandb
):
if
(
args
.
wandb
):
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
args
.
deepspeed_config_path
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
wdb_logger
.
experiment
.
save
(
"openfold/config.py"
)
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
elif
(
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
args
.
num_nodes
>
1
:
strategy
=
DDP
Plugin
(
find_unused_parameters
=
False
)
strategy
=
DDP
Strategy
(
find_unused_parameters
=
False
)
else
:
else
:
strategy
=
None
strategy
=
None
if
(
args
.
wandb
):
if
(
args
.
wandb
):
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
freeze_path
=
f
"
{
wdb_logger
.
experiment
.
dir
}
/package_versions.txt"
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
.
from_argparse_args
(
# Raw dump of all args from pl.Trainer constructor
args
,
trainer_kws
=
set
([
default_root_dir
=
args
.
output_dir
,
'accelerator'
,
'strategy'
,
'devices'
,
'num_nodes'
,
'precision'
,
'logger'
,
'callbacks'
,
'fast_dev_run'
,
'max_epochs'
,
'min_epochs'
,
'max_steps'
,
'min_steps'
,
'max_tim'
,
'limit_train_batches'
,
'limit_val_batches'
,
'limit_test_batches'
,
'limit_predict_batches'
,
'overfit_batches'
,
'val_check_interval'
,
'check_val_every_n_epoch'
,
'num_sanity_val_steps'
,
'log_every_n_steps'
,
'enable_checkpointing'
,
'enable_progress_bar'
,
'enable_model_summary'
,
'accumulate_grad_batches'
,
'gradient_clip_val'
,
'gradient_clip_algorithm'
,
'deterministic'
,
'benchmark'
,
'inference_mode'
,
'use_distributed_sampler'
,
'profiler'
,
'detect_anomaly'
,
'barebones'
,
'plugins'
,
'sync_batchnorm'
,
'reload_dataloaders_every_n_epochs'
,
'default_root_dir'
,
strategy
=
strategy
,
])
callbacks
=
callbacks
,
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
logger
=
loggers
,
trainer_args
.
update
({
)
'default_root_dir'
:
args
.
output_dir
,
'strategy'
:
strategy
,
if
(
args
.
resume_model_weights_only
):
'callbacks'
:
callbacks
,
'logger'
:
loggers
,
})
trainer
=
pl
.
Trainer
(
**
trainer_args
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
ckpt_path
=
None
else
:
else
:
ckpt_path
=
args
.
resume_from_ckpt
ckpt_path
=
args
.
resume_from_ckpt
trainer
.
fit
(
trainer
.
fit
(
model_module
,
model_module
,
datamodule
=
data_module
,
datamodule
=
data_module
,
ckpt_path
=
ckpt_path
,
ckpt_path
=
ckpt_path
,
)
)
...
@@ -595,36 +605,59 @@ if __name__ == "__main__":
...
@@ -595,36 +605,59 @@ if __name__ == "__main__":
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Distillation alignment index. See the README for instructions."
help
=
"Distillation alignment index. See the README for instructions."
)
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
# Disable the initial validation pass
)
parser
.
set_defaults
(
parser
.
add_argument
(
num_sanity_val_steps
=
0
,
"--gpus"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
None
,
)
parser
.
add_argument
(
"--replace_sampler_ddp"
,
type
=
bool_type
,
default
=
True
,
)
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
1
,
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
# parser = pl.Trainer.add_argparse_args(parser)
#
# # Disable the initial validation pass
# parser.set_defaults(
# num_sanity_val_steps=0,
# )
# Remove some buggy/redundant arguments introduced by the Trainer
#
#
Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
#
remove_arguments(
parser
,
#
parser,
[
#
[
"--accelerator"
,
#
"--accelerator",
"--resume_from_checkpoint"
,
#
"--resume_from_checkpoint",
"--reload_dataloaders_every_epoch"
,
#
"--reload_dataloaders_every_epoch",
"--reload_dataloaders_every_n_epochs"
,
#
"--reload_dataloaders_every_n_epochs",
]
#
]
)
#
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
(
args
.
seed
is
None
and
if
(
args
.
seed
is
None
and
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
((
args
.
gpus
is
not
None
and
args
.
gpus
>
1
)
or
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
(
args
.
num_nodes
is
not
None
and
args
.
num_nodes
>
1
))):
raise
ValueError
(
"For distributed training, --seed must be specified"
)
raise
ValueError
(
"For distributed training, --seed must be specified"
)
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
args
.
reload_dataloaders_every_n_epochs
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment