Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
2b08407d
Commit
2b08407d
authored
Jan 25, 2023
by
Lucas Bickmann
Browse files
Added support for Jax-parameter loading to train_openfold.py
parent
2ef7893a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
0 deletions
+26
-0
train_openfold.py
train_openfold.py
+26
-0
No files found.
train_openfold.py
View file @
2b08407d
...
...
@@ -37,6 +37,9 @@ from openfold.utils.validation_metrics import (
gdt_ts
,
gdt_ha
,
)
from
openfold.utils.import_weights
import
(
import_jax_weights_
,
)
from
scripts.zero_to_fp32
import
(
get_fp32_state_dict_from_zero_checkpoint
,
get_global_step_from_zero_checkpoint
...
...
@@ -241,6 +244,17 @@ class OpenFoldWrapper(pl.LightningModule):
def
resume_last_lr_step
(
self
,
lr_step
):
self
.
last_lr_step
=
lr_step
def
load_from_jax
(
self
,
jax_path
):
model_basename
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
os
.
path
.
normpath
(
jax_path
)
)
)[
0
]
model_version
=
"_"
.
join
(
model_basename
.
split
(
"_"
)[
1
:])
import_jax_weights_
(
self
.
model
,
jax_path
,
version
=
model_version
)
def
main
(
args
):
if
(
args
.
seed
is
not
None
):
...
...
@@ -269,6 +283,9 @@ def main(args):
sd
=
{
k
[
len
(
"module."
):]:
v
for
k
,
v
in
sd
.
items
()}
model_module
.
load_state_dict
(
sd
)
logging
.
info
(
"Successfully loaded model weights..."
)
if
(
args
.
jax_param_path
):
model_module
.
load_from_jax
(
args
.
jax_param_path
)
logging
.
info
(
f
"Successfully loaded JAX parameters at
{
args
.
jax_param_path
}
..."
)
# TorchScript components of the model
if
(
args
.
script_modules
):
...
...
@@ -531,6 +548,12 @@ if __name__ == "__main__":
'used.'
)
)
parser
.
add_argument
(
"--jax_param_path"
,
type
=
str
,
default
=
None
,
help
=
"""Path to JAX model parameters. If None, and openfold_checkpoint_path
is also None, parameters are selected automatically according to
the model name from openfold/resources/params"""
)
parser
.
add_argument
(
"--_distillation_structure_index_path"
,
type
=
str
,
default
=
None
,
)
...
...
@@ -570,6 +593,9 @@ if __name__ == "__main__":
if
(
str
(
args
.
precision
)
==
"16"
and
args
.
deepspeed_config_path
is
not
None
):
raise
ValueError
(
"DeepSpeed and FP16 training are not compatible"
)
if
(
str
(
args
.
jax_param_path
)
is
not
None
and
args
.
resume_from_ckpt
is
not
None
):
raise
ValueError
(
"Choose between loading pretrained Jax-weights and a checkpoint-path"
)
# This re-applies the training-time filters at the beginning of every epoch
args
.
reload_dataloaders_every_n_epochs
=
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment