Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
0f2b2929
"vscode:/vscode.git/clone" did not exist on "44bde250656f68864dcd4942d8648454afc928fa"
Commit
0f2b2929
authored
Mar 13, 2024
by
Jennifer
Browse files
Adds `experiment_config_json` for setting custom configurations with a json.
parent
5bfad074
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
19 additions
and
1 deletion
+19
-1
run_pretrained_openfold.py
run_pretrained_openfold.py
+9
-0
train_openfold.py
train_openfold.py
+10
-1
No files found.
run_pretrained_openfold.py
View file @
0f2b2929
...
@@ -20,6 +20,7 @@ import os
...
@@ -20,6 +20,7 @@ import os
import
pickle
import
pickle
import
random
import
random
import
time
import
time
import
json
logging
.
basicConfig
()
logging
.
basicConfig
()
logger
=
logging
.
getLogger
(
__file__
)
logger
=
logging
.
getLogger
(
__file__
)
...
@@ -180,6 +181,11 @@ def main(args):
...
@@ -180,6 +181,11 @@ def main(args):
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
config
=
model_config
(
args
.
config_preset
,
long_sequence_inference
=
args
.
long_sequence_inference
)
if
args
.
experiment_config_json
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
custom_config_dict
=
json
.
load
(
f
)
config
.
update_from_flattened_dict
(
custom_config_dict
)
if
args
.
trace_model
:
if
args
.
trace_model
:
if
not
config
.
data
.
predict
.
fixed_size
:
if
not
config
.
data
.
predict
.
fixed_size
:
raise
ValueError
(
raise
ValueError
(
...
@@ -453,6 +459,9 @@ if __name__ == "__main__":
...
@@ -453,6 +459,9 @@ if __name__ == "__main__":
"--cif_output"
,
action
=
"store_true"
,
default
=
False
,
"--cif_output"
,
action
=
"store_true"
,
default
=
False
,
help
=
"Output predicted models in ModelCIF format instead of PDB format (default)"
help
=
"Output predicted models in ModelCIF format instead of PDB format (default)"
)
)
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
add_data_args
(
parser
)
add_data_args
(
parser
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
train_openfold.py
View file @
0f2b2929
...
@@ -2,6 +2,7 @@ import argparse
...
@@ -2,6 +2,7 @@ import argparse
import
logging
import
logging
import
os
import
os
import
sys
import
sys
import
json
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
from
pytorch_lightning.callbacks.lr_monitor
import
LearningRateMonitor
...
@@ -39,7 +40,6 @@ from scripts.zero_to_fp32 import (
...
@@ -39,7 +40,6 @@ from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint
,
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
get_global_step_from_zero_checkpoint
)
)
from
scripts.zero_to_fp32
import
get_optim_files
,
parse_optim_states
,
get_model_state_file
from
openfold.utils.logger
import
PerformanceLoggingCallback
from
openfold.utils.logger
import
PerformanceLoggingCallback
...
@@ -59,6 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
...
@@ -59,6 +59,7 @@ class OpenFoldWrapper(pl.LightningModule):
self
.
cached_weights
=
None
self
.
cached_weights
=
None
self
.
last_lr_step
=
-
1
self
.
last_lr_step
=
-
1
self
.
save_hyperparameters
def
forward
(
self
,
batch
):
def
forward
(
self
,
batch
):
return
self
.
model
(
batch
)
return
self
.
model
(
batch
)
...
@@ -280,6 +281,11 @@ def main(args):
...
@@ -280,6 +281,11 @@ def main(args):
train
=
True
,
train
=
True
,
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
low_prec
=
(
str
(
args
.
precision
)
==
"16"
)
)
)
if
args
.
experiment_config_json
:
with
open
(
args
.
experiment_config_json
,
'r'
)
as
f
:
custom_config_dict
=
json
.
load
(
f
)
config
.
update_from_flattened_dict
(
custom_config_dict
)
model_module
=
OpenFoldWrapper
(
config
)
model_module
=
OpenFoldWrapper
(
config
)
if
args
.
resume_from_ckpt
:
if
args
.
resume_from_ckpt
:
...
@@ -611,6 +617,9 @@ if __name__ == "__main__":
...
@@ -611,6 +617,9 @@ if __name__ == "__main__":
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
"--distillation_alignment_index_path"
,
type
=
str
,
default
=
None
,
help
=
"Distillation alignment index. See the README for instructions."
help
=
"Distillation alignment index. See the README for instructions."
)
)
parser
.
add_argument
(
"--experiment_config_json"
,
default
=
""
,
help
=
"Path to a json file with custom config values to overwrite config setting"
,
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
parser
=
pl
.
Trainer
.
add_argparse_args
(
parser
)
# Disable the initial validation pass
# Disable the initial validation pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment