Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
d7367f9a
"docs/git@developer.sourcefind.cn:OpenDAS/mmdeploy.git" did not exist on "4353fa59a0746882ea4cd68c5bf15c9588c3eabc"
Commit
d7367f9a
authored
Jun 24, 2022
by
Gustaf Ahdritz
Browse files
Touch up OF/AF conversion script
parent
76599b12
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
4 deletions
+19
-4
scripts/convert_of_weights_to_jax.py
scripts/convert_of_weights_to_jax.py
+19
-4
No files found.
scripts/convert_of_weights_to_jax.py
View file @
d7367f9a
#!/usr/bin/env python
# Copyright 2022 AlQuraishi Laboratory
# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Converts OpenFold .pt checkpoints into AlphaFold .npz ones, which can then be
# used to run inference using DeepMind's JAX code.
import
argparse
import
argparse
import
numpy
as
np
import
numpy
as
np
...
@@ -54,7 +68,7 @@ def main(args):
...
@@ -54,7 +68,7 @@ def main(args):
translation
=
generate_translation_dict
(
model
,
args
.
config_preset
)
translation
=
generate_translation_dict
(
model
,
args
.
config_preset
)
translation
=
process_translation_dict
(
translation
)
translation
=
process_translation_dict
(
translation
)
af_weight_template
=
data
=
np
.
load
(
args
.
template_npz_path
)
af_weight_template
=
np
.
load
(
args
.
template_npz_path
)
af_weight_template
=
{
k
:
v
for
k
,
v
in
af_weight_template
.
items
()
if
k
in
translation
}
af_weight_template
=
{
k
:
v
for
k
,
v
in
af_weight_template
.
items
()
if
k
in
translation
}
zero
=
lambda
n
:
n
*
0
zero
=
lambda
n
:
n
*
0
af_weight_template
=
tree_map
(
zero
,
af_weight_template
,
np
.
ndarray
)
af_weight_template
=
tree_map
(
zero
,
af_weight_template
,
np
.
ndarray
)
...
@@ -80,7 +94,8 @@ if __name__ == "__main__":
...
@@ -80,7 +94,8 @@ if __name__ == "__main__":
type
=
str
,
type
=
str
,
default
=
"openfold/resources/params/params_model_1_ptm.npz"
,
default
=
"openfold/resources/params/params_model_1_ptm.npz"
,
help
=
"""Path to an AlphaFold checkpoint w/ a superset of the OF
help
=
"""Path to an AlphaFold checkpoint w/ a superset of the OF
checkpoint's parameters"""
checkpoint's parameters. params_model_1_ptm.npz always works.
"""
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment