compare_utils.py 2.89 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import importlib
import pkgutil
import sys
import unittest

import numpy as np

from config import model_config
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_
from tests.config import consts

# Give JAX some GPU memory discipline
# (by default it hogs 90% of GPU memory. This disables that behavior and also 
# forces it to proactively free memory that it allocates)
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"


def alphafold_is_installed():
    return importlib.util.find_spec("alphafold") is not None


def skip_unless_alphafold_installed():
    return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")


def import_alphafold():
    """ 
        If AlphaFold is installed using the provided setuptools script, this
        is necessary to expose all of AlphaFold's precious insides 
    """
    if("alphafold" in sys.modules):
        return sys.modules["alphafold"] 
    module = importlib.import_module("alphafold")
    # Forcefully import alphafold's submodules
    submodules = pkgutil.walk_packages(
        module.__path__, prefix=("alphafold.")
    )
    for submodule_info in submodules:
        importlib.import_module(submodule_info.name)
    sys.modules["alphafold"] = module
    globals()["alphafold"] = module

    return module


def get_alphafold_config():
    config = alphafold.model.config.model_config("model_1_ptm")
    config.model.global_config.deterministic = True
    return config


_param_path = "openfold/resources/params/params_model_1_ptm.npz"
_model = None
def get_global_pretrained_openfold():
    global _model
    if(_model is None):
        _model = AlphaFold(model_config("model_1_ptm").model)
        _model = _model.eval()
        if(not os.path.exists(_param_path)):
            raise FileNotFoundError(
                """Cannot load pretrained parameters. Make sure to run the 
                installation script before running tests."""
        )
        import_jax_weights_(_model, _param_path)
        _model = _model.cuda()

    return _model


_orig_weights = None
def _get_orig_weights():
    global _orig_weights
    if(_orig_weights is None):
        _orig_weights = np.load(_param_path)

    return _orig_weights


def _remove_key_prefix(d, prefix):
    for k, v in list(d.items()):
        if(k.startswith(prefix)):
            d.pop(k)
            d[k[len(prefix):]] = v


def fetch_alphafold_module_weights(weight_path):
    orig_weights = _get_orig_weights()
    params = {
        k:v for k,v in orig_weights.items() 
        if weight_path in k
    }
    if('/' in weight_path):
        spl = weight_path.split('/')
        spl = spl if len(spl[-1]) != 0 else spl[:-1]
        module_name = spl[-1]
        prefix = '/'.join(spl[:-1]) + '/'
        _remove_key_prefix(params, prefix)
    params = alphafold.model.utils.flat_params_to_haiku(params)
    return params