Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
142c5e65
Commit
142c5e65
authored
May 02, 2024
by
Jennifer
Browse files
Updates organization of command line flags for pl.Trainer
parent
2eda3215
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
48 additions
and
45 deletions
+48
-45
train_openfold.py
train_openfold.py
+48
-45
No files found.
train_openfold.py
View file @
142c5e65
...
@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -218,11 +218,6 @@ class OpenFoldWrapper(pl.LightningModule):
learning_rate
:
float
=
1e-3
,
learning_rate
:
float
=
1e-3
,
eps
:
float
=
1e-5
,
eps
:
float
=
1e-5
,
)
->
torch
.
optim
.
Adam
:
)
->
torch
.
optim
.
Adam
:
# return torch.optim.Adam(
# self.model.parameters(),
# lr=learning_rate,
# eps=eps
# )
# Ignored as long as a DeepSpeed optimizer is configured
# Ignored as long as a DeepSpeed optimizer is configured
optimizer
=
torch
.
optim
.
Adam
(
optimizer
=
torch
.
optim
.
Adam
(
self
.
model
.
parameters
(),
self
.
model
.
parameters
(),
...
@@ -289,10 +284,13 @@ def main(args):
...
@@ -289,10 +284,13 @@ def main(args):
if
(
args
.
seed
is
not
None
):
if
(
args
.
seed
is
not
None
):
seed_everything
(
args
.
seed
,
workers
=
True
)
seed_everything
(
args
.
seed
,
workers
=
True
)
is_low_precision
=
args
.
precision
in
[
"bf16-mixed"
,
"16"
,
"bf16"
,
"16-true"
,
"16-mixed"
,
"bf16-mixed"
]
config
=
model_config
(
config
=
model_config
(
args
.
config_preset
,
args
.
config_preset
,
train
=
True
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
is_low_precision
,
)
)
if
args
.
experiment_config_json
:
if
args
.
experiment_config_json
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
...
@@ -432,17 +430,17 @@ def main(args):
...
@@ -432,17 +430,17 @@ def main(args):
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
os
.
system
(
f
"
{
sys
.
executable
}
-m pip freeze >
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
wdb_logger
.
experiment
.
save
(
f
"
{
freeze_path
}
"
)
trainer
=
pl
.
Trainer
(
trainer
_kws
=
[
'num_nodes'
,
'precision'
,
'max_epochs'
,
'log_every_n_steps'
,
num_nodes
=
args
.
num_nodes
,
'flush_logs_ever_n_steps'
,
'num_sanity_val_steps'
,
'reload_dataloaders_every_n_epochs'
]
devices
=
args
.
gpus
,
trainer_args
=
{
k
:
v
for
k
,
v
in
vars
(
args
).
items
()
if
k
in
trainer_kws
}
precision
=
args
.
precision
,
trainer_args
.
update
({
max_epochs
=
args
.
max_epochs
,
'default_root_dir'
:
args
.
output_dir
,
default_root_dir
=
args
.
output_dir
,
'strategy'
:
strategy
,
strategy
=
strategy
,
'callbacks'
:
callbacks
,
callbacks
=
callback
s
,
'logger'
:
logger
s
,
logger
=
loggers
,
})
profiler
=
'simple'
,
trainer
=
pl
.
Trainer
(
**
trainer_args
)
)
if
(
args
.
resume_model_weights_only
):
if
(
args
.
resume_model_weights_only
):
ckpt_path
=
None
ckpt_path
=
None
...
@@ -652,32 +650,39 @@ if __name__ == "__main__":
...
@@ -652,32 +650,39 @@ if __name__ == "__main__":
parser
.
add_argument
(
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
)
parser
.
add_argument
(
"--num_nodes"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
parser
.
add_argument
(
"--gpus"
,
type
=
int
,
default
=
None
)
"--gpus"
,
type
=
int
,
default
=
1
,
help
=
'For determining optimal strategy and effective batch size.'
parser
.
add_argument
(
"--max_epochs"
,
type
=
int
,
default
=
None
)
)
parser
.
add_argument
(
"--precision"
,
type
=
str
,
default
=
"32"
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
,
parser
.
add_argument
(
"--log_every_n_steps"
,
type
=
int
,
default
=
50
)
help
=
"Whether to use MPI for parallele processing"
)
parser
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
)
trainer_group
=
parser
.
add_argument_group
(
parser
.
add_argument
(
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
)
'Arguments to pass to PyTorch Lightning Trainer'
)
parser
.
add_argument
(
"--mpi_plugin"
,
action
=
"store_true"
,
default
=
False
)
trainer_group
.
add_argument
(
# parser = pl.Trainer.add_argparse_args(parser)
"--num_nodes"
,
type
=
int
,
default
=
1
,
)
# Disable the initial validation pass
trainer_group
.
add_argument
(
parser
.
set_defaults
(
"--precision"
,
type
=
str
,
default
=
'bf16'
,
num_sanity_val_steps
=
0
,
help
=
'Sets precision, lower precision improves runtime performance.'
,
)
)
trainer_group
.
add_argument
(
# Remove some buggy/redundant arguments introduced by the Trainer
"--max_epochs"
,
type
=
int
,
default
=
1
,
remove_arguments
(
)
parser
,
trainer_group
.
add_argument
(
[
"--log_every_n_steps"
,
type
=
int
,
default
=
25
,
"--accelerator"
,
)
"--resume_from_checkpoint"
,
trainer_group
.
add_argument
(
"--reload_dataloaders_every_epoch"
,
"--flush_logs_every_n_steps"
,
type
=
int
,
default
=
5
,
"--reload_dataloaders_every_n_epochs"
,
)
]
trainer_group
.
add_argument
(
)
"--num_sanity_val_steps"
,
type
=
int
,
default
=
0
,
)
trainer_group
.
add_argument
(
"--reload_dataloaders_every_n_epochs"
,
type
=
int
,
default
=
1
,
)
trainer_group
.
add_argument
(
"--accumulate_grad_batches"
,
type
=
int
,
default
=
1
,
help
=
"Accumulate gradients over k batches before next optimizer step."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -692,7 +697,5 @@ if __name__ == "__main__":
...
@@ -692,7 +697,5 @@ if __name__ == "__main__":
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
if
(
args
.
resume_from_jax_params
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
main
(
args
)
main
(
args
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment