Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
142c5e65
"vscode:/vscode.git/clone" did not exist on "c0bdf412bac925881c1852a8aca671ec36027b99"
Commit
142c5e65
authored
May 02, 2024
by
Jennifer
Browse files
Updates organization of command line flags for pl.Trainer
parent
2eda3215
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
45 deletions
+48
-45
train_openfold.py
train_openfold.py
+48
-45
No files found.
train_openfold.py
View file @
142c5e65
...
...
@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
...
...
@@ -289,10 +284,13 @@ def main(args):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
,
workers
=
True
)
is_low_precision
=
args
.
precision
in
[
"bf16-mixed"
,
"16"
,
"bf16"
,
"16-true"
,
"16-mixed"
,
"bf16-mixed"
]
config
=
model_config
(
args
.
config_preset
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
is_low_precision
,
)
if
args
.
experiment_config_json
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
...
...
@@ -432,17 +430,17 @@ def main(args):
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
(
num_nodes
=
args
.
num_nodes
,
devices
=
args
.
gpus
,
precision
=
args
.
precision
,
max_epochs
=
args
.
max_epochs
,
default_root_dir
=
args
.
output_dir
,
strategy
=
strategy
,
callbacks
=
callback
s
,
logger
=
loggers
,
profiler
=
'simple'
,
)
trainer
_kws
=
[
'num_nodes'
,
'precision'
,
'max_epochs'
,
'log_every_n_steps'
,
'flush_logs_ever_n_steps'
,
'num_sanity_val_steps'
,
'reload_dataloaders_every_n_epochs'
]
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
trainer_args
.
update
({
'default_root_dir'
:
args
.
output_dir
,
'strategy'
:
strategy
,
'callbacks'
:
callbacks
,
'logger'
:
logger
s
,
})
trainer
=
pl
.
Trainer
(
**
trainer_args
)
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
...
...
@@ -652,32 +650,39 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"32"
)
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
50
)
parser
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
)
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
)
# parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass
parser
.
set_defaults
(
num_sanity_val_steps
=
0
,
)
# Remove some buggy/redundant arguments introduced by the Trainer
remove_arguments
(
parser
,
[
"--accelerator"
,
"--resume_from_checkpoint"
,
"--reload_dataloaders_every_epoch"
,
"--reload_dataloaders_every_n_epochs"
,
]
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
1
,
help
=
'For determining optimal strategy and effective batch size.'
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Whether to use MPI for parallele processing"
)
trainer_group
=
parser
.
add_argument_group
(
'Arguments to pass to PyTorch Lightning Trainer'
)
trainer_group
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
'bf16'
,
help
=
'Sets precision, lower precision improves runtime performance.'
,
)
trainer_group
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
)
trainer_group
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
,
)
trainer_group
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
trainer_group
.
add_argument
(
"--reload_dataloaders_every_n_epochs"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
,
help
=
"Accumulate gradients over k batches before next optimizer step."
)
args
=
parser
.
parse_args
()
...
...
@@ -692,7 +697,5 @@ if __name__ == "__main__":
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment