Commit 9d6127cb authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix missing import in tests

parent fe8869c3
...@@ -11,6 +11,8 @@ from openfold.config import model_config ...@@ -11,6 +11,8 @@ from openfold.config import model_config
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_ from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts
# Give JAX some GPU memory discipline # Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also # (by default it hogs 90% of GPU memory. This disables that behavior and also
# forces it to proactively free memory that it allocates) # forces it to proactively free memory that it allocates)
...@@ -98,7 +100,7 @@ def _remove_key_prefix(d, prefix): ...@@ -98,7 +100,7 @@ def _remove_key_prefix(d, prefix):
for k, v in list(d.items()): for k, v in list(d.items()):
if k.startswith(prefix): if k.startswith(prefix):
d.pop(k) d.pop(k)
d[k[len(prefix) :]] = v d[k[len(prefix):]] = v
def fetch_alphafold_module_weights(weight_path): def fetch_alphafold_module_weights(weight_path):
...@@ -107,7 +109,6 @@ def fetch_alphafold_module_weights(weight_path): ...@@ -107,7 +109,6 @@ def fetch_alphafold_module_weights(weight_path):
if "/" in weight_path: if "/" in weight_path:
spl = weight_path.split("/") spl = weight_path.split("/")
spl = spl if len(spl[-1]) != 0 else spl[:-1] spl = spl if len(spl[-1]) != 0 else spl[:-1]
module_name = spl[-1]
prefix = "/".join(spl[:-1]) + "/" prefix = "/".join(spl[:-1]) + "/"
_remove_key_prefix(params, prefix) _remove_key_prefix(params, prefix)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment