compare_utils.py 2.92 KB
Newer Older
1
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
2
3
4

os.environ["CUDA_VISIBLE_DEVICES"] = "4,"

5
6
7
8
9
10
11
import importlib
import pkgutil
import sys
import unittest

import numpy as np

12
from openfold.config import model_config
13
14
15
16
17
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts

# Give JAX some GPU memory discipline
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
# (by default it hogs 90% of GPU memory. This disables that behavior and also
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# forces it to proactively free memory that it allocates)
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"


def alphafold_is_installed():
    return importlib.util.find_spec("alphafold") is not None


def skip_unless_alphafold_installed():
    return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")


def import_alphafold():
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
    If AlphaFold is installed using the provided setuptools script, this
    is necessary to expose all of AlphaFold's precious insides
    """
    if "alphafold" in sys.modules:
        return sys.modules["alphafold"]
39
40
    module = importlib.import_module("alphafold")
    # Forcefully import alphafold's submodules
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
    submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold."))
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
    for submodule_info in submodules:
        importlib.import_module(submodule_info.name)
    sys.modules["alphafold"] = module
    globals()["alphafold"] = module

    return module


def get_alphafold_config():
    config = alphafold.model.config.model_config("model_1_ptm")
    config.model.global_config.deterministic = True
    return config


_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_model = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59


60
61
def get_global_pretrained_openfold():
    global _model
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
62
    if _model is None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
        _model = AlphaFold(model_config("model_1_ptm"))
64
        _model = _model.eval()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
        if not os.path.exists(_param_path):
66
67
68
            raise FileNotFoundError(
                """Cannot load pretrained parameters. Make sure to run the 
                installation script before running tests."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
            )
70
        import_jax_weights_(_model, _param_path, version="model_1_ptm")
71
72
73
74
75
76
        _model = _model.cuda()

    return _model


_orig_weights = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
77
78


79
80
def _get_orig_weights():
    global _orig_weights
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
    if _orig_weights is None:
82
83
84
85
86
87
88
        _orig_weights = np.load(_param_path)

    return _orig_weights


def _remove_key_prefix(d, prefix):
    for k, v in list(d.items()):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
        if k.startswith(prefix):
90
            d.pop(k)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
91
            d[k[len(prefix) :]] = v
92
93
94
95


def fetch_alphafold_module_weights(weight_path):
    orig_weights = _get_orig_weights()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
96
97
98
    params = {k: v for k, v in orig_weights.items() if weight_path in k}
    if "/" in weight_path:
        spl = weight_path.split("/")
99
100
        spl = spl if len(spl[-1]) != 0 else spl[:-1]
        module_name = spl[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
        prefix = "/".join(spl[:-1]) + "/"
102
103
104
        _remove_key_prefix(params, prefix)
    params = alphafold.model.utils.flat_params_to_haiku(params)
    return params