compare_utils.py 4.26 KB
Newer Older
1
import os
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
2

3
4
5
6
7
8
import importlib
import pkgutil
import sys
import unittest

import numpy as np
9
import torch
10

11
from openfold.config import model_config
12
13
14
from openfold.model.model import AlphaFold
from openfold.utils.import_weights import import_jax_weights_

15
16
from tests.config import consts

17
# Give JAX some GPU memory discipline
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
18
# (by default it hogs 90% of GPU memory. This disables that behavior and also
19
20
21
22
23
# forces it to proactively free memory that it allocates)
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
os.environ["JAX_PLATFORM_NAME"] = "gpu"


24
25
26
27
28
29
30
31
32
33
34
35
def skip_unless_ds4s_installed():
    deepspeed_is_installed = importlib.util.find_spec("deepspeed") is not None
    ds4s_is_installed = deepspeed_is_installed and importlib.util.find_spec(
        "deepspeed.ops.deepspeed4science") is not None
    return unittest.skipUnless(ds4s_is_installed, "Requires DeepSpeed with version ≥ 0.10.4")


def skip_unless_flash_attn_installed():
    fa_is_installed = importlib.util.find_spec("flash_attn") is not None
    return unittest.skipUnless(fa_is_installed, "Requires Flash Attention")


36
37
38
39
40
41
42
43
44
45
def alphafold_is_installed():
    return importlib.util.find_spec("alphafold") is not None


def skip_unless_alphafold_installed():
    return unittest.skipUnless(alphafold_is_installed(), "Requires AlphaFold")


def import_alphafold():
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48
49
50
    If AlphaFold is installed using the provided setuptools script, this
    is necessary to expose all of AlphaFold's precious insides
    """
    if "alphafold" in sys.modules:
        return sys.modules["alphafold"]
51
52
    module = importlib.import_module("alphafold")
    # Forcefully import alphafold's submodules
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
    submodules = pkgutil.walk_packages(module.__path__, prefix=("alphafold."))
54
55
56
57
58
59
60
61
62
    for submodule_info in submodules:
        importlib.import_module(submodule_info.name)
    sys.modules["alphafold"] = module
    globals()["alphafold"] = module

    return module


def get_alphafold_config():
63
    config = alphafold.model.config.model_config(consts.model)  # noqa
64
65
66
67
    config.model.global_config.deterministic = True
    return config


68
69
dir_path = os.path.dirname(os.path.realpath(__file__))
_param_path = os.path.join(dir_path, "..", f"openfold/resources/params/params_{consts.model}.npz")
70
_model = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71
72


73
74
def get_global_pretrained_openfold():
    global _model
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
    if _model is None:
76
        _model = AlphaFold(model_config(consts.model))
77
        _model = _model.eval()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
        if not os.path.exists(_param_path):
79
80
81
            raise FileNotFoundError(
                """Cannot load pretrained parameters. Make sure to run the 
                installation script before running tests."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
            )
83
        import_jax_weights_(_model, _param_path, version=consts.model)
84
85
86
87
88
89
        _model = _model.cuda()

    return _model


_orig_weights = None
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
90
91


92
93
def _get_orig_weights():
    global _orig_weights
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
94
    if _orig_weights is None:
95
96
97
98
99
100
101
        _orig_weights = np.load(_param_path)

    return _orig_weights


def _remove_key_prefix(d, prefix):
    for k, v in list(d.items()):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
        if k.startswith(prefix):
103
            d.pop(k)
104
            d[k[len(prefix):]] = v
105
106
107
108


def fetch_alphafold_module_weights(weight_path):
    orig_weights = _get_orig_weights()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109
110
111
    params = {k: v for k, v in orig_weights.items() if weight_path in k}
    if "/" in weight_path:
        spl = weight_path.split("/")
112
        spl = spl if len(spl[-1]) != 0 else spl[:-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
113
        prefix = "/".join(spl[:-1]) + "/"
114
        _remove_key_prefix(params, prefix)
115
116

    try:
117
        params = alphafold.model.utils.flat_params_to_haiku(params)  # noqa
118
119
120
121
    except:
        raise ImportError(
            "Make sure to call import_alphafold before running this function"
        )
122
    return params
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139


def _assert_abs_diff_small_base(compare_func, expected, actual, eps):
    # Helper function for comparing absolute differences of two torch tensors.
    abs_diff = torch.abs(expected - actual)
    err = compare_func(abs_diff)
    zero_tensor = torch.tensor(0, dtype=err.dtype)
    rtol = 1.6e-2 if err.dtype == torch.bfloat16 else 1.3e-6  
    torch.testing.assert_close(err, zero_tensor, atol=eps, rtol=rtol)


def assert_max_abs_diff_small(expected, actual, eps):
    _assert_abs_diff_small_base(torch.max, expected, actual, eps)


def assert_mean_abs_diff_small(expected, actual, eps):
    _assert_abs_diff_small_base(torch.mean, expected, actual, eps)