data_utils.py 4.57 KB
Newer Older
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
from random import randint
16
import torch
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
17
import numpy as np
18
from scipy.spatial.transform import Rotation
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
19

20
21
22
23
24
25
26
27
28
29
30
31
32
from tests.config import consts


def random_asym_ids(n_res, split_chains=True, min_chain_len=4):
    n_chain = randint(1, n_res // min_chain_len) if consts.is_multimer else 1

    if not split_chains:
        return [0] * n_res

    assert n_res >= n_chain

    pieces = []
    asym_ids = []
33
    final_idx = n_chain - 1
34
    for idx in range(n_chain - 1):
Christina Floristean's avatar
Christina Floristean committed
35
36
        n_stop = (n_res - sum(pieces) - n_chain + idx - min_chain_len)
        if n_stop <= min_chain_len:
37
            final_idx = idx
Christina Floristean's avatar
Christina Floristean committed
38
39
            break
        piece = randint(min_chain_len, n_stop)
40
41
        pieces.append(piece)
        asym_ids.extend(piece * [idx])
42
    asym_ids.extend((n_res - sum(pieces)) * [final_idx])
43

44
    return np.array(asym_ids).astype(np.float32) + 1
45

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
46
47
48

def random_template_feats(n_templ, n, batch_size=None):
    b = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
49
    if batch_size is not None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50
51
52
53
54
55
        b.append(batch_size)
    batch = {
        "template_mask": np.random.randint(0, 2, (*b, n_templ)),
        "template_pseudo_beta_mask": np.random.randint(0, 2, (*b, n_templ, n)),
        "template_pseudo_beta": np.random.rand(*b, n_templ, n, 3),
        "template_aatype": np.random.randint(0, 22, (*b, n_templ, n)),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
        "template_all_atom_mask": np.random.randint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57
58
            0, 2, (*b, n_templ, n, 37)
        ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
60
61
62
63
64
65
66
        "template_all_atom_positions": 
            np.random.rand(*b, n_templ, n, 37, 3) * 10,
        "template_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_alt_torsion_angles_sin_cos": 
            np.random.rand(*b, n_templ, n, 7, 2),
        "template_torsion_angles_mask": 
            np.random.rand(*b, n_templ, n, 7),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
67
    }
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
68
    batch = {k: v.astype(np.float32) for k, v in batch.items()}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
    batch["template_aatype"] = batch["template_aatype"].astype(np.int64)
70
71
72
73
74

    if consts.is_multimer:
        asym_ids = np.array(random_asym_ids(n))
        batch["asym_id"] = np.tile(asym_ids[np.newaxis, :], (*b, n_templ, 1))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
76
    return batch

77

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
79
def random_extra_msa_feats(n_extra, n, batch_size=None):
    b = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
80
    if batch_size is not None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
82
        b.append(batch_size)
    batch = {
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
83
84
85
86
87
88
89
90
91
92
93
94
        "extra_msa": np.random.randint(0, 22, (*b, n_extra, n)).astype(
            np.int64
        ),
        "extra_has_deletion": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
        "extra_deletion_value": np.random.rand(*b, n_extra, n).astype(
            np.float32
        ),
        "extra_msa_mask": np.random.randint(0, 2, (*b, n_extra, n)).astype(
            np.float32
        ),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
    }
    return batch
97
98


99
def random_affines_vector(dim):
100
101
102
103
    prod_dim = 1
    for d in dim:
        prod_dim *= d

104
    affines = np.zeros((prod_dim, 7)).astype(np.float32)
105
106
107

    for i in range(prod_dim):
        affines[i, :4] = Rotation.random(random_state=42).as_quat()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
108
109
110
        affines[i, 4:] = np.random.rand(
            3,
        ).astype(np.float32)
111
112
113
114

    return affines.reshape(*dim, 7)


115
def random_affines_4x4(dim):
116
117
118
119
    prod_dim = 1
    for d in dim:
        prod_dim *= d

120
    affines = np.zeros((prod_dim, 4, 4)).astype(np.float32)
121
122
123

    for i in range(prod_dim):
        affines[i, :3, :3] = Rotation.random(random_state=42).as_matrix()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
125
126
        affines[i, :3, 3] = np.random.rand(
            3,
        ).astype(np.float32)
127
128
129
130

    affines[:, 3, 3] = 1

    return affines.reshape(*dim, 4, 4)
131
132


133
134
135
136
137
def random_attention_inputs(batch_size, n_seq, n, no_heads, c_hidden, inf=1e9,
                            dtype=torch.float32, requires_grad=False):
    q = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()
    kv = torch.rand(batch_size, n_seq, n, c_hidden, dtype=dtype, requires_grad=requires_grad).cuda()

138
    mask = torch.randint(0, 2, (batch_size, n_seq, 1, 1, n), dtype=dtype, requires_grad=False).cuda()
139
140
141
142
    z_bias = torch.rand(batch_size, 1, no_heads, n, n, dtype=dtype, requires_grad=requires_grad).cuda()
    mask_bias = inf * (mask - 1)

    biases = [mask_bias, z_bias]
143
144

    return q, kv, mask, biases