Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
OpenFold
Commits
9d6127cb
Commit
9d6127cb
authored
Dec 15, 2023
by
Christina Floristean
Browse files
Fix missing import in tests
parent
fe8869c3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
2 deletions
+3
-2
tests/compare_utils.py
tests/compare_utils.py
+3
-2
No files found.
tests/compare_utils.py
View file @
9d6127cb
...
@@ -11,6 +11,8 @@ from openfold.config import model_config
...
@@ -11,6 +11,8 @@ from openfold.config import model_config
from
openfold.model.model
import
AlphaFold
from
openfold.model.model
import
AlphaFold
from
openfold.utils.import_weights
import
import_jax_weights_
from
openfold.utils.import_weights
import
import_jax_weights_
from
tests.config
import
consts
# Give JAX some GPU memory discipline
# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates)
# forces it to proactively free memory that it allocates)
...
@@ -98,7 +100,7 @@ def _remove_key_prefix(d, prefix):
...
@@ -98,7 +100,7 @@ def _remove_key_prefix(d, prefix):
for
k
,
v
in
list
(
d
.
items
()):
for
k
,
v
in
list
(
d
.
items
()):
if
k
.
startswith
(
prefix
):
if
k
.
startswith
(
prefix
):
d
.
pop
(
k
)
d
.
pop
(
k
)
d
[
k
[
len
(
prefix
)
:]]
=
v
d
[
k
[
len
(
prefix
):]]
=
v
def
fetch_alphafold_module_weights
(
weight_path
):
def
fetch_alphafold_module_weights
(
weight_path
):
...
@@ -107,7 +109,6 @@ def fetch_alphafold_module_weights(weight_path):
...
@@ -107,7 +109,6 @@ def fetch_alphafold_module_weights(weight_path):
if
"/"
in
weight_path
:
if
"/"
in
weight_path
:
spl
=
weight_path
.
split
(
"/"
)
spl
=
weight_path
.
split
(
"/"
)
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
spl
=
spl
if
len
(
spl
[
-
1
])
!=
0
else
spl
[:
-
1
]
module_name
=
spl
[
-
1
]
prefix
=
"/"
.
join
(
spl
[:
-
1
])
+
"/"
prefix
=
"/"
.
join
(
spl
[:
-
1
])
+
"/"
_remove_key_prefix
(
params
,
prefix
)
_remove_key_prefix
(
params
,
prefix
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment