Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
142c5e65
Commit
142c5e65
authored
May 02, 2024
by
Jennifer
Browse files
Updates organization of command line flags for pl.Trainer
parent
2eda3215
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
45 deletions
+48
-45
train_openfold.py
train_openfold.py
+48
-45
No files found.
train_openfold.py
View file @
142c5e65
...
...
@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
...
...
@@ -289,10 +284,13 @@ def main(args):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
,
workers
=
True
)
is_low_precision
=
args
.
precision
in
[
"bf16-mixed"
,
"16"
,
"bf16"
,
"16-true"
,
"16-mixed"
,
"bf16-mixed"
]
config
=
model_config
(
args
.
config_preset
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
is_low_precision
,
)
if
args
.
experiment_config_json
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
...
...
@@ -432,17 +430,17 @@ def main(args):
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
(
num_nodes
=
args
.
num_nodes
,
devices
=
args
.
gpus
,
precision
=
args
.
precision
,
max_epochs
=
args
.
max_epochs
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callback
s
,
logger
=
loggers
,
profiler
=
'simple'
,
)
trainer
_kws
=
[
'num_nodes'
,
'precision'
,
'max_epochs'
,
'log_every_n_steps'
,
'flush_logs_ever_n_steps'
,
'num_sanity_val_steps'
,
'reload_dataloaders_every_n_epochs'
]
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
trainer_args
.
update
({
'default_root_dir'
:
args
.
output_dir
,
'strategy'
:
strategy
,
'callbacks'
:
callbacks
,
'logger'
:
logger
s
,
})
trainer
=
pl
.
Trainer
(
**
trainer_args
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
...
...
@@ -652,32 +650,39 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"32"
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
help
=
'For determining optimal strategy and effective batch size.'
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use MPI for parallele processing"
)
trainer_group
=
parser
.
add_argument_group
(
'Arguments to pass to PyTorch Lightning Trainer'
)
trainer_group
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
'bf16'
,
help
=
'Sets precision, lower precision improves runtime performance.'
,
)
trainer_group
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
)
trainer_group
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
,
)
trainer_group
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
trainer_group
.
add_argument
(
"--reload_dataloaders_every_n_epochs"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
,
help
=
"Accumulate gradients over k batches before next optimizer step."
)
args
=
parser
.
parse_args
()
...
...
@@ -692,7 +697,5 @@ if __name__ == "__main__":
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment