data_transforms.py 37.9 KB
Newer Older
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import itertools
17
from functools import reduce, wraps
18
from operator import add
19

20
21
22
import numpy as np
import torch

23
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
24
from openfold.np import residue_constants as rc
25
from openfold.utils.rigid_utils import Rotation, Rigid
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
29
30
from openfold.utils.tensor_utils import (
    tree_map,
    tensor_tree_map,
    batched_gather,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31

32

33
MSA_FEATURE_NAMES = [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
39
    "msa",
    "deletion_matrix",
    "msa_mask",
    "msa_row_mask",
    "bert_mask",
    "true_msa",
40
]
41

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
44
45
46
47
def cast_to_64bit_ints(protein):
    # We keep all ints as int64
    for k, v in protein.items():
        if v.dtype == torch.int32:
            protein[k] = v.type(torch.int64)
48

49
50
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
51

52
def make_one_hot(x, num_classes):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
53
    x_one_hot = torch.zeros(*x.shape, num_classes, device=x.device)
54
55
56
    x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
    return x_one_hot

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
57

58
def make_seq_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
60
61
    protein["seq_mask"] = torch.ones(
        protein["aatype"].shape, dtype=torch.float32
    )
62
63
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
64

65
def make_template_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
66
67
    protein["template_mask"] = torch.ones(
        protein["template_aatype"].shape[0], dtype=torch.float32
68
    )
69
70
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
71

72
def curry1(f):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
    """Supply all arguments but the first."""
74
    @wraps(f)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
75
76
    def fc(*args, **kwargs):
        return lambda x: f(x, *args, **kwargs)
77

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
    return fc
79
80
81


def make_all_atom_aatype(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
    protein["all_atom_aatype"] = protein["aatype"]
83
84
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
85

86
87
def fix_templates_aatype(protein):
    # Map one-hot to indices
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
    num_templates = protein["template_aatype"].shape[0]
89
90
91
92
93
94
    if(num_templates > 0):
        protein["template_aatype"] = torch.argmax(
            protein["template_aatype"], dim=-1
        )
        # Map hhsearch-aatype to our aatype.
        new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
97
        new_order = torch.tensor(
            new_order_list, dtype=torch.int64, device=protein["aatype"].device,
        ).expand(num_templates, -1)
98
99
100
        protein["template_aatype"] = torch.gather(
            new_order, 1, index=protein["template_aatype"]
        )
101

102
103
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104

105
def correct_msa_restypes(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
106
107
    """Correct MSA restype to have the same order as rc."""
    new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
108
    new_order = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109
110
        [new_order_list] * protein["msa"].shape[1], 
        device=protein["msa"].device,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
112
    ).transpose(0, 1)
    protein["msa"] = torch.gather(new_order, 0, protein["msa"])
113
114

    perm_matrix = np.zeros((22, 22), dtype=np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
    perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
116
117

    for k in protein:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
118
        if "profile" in k:
119
            num_dim = protein[k].shape.as_list()[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
121
122
123
124
            assert num_dim in [
                20,
                21,
                22,
            ], "num_dim for %s out of expected range: %s" % (k, num_dim)
125
            protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
126
    
127
128
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129

130
131
def squeeze_features(protein):
    """Remove singleton and repeated dimensions in protein features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
    protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
133
    for k in [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
134
135
136
137
138
139
140
141
142
143
144
145
        "domain_name",
        "msa",
        "num_alignments",
        "seq_length",
        "sequence",
        "superfamily",
        "deletion_matrix",
        "resolution",
        "between_segment_residues",
        "residue_index",
        "template_all_atom_mask",
    ]:
146
147
148
        if k in protein:
            final_dim = protein[k].shape[-1]
            if isinstance(final_dim, int) and final_dim == 1:
149
150
151
152
                if torch.is_tensor(protein[k]):
                    protein[k] = torch.squeeze(protein[k], dim=-1)
                else:
                    protein[k] = np.squeeze(protein[k], axis=-1)
153

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
    for k in ["seq_length", "num_alignments"]:
155
156
        if k in protein:
            protein[k] = protein[k][0]
157

158
159
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
160

161
162
163
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
    """Replace a portion of the MSA with 'X'."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
164
    msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
165
166
    x_idx = 20
    gap_idx = 21
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
167
168
    msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
    protein["msa"] = torch.where(
169
170
171
        msa_mask,
        torch.ones_like(protein["msa"]) * x_idx,
        protein["msa"]
172
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
173
    aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
174

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
175
176
177
178
    protein["aatype"] = torch.where(
        aatype_mask,
        torch.ones_like(protein["aatype"]) * x_idx,
        protein["aatype"],
179
    )
180
181
    return protein

182

183
@curry1
184
def sample_msa(protein, max_seq, keep_extra, seed=None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
185
    """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`.""" 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
186
    num_seq = protein["msa"].shape[0]
187
188
189
190
    g = torch.Generator(device=protein["msa"].device)
    if seed is not None:
        g.manual_seed(seed)
    shuffled = torch.randperm(num_seq - 1, generator=g) + 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
191
192
193
194
    index_order = torch.cat(
        (torch.tensor([0], device=shuffled.device), shuffled), 
        dim=0
    )
195
    num_sel = min(max_seq, num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
196
197
198
    sel_seq, not_sel_seq = torch.split(
        index_order, [num_sel, num_seq - num_sel]
    )
199
200
201
202

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            if keep_extra:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
204
205
                protein["extra_" + k] = torch.index_select(
                    protein[k], 0, not_sel_seq
                )
206
            protein[k] = torch.index_select(protein[k], 0, sel_seq)
207

208
209
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210

211
212
213
214
215
@curry1
def add_distillation_flag(protein, distillation):
    protein['is_distillation'] = distillation
    return protein

216
217
218
@curry1
def sample_msa_distillation(protein, max_seq):
    if(protein["is_distillation"] == 1):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
        protein = sample_msa(max_seq, keep_extra=False)(protein)
220
221
222
    return protein


223
224
@curry1
def crop_extra_msa(protein, max_extra_msa):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
225
    num_seq = protein["extra_msa"].shape[0]
226
227
228
    num_sel = min(max_extra_msa, num_seq)
    select_indices = torch.randperm(num_seq)[:num_sel]
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
229
230
231
232
        if "extra_" + k in protein:
            protein["extra_" + k] = torch.index_select(
                protein["extra_" + k], 0, select_indices
            )
233
    
234
235
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
236

237
238
def delete_extra_msa(protein):
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
240
        if "extra_" + k in protein:
            del protein["extra_" + k]
241
242
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
243

244
245
246
# Not used in inference
@curry1
def block_delete_msa(protein, config):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
247
    num_seq = protein["msa"].shape[0]
248
    block_num_seq = torch.floor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
        torch.tensor(num_seq, dtype=torch.float32, device=protein["msa"].device)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
250
        * config.msa_fraction_per_block
251
    ).to(torch.int32)
252
253

    if config.randomize_num_blocks:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
255
256
        nb = torch.distributions.uniform.Uniform(
            0, config.num_blocks + 1
        ).sample()
257
258
259
260
261
    else:
        nb = config.num_blocks

    del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
    del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
262
    del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
263
264
265
266
267
268
269
270
271
272
273
274
275
276
    del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
    uniques, counts = combined.unique(return_counts=True)
    difference = uniques[counts == 1]
    intersection = uniques[counts > 1]
    keep_indices = torch.squeeze(difference, 0)

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = torch.gather(protein[k], keep_indices)

    return protein
277

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
278

279
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280
281
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
    weights = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282
283
284
285
286
        [
            torch.ones(21, device=protein["msa"].device), 
            gap_agreement_weight * torch.ones(1, device=protein["msa"].device),
            torch.zeros(1, device=protein["msa"].device)
        ],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
287
288
        0,
    )
289
290

    # Make agreement score as weighted Hamming distance
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
291
292
293
294
    msa_one_hot = make_one_hot(protein["msa"], 23)
    sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
    extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
    extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
295
296
297
298
299
300

    num_seq, num_res, _ = sample_one_hot.shape
    extra_num_seq, _, _ = extra_one_hot.shape

    # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
    # in an optimized fashion to avoid possible memory or computation blowup.
301
    agreement = torch.matmul(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
302
        torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
303
304
305
306
        torch.reshape(
            sample_one_hot * weights, [num_seq, num_res * 23]
        ).transpose(0, 1),
    )
307
308

    # Assign each sequence in the extra sequences to the closest MSA sample
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
310
311
    protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
        torch.int64
    )
312
    
313
314
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
315

316
317
def unsorted_segment_sum(data, segment_ids, num_segments):
    """
318
319
    Computes the sum along segments of a tensor. Similar to 
    tf.unsorted_segment_sum, but only supports 1-D indices.
320
321

    :param data: A tensor whose segments are to be summed.
322
    :param segment_ids: The 1-D segment indices tensor.
323
324
325
    :param num_segments: The number of segments.
    :return: A tensor of same data type as the data argument.
    """
326
327
328
329
330
331
332
333
    assert (
        len(segment_ids.shape) == 1 and
        segment_ids.shape[0] == data.shape[0]
    )
    segment_ids = segment_ids.view(
        segment_ids.shape[0], *((1,) * len(data.shape[1:]))
    )
    segment_ids = segment_ids.expand(data.shape)
334
    shape = [num_segments] + list(data.shape[1:])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
335
336
337
338
    tensor = (
        torch.zeros(*shape, device=segment_ids.device)
        .scatter_add_(0, segment_ids, data.float())
    )
339
340
341
    tensor = tensor.type(data.dtype)
    return tensor

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
342

343
344
345
@curry1
def summarize_clusters(protein):
    """Produce profile and deletion_matrix_mean within each cluster."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
347
    num_seq = protein["msa"].shape[0]

348
    def csum(x):
349
        return unsorted_segment_sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
350
            x, protein["extra_cluster_assignment"], num_seq
351
        )
352

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
354
    mask = protein["extra_msa_mask"]
    mask_counts = 1e-6 + protein["msa_mask"] + csum(mask)  # Include center
355

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
356
357
358
    msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
    msa_sum += make_one_hot(protein["msa"], 23)  # Original sequence
    protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
359
360
    del msa_sum

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361
362
363
    del_sum = csum(mask * protein["extra_deletion_matrix"])
    del_sum += protein["deletion_matrix"]  # Original sequence
    protein["cluster_deletion_mean"] = del_sum / mask_counts
364
    del del_sum
365
    
366
367
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368

369
370
def make_msa_mask(protein):
    """Mask features are all ones, but will later be zero-padded."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
    protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
    protein["msa_row_mask"] = torch.ones(
Sachin Kadyan's avatar
Sachin Kadyan committed
373
        (protein["msa"].shape[0]), dtype=torch.float32
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
    )
375
    return protein
376

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
379
    """Create pseudo beta features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
380
381
382
    is_gly = torch.eq(aatype, rc.restype_order["G"])
    ca_idx = rc.atom_order["CA"]
    cb_idx = rc.atom_order["CB"]
383
384
385
    pseudo_beta = torch.where(
        torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
        all_atom_positions[..., ca_idx, :],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
386
387
        all_atom_positions[..., cb_idx, :],
    )
388

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389
    if all_atom_mask is not None:
390
        pseudo_beta_mask = torch.where(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
391
392
            is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
        )
393
394
395
396
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
397

398
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
def make_pseudo_beta(protein, prefix=""):
400
    """Create pseudo-beta (alpha for glycine) position and mask."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
401
402
403
404
405
406
407
408
409
    assert prefix in ["", "template_"]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein["template_aatype" if prefix else "aatype"],
        protein[prefix + "all_atom_positions"],
        protein["template_all_atom_mask" if prefix else "all_atom_mask"],
    )
410
    return protein
411

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412

413
414
@curry1
def add_constant_field(protein, key, value):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
    protein[key] = torch.tensor(value, device=protein["msa"].device)
416
417
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
418

419
420
421
def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
422
    distribution = torch.distributions.categorical.Categorical(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
        torch.reshape(probs + epsilon, [-1, num_classes])
424
    )
425
426
427
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
428

429
430
def make_hhblits_profile(protein):
    """Compute the HHblits MSA profile if not already present."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
    if "hhblits_profile" in protein:
432
433
434
        return protein

    # Compute the profile for every residue (over all MSA sequences).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
    msa_one_hot = make_one_hot(protein["msa"], 22)
436

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
437
    protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
438
439
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
440

441
442
443
444
@curry1
def make_masked_msa(protein, config, replace_fraction):
    """Create data for BERT on raw MSA."""
    # Add a random amino acid uniformly.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
445
446
447
448
449
    random_aa = torch.tensor(
        [0.05] * 20 + [0.0, 0.0], 
        dtype=torch.float32, 
        device=protein["aatype"].device
    )
450
451

    categorical_probs = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
452
453
454
455
        config.uniform_prob * random_aa
        + config.profile_prob * protein["hhblits_profile"]
        + config.same_prob * make_one_hot(protein["msa"], 22)
    )
456
457

    # Put all remaining probability on [MASK] which is a new column
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
458
459
460
    pad_shapes = list(
        reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
    )
461
    pad_shapes[1] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
463
464
465
    mask_prob = (
        1.0 - config.profile_prob - config.same_prob - config.uniform_prob
    )
    assert mask_prob >= 0.0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
466

467
468
469
    categorical_probs = torch.nn.functional.pad(
        categorical_probs, pad_shapes, value=mask_prob
    )
470

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
471
    sh = protein["msa"].shape
472
473
474
    mask_position = torch.rand(sh) < replace_fraction

    bert_msa = shaped_categorical(categorical_probs)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
    bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
476
477

    # Mix real and masked MSA
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
478
479
480
    protein["bert_mask"] = mask_position.to(torch.float32)
    protein["true_msa"] = protein["msa"]
    protein["msa"] = bert_msa
481
482

    return protein
483

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
484

485
@curry1
486
def make_fixed_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
488
489
490
491
492
    protein,
    shape_schema,
    msa_cluster_size,
    extra_msa_size,
    num_res=0,
    num_templates=0,
493
):
494
495
496
497
498
499
500
501
502
503
    """Guess at the MSA and sequence dimension to make fixed size."""
    pad_size_map = {
        NUM_RES: num_res,
        NUM_MSA_SEQ: msa_cluster_size,
        NUM_EXTRA_SEQ: extra_msa_size,
        NUM_TEMPLATES: num_templates,
    }

    for k, v in protein.items():
        # Don't transfer this to the accelerator.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
504
        if k == "extra_cluster_assignment":
505
506
507
            continue
        shape = list(v.shape)
        schema = shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
508
        msg = "Rank mismatch between shape and shape schema for"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
509
        assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
510
511
512
        pad_size = [
            pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
        ]
513
514
515
516
517
518
519

        padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
        padding.reverse()
        padding = list(itertools.chain(*padding))
        if padding:
            protein[k] = torch.nn.functional.pad(v, padding)
            protein[k] = torch.reshape(protein[k], pad_size)
520
    
521
522
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
523

524
525
526
@curry1
def make_msa_feat(protein):
    """Create and concatenate MSA features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
527
    # Whether there is a domain break. Always zero for chains, but keeping for
528
529
    # compatibility with domain datasets.
    has_break = torch.clip(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
        protein["between_segment_residues"].to(torch.float32), 0, 1
531
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
532
    aatype_1hot = make_one_hot(protein["aatype"], 21)
533
534
535

    target_feat = [
        torch.unsqueeze(has_break, dim=-1),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
536
        aatype_1hot,  # Everyone gets the original sequence.
537
538
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
539
540
541
542
543
    msa_1hot = make_one_hot(protein["msa"], 23)
    has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
    deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
        2.0 / np.pi
    )
544
545
546
547
548
549
550

    msa_feat = [
        msa_1hot,
        torch.unsqueeze(has_deletion, dim=-1),
        torch.unsqueeze(deletion_value, dim=-1),
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
551
552
553
554
555
556
557
558
559
    if "cluster_profile" in protein:
        deletion_mean_value = torch.atan(
            protein["cluster_deletion_mean"] / 3.0
        ) * (2.0 / np.pi)
        msa_feat.extend(
            [
                protein["cluster_profile"],
                torch.unsqueeze(deletion_mean_value, dim=-1),
            ]
560
        )
561

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
562
563
564
    if "extra_deletion_matrix" in protein:
        protein["extra_has_deletion"] = torch.clip(
            protein["extra_deletion_matrix"], 0.0, 1.0
565
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
566
567
568
        protein["extra_deletion_value"] = torch.atan(
            protein["extra_deletion_matrix"] / 3.0
        ) * (2.0 / np.pi)
569

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
570
571
    protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
    protein["target_feat"] = torch.cat(target_feat, dim=-1)
572
573
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
574

575
576
577
578
@curry1
def select_feat(protein, feature_list):
    return {k: v for k, v in protein.items() if k in feature_list}

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
579

580
581
582
@curry1
def crop_templates(protein, max_templates):
    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
583
        if k.startswith("template_"):
584
585
            protein[k] = v[:max_templates]
    return protein
586

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
587

588
589
590
591
592
593
def make_atom14_masks(protein):
    """Construct denser atom positions (14 dimensions instead of 37)."""
    restype_atom14_to_atom37 = []
    restype_atom37_to_atom14 = []
    restype_atom14_mask = []

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
594
    for rt in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
595
596
597
598
        atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
        restype_atom14_to_atom37.append(
            [(rc.atom_order[name] if name else 0) for name in atom_names]
        )
599
        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
600
601
602
603
604
605
606
607
608
609
        restype_atom37_to_atom14.append(
            [
                (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                for name in rc.atom_types
            ]
        )

        restype_atom14_mask.append(
            [(1.0 if name else 0.0) for name in atom_names]
        )
610
611
612
613

    # Add dummy mapping for restype 'UNK'
    restype_atom14_to_atom37.append([0] * 14)
    restype_atom37_to_atom14.append([0] * 37)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
614
    restype_atom14_mask.append([0.0] * 14)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
615

616
    restype_atom14_to_atom37 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
617
618
619
        restype_atom14_to_atom37,
        dtype=torch.int32,
        device=protein["aatype"].device,
620
621
    )
    restype_atom37_to_atom14 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
622
623
624
        restype_atom37_to_atom14,
        dtype=torch.int32,
        device=protein["aatype"].device,
625
626
    )
    restype_atom14_mask = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
627
628
629
        restype_atom14_mask,
        dtype=torch.float32,
        device=protein["aatype"].device,
630
    )
631
    protein_aatype = protein['aatype'].to(torch.long)
632
633
634

    # create the mapping for (residx, atom14) --> atom37, i.e. an array
    # with shape (num_res, 14) containing the atom37 indices for this protein
635
636
    residx_atom14_to_atom37 = restype_atom14_to_atom37[protein_aatype]
    residx_atom14_mask = restype_atom14_mask[protein_aatype]
637

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
638
639
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
640
641

    # create the gather indices for mapping back
642
    residx_atom37_to_atom14 = restype_atom37_to_atom14[protein_aatype]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
643
    protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
644
645

    # create the corresponding mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646
    restype_atom37_mask = torch.zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
647
        [21, 37], dtype=torch.float32, device=protein["aatype"].device
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
648
649
650
651
    )
    for restype, restype_letter in enumerate(rc.restypes):
        restype_name = rc.restype_1to3[restype_letter]
        atom_names = rc.residue_atoms[restype_name]
652
        for atom_name in atom_names:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
653
            atom_type = rc.atom_order[atom_name]
654
655
            restype_atom37_mask[restype, atom_type] = 1

656
    residx_atom37_mask = restype_atom37_mask[protein_aatype]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
657
    protein["atom37_atom_exists"] = residx_atom37_mask
658

659
    return protein
660
661
662


def make_atom14_masks_np(batch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663
664
665
666
667
    batch = tree_map(
        lambda n: torch.tensor(n, device=batch["aatype"].device), 
        batch, 
        np.ndarray
    )
668
669
670
    out = make_atom14_masks(batch)
    out = tensor_tree_map(lambda t: np.array(t), out)
    return out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
671
672
673
674
675
676


def make_atom14_positions(protein):
    """Constructs denser atom positions (14 dimensions instead of 37)."""
    residx_atom14_mask = protein["atom14_atom_exists"]
    residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
677

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
678
679
    # Create a mask for known ground truth positions.
    residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
680
681
682
683
        protein["all_atom_mask"],
        residx_atom14_to_atom37,
        dim=-1,
        no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
684
685
686
687
688
689
690
691
    )

    # Gather the ground truth positions.
    residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
        batched_gather(
            protein["all_atom_positions"],
            residx_atom14_to_atom37,
            dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
692
            no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
693
694
        )
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
695

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
696
697
698
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["atom14_gt_exists"] = residx_atom14_gt_mask
    protein["atom14_gt_positions"] = residx_atom14_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
699

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
700
701
    # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
    # alternative ground truth coordinates where the naming is swapped
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
702
    restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
703
    restype_3 += ["UNK"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
704

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
705
706
707
    # Matrices for renaming ambiguous atoms.
    all_matrices = {
        res: torch.eye(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
708
709
710
711
712
            14,
            dtype=protein["all_atom_mask"].dtype,
            device=protein["all_atom_mask"].device,
        )
        for res in restype_3
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
713
714
    }
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
        correspondences = torch.arange(
            14, device=protein["all_atom_mask"].device
        )
        for source_atom_swap, target_atom_swap in swap.items():
            source_index = rc.restype_name_to_atom14_names[resname].index(
                source_atom_swap
            )
            target_index = rc.restype_name_to_atom14_names[resname].index(
                target_atom_swap
            )
            correspondences[source_index] = target_index
            correspondences[target_index] = source_index
            renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
            for index, correspondence in enumerate(correspondences):
                renaming_matrix[index, correspondence] = 1.0
        all_matrices[resname] = renaming_matrix
731
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
732
733
734
    renaming_matrices = torch.stack(
        [all_matrices[restype] for restype in restype_3]
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
735

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
736
737
738
    # Pick the transformation matrices for the given residue sequence
    # shape (num_res, 14, 14).
    renaming_transform = renaming_matrices[protein["aatype"]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
739

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
740
741
    # Apply it to the ground truth positions. shape (num_res, 14, 3).
    alternative_gt_positions = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
742
        "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
743
744
    )
    protein["atom14_alt_gt_positions"] = alternative_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
745

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
746
747
748
749
    # Create the mask for the alternative ground truth (differs from the
    # ground truth mask, if only one of the atoms in an ambiguous pair has a
    # ground truth position).
    alternative_gt_mask = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
750
751
        "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
752
    protein["atom14_alt_gt_exists"] = alternative_gt_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
753

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
754
755
756
    # Create an ambiguous atoms mask.  shape: (21, 14).
    restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
757
758
759
760
761
762
763
764
765
766
767
        for atom_name1, atom_name2 in swap.items():
            restype = rc.restype_order[rc.restype_3to1[resname]]
            atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
                atom_name1
            )
            atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
                atom_name2
            )
            restype_atom14_is_ambiguous[restype, atom_idx1] = 1
            restype_atom14_is_ambiguous[restype, atom_idx2] = 1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
768
    # From this create an ambiguous_mask for the given sequence.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
769
770
771
772
    protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
        protein["aatype"]
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
773
774
775
    return protein


776
def atom37_to_frames(protein, eps=1e-8):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
777
778
779
780
781
782
    aatype = protein["aatype"]
    all_atom_positions = protein["all_atom_positions"]
    all_atom_mask = protein["all_atom_mask"]

    batch_dims = len(aatype.shape[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
783
784
785
786
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
    restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
    restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
787
788
789
    for restype, restype_letter in enumerate(rc.restypes):
        resname = rc.restype_1to3[restype_letter]
        for chi_idx in range(4):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
790
            if rc.chi_angles_mask[restype][chi_idx]:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
791
792
793
794
795
796
                names = rc.chi_angles_atoms[resname][chi_idx]
                restype_rigidgroup_base_atom_names[
                    restype, chi_idx + 4, :
                ] = names[1:]

    restype_rigidgroup_mask = all_atom_mask.new_zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
797
        (*aatype.shape[:-1], 21, 8),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
798
799
800
    )
    restype_rigidgroup_mask[..., 0] = 1
    restype_rigidgroup_mask[..., 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
801
802
    restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
        rc.chi_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
803
804
805
    )

    lookuptable = rc.atom_order.copy()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
806
    lookuptable[""] = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
807
808
809
810
811
812
813
814
815
    lookup = np.vectorize(lambda x: lookuptable[x])
    restype_rigidgroup_base_atom37_idx = lookup(
        restype_rigidgroup_base_atom_names,
    )
    restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
        restype_rigidgroup_base_atom37_idx,
    )
    restype_rigidgroup_base_atom37_idx = (
        restype_rigidgroup_base_atom37_idx.view(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
816
            *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
817
818
819
820
821
822
823
824
825
        )
    )

    residx_rigidgroup_base_atom37_idx = batched_gather(
        restype_rigidgroup_base_atom37_idx,
        aatype,
        dim=-3,
        no_batch_dims=batch_dims,
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
826

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
827
828
829
830
831
832
833
    base_atom_pos = batched_gather(
        all_atom_positions,
        residx_rigidgroup_base_atom37_idx,
        dim=-2,
        no_batch_dims=len(all_atom_positions.shape[:-2]),
    )

834
    gt_frames = Rigid.from_3_points(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
835
836
837
        p_neg_x_axis=base_atom_pos[..., 0, :],
        origin=base_atom_pos[..., 1, :],
        p_xy_plane=base_atom_pos[..., 2, :],
838
        eps=eps,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
839
840
841
    )

    group_exists = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
842
843
844
        restype_rigidgroup_mask,
        aatype,
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
845
846
847
848
849
850
851
        no_batch_dims=batch_dims,
    )

    gt_atoms_exist = batched_gather(
        all_atom_mask,
        residx_rigidgroup_base_atom37_idx,
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
852
        no_batch_dims=len(all_atom_mask.shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
853
854
855
    )
    gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
856
    rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
857
858
859
    rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
    rots[..., 0, 0, 0] = -1
    rots[..., 0, 2, 2] = -1
860
    rots = Rotation(rot_mats=rots)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
861

862
    gt_frames = gt_frames.compose(Rigid(rots, None))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
863

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
864
865
866
867
868
869
870
871
872
873
874
875
    restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
        *((1,) * batch_dims), 21, 8
    )
    restype_rigidgroup_rots = torch.eye(
        3, dtype=all_atom_mask.dtype, device=aatype.device
    )
    restype_rigidgroup_rots = torch.tile(
        restype_rigidgroup_rots,
        (*((1,) * batch_dims), 21, 8, 1, 1),
    )

    for resname, _ in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
876
        restype = rc.restype_order[rc.restype_3to1[resname]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
877
878
        chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
879
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1

    residx_rigidgroup_is_ambiguous = batched_gather(
        restype_rigidgroup_is_ambiguous,
        aatype,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    residx_rigidgroup_ambiguity_rot = batched_gather(
        restype_rigidgroup_rots,
        aatype,
        dim=-4,
        no_batch_dims=batch_dims,
    )

896
897
898
899
900
901
    residx_rigidgroup_ambiguity_rot = Rotation(
        rot_mats=residx_rigidgroup_ambiguity_rot
    )
    alt_gt_frames = gt_frames.compose(
        Rigid(residx_rigidgroup_ambiguity_rot, None)
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
902

903
904
    gt_frames_tensor = gt_frames.to_tensor_4x4()
    alt_gt_frames_tensor = alt_gt_frames.to_tensor_4x4()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
905

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
906
907
908
909
910
    protein["rigidgroups_gt_frames"] = gt_frames_tensor
    protein["rigidgroups_gt_exists"] = gt_exists
    protein["rigidgroups_group_exists"] = group_exists
    protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
    protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
911
912
913
914
915
916

    return protein


def get_chi_atom_indices():
    """Returns atom indices needed to compute chi angles for all residue types.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
917

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
918
919
920
921
922
923
924
925
    Returns:
      A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
      in the order specified in rc.restypes + unknown residue type
      at the end. For chi angles which are not defined on the residue, the
      positions indices are by default set to 0.
    """
    chi_atom_indices = []
    for residue_name in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
926
927
928
929
930
931
932
933
934
935
936
        residue_name = rc.restype_1to3[residue_name]
        residue_chi_angles = rc.chi_angles_atoms[residue_name]
        atom_indices = []
        for chi_angle in residue_chi_angles:
            atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
        for _ in range(4 - len(atom_indices)):
            atom_indices.append(
                [0, 0, 0, 0]
            )  # For chi angles not defined on the AA.
        chi_atom_indices.append(atom_indices)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
937
    chi_atom_indices.append([[0, 0, 0, 0]] * 4)  # For UNKNOWN residue.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
938

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
939
940
941
942
943
944
    return chi_atom_indices


@curry1
def atom37_to_torsion_angles(
    protein,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
945
    prefix="",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
946
947
):
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
    Convert coordinates to torsion angles.

    This function is extremely sensitive to floating point imprecisions
    and should be run with double precision whenever possible.

    Args:
        Dict containing:
            * (prefix)aatype:
                [*, N_res] residue indices
            * (prefix)all_atom_positions:
                [*, N_res, 37, 3] atom positions (in atom37
                format)
            * (prefix)all_atom_mask:
                [*, N_res, 37] atom position mask
    Returns:
        The same dictionary updated with the following features:

        "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Torsion angles
        "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Alternate torsion angles (accounting for 180-degree symmetry)
        "(prefix)torsion_angles_mask" ([*, N_res, 7])
            Torsion angles mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
971
972
973
974
975
976
    """
    aatype = protein[prefix + "aatype"]
    all_atom_positions = protein[prefix + "all_atom_positions"]
    all_atom_mask = protein[prefix + "all_atom_mask"]

    aatype = torch.clamp(aatype, max=20)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
977

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
978
979
980
981
982
983
984
985
986
987
988
    pad = all_atom_positions.new_zeros(
        [*all_atom_positions.shape[:-3], 1, 37, 3]
    )
    prev_all_atom_positions = torch.cat(
        [pad, all_atom_positions[..., :-1, :, :]], dim=-3
    )

    pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
    prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)

    pre_omega_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
989
990
        [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
991
992
    )
    phi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
993
994
        [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
995
996
    )
    psi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
997
998
        [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
999
1000
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1001
1002
1003
1004
1005
    pre_omega_mask = torch.prod(
        prev_all_atom_mask[..., 1:3], dim=-1
    ) * torch.prod(all_atom_mask[..., :2], dim=-1)
    phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
        all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1006
1007
    )
    psi_mask = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1008
1009
        torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
        * all_atom_mask[..., 4]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
    )

    chi_atom_indices = torch.as_tensor(
        get_chi_atom_indices(), device=aatype.device
    )

    atom_indices = chi_atom_indices[..., aatype, :, :]
    chis_atom_pos = batched_gather(
        all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
    )

    chi_angles_mask = list(rc.chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1022
    chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1023
    chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1024

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1025
1026
1027
    chis_mask = chi_angles_mask[aatype, :]

    chi_angle_atoms_mask = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1028
1029
1030
1031
        all_atom_mask,
        atom_indices,
        dim=-1,
        no_batch_dims=len(atom_indices.shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
    )
    chi_angle_atoms_mask = torch.prod(
        chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
    )
    chis_mask = chis_mask * chi_angle_atoms_mask

    torsions_atom_pos = torch.cat(
        [
            pre_omega_atom_pos[..., None, :, :],
            phi_atom_pos[..., None, :, :],
            psi_atom_pos[..., None, :, :],
            chis_atom_pos,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1044
1045
        ],
        dim=-3,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1046
1047
1048
1049
1050
1051
1052
1053
    )

    torsion_angles_mask = torch.cat(
        [
            pre_omega_mask[..., None],
            phi_mask[..., None],
            psi_mask[..., None],
            chis_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1054
1055
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1056
1057
    )

1058
    torsion_frames = Rigid.from_3_points(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
        torsions_atom_pos[..., 1, :],
        torsions_atom_pos[..., 2, :],
        torsions_atom_pos[..., 0, :],
        eps=1e-8,
    )

    fourth_atom_rel_pos = torsion_frames.invert().apply(
        torsions_atom_pos[..., 3, :]
    )

    torsion_angles_sin_cos = torch.stack(
        [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
    )

    denom = torch.sqrt(
        torch.sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1075
1076
1077
1078
1079
1080
            torch.square(torsion_angles_sin_cos),
            dim=-1,
            dtype=torsion_angles_sin_cos.dtype,
            keepdims=True,
        )
        + 1e-8
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1081
1082
1083
1084
    )
    torsion_angles_sin_cos = torsion_angles_sin_cos / denom

    torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1085
        [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1086
1087
1088
1089
1090
1091
1092
1093
1094
    )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]

    chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
        rc.chi_pi_periodic,
    )[aatype, ...]

    mirror_torsion_angles = torch.cat(
        [
            all_atom_mask.new_ones(*aatype.shape, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1095
1096
1097
            1.0 - 2.0 * chi_is_ambiguous,
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1098
1099
1100
1101
1102
1103
1104
1105
1106
    )

    alt_torsion_angles_sin_cos = (
        torsion_angles_sin_cos * mirror_torsion_angles[..., None]
    )

    protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
    protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
    protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1107

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1108
1109
1110
1111
    return protein


def get_backbone_frames(protein):
1112
1113
    # DISCREPANCY: AlphaFold uses tensor_7s here. I don't know why.
    protein["backbone_rigid_tensor"] = protein["rigidgroups_gt_frames"][
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1114
1115
        ..., 0, :, :
    ]
1116
    protein["backbone_rigid_mask"] = protein["rigidgroups_gt_exists"][..., 0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132

    return protein


def get_chi_angles(protein):
    dtype = protein["all_atom_mask"].dtype
    protein["chi_angles_sin_cos"] = (
        protein["torsion_angles_sin_cos"][..., 3:, :]
    ).to(dtype)
    protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)

    return protein


@curry1
def random_crop_to_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1133
1134
1135
    protein,
    crop_size,
    max_templates,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1136
    shape_schema,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1137
1138
    subsample_templates=False,
    seed=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1139
1140
):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
1141
1142
1143
1144
1145
    # We want each ensemble to be cropped the same way
    g = torch.Generator(device=protein["seq_length"].device)
    if seed is not None:
        g.manual_seed(seed)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1146
    seq_length = protein["seq_length"]
1147

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1148
1149
    if "template_mask" in protein:
        num_templates = protein["template_mask"].shape[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1150
    else:
1151
        num_templates = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1152

1153
1154
    # No need to subsample templates if there aren't any
    subsample_templates = subsample_templates and num_templates
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1155

1156
    num_res_crop_size = min(int(seq_length), crop_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1157
1158

    def _randint(lower, upper):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1159
        return int(torch.randint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1160
                lower,
1161
                upper + 1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1162
1163
1164
                (1,),
                device=protein["seq_length"].device,
                generator=g,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1165
        )[0])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1166

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1167
    if subsample_templates:
1168
        templates_crop_start = _randint(0, num_templates)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1169
        templates_select_indices = torch.randperm(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1170
            num_templates, device=protein["seq_length"].device, generator=g
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1171
1172
1173
        )
    else:
        templates_crop_start = 0
1174
1175
1176
1177

    num_templates_crop_size = min(
        num_templates - templates_crop_start, max_templates
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1178

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1179
    n = seq_length - num_res_crop_size
1180
    if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1181
        right_anchor = n
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1182
    else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1183
        x = _randint(0, n)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1184
        right_anchor = n - x
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1185
1186
1187
1188

    num_res_crop_start = _randint(0, right_anchor)

    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1189
1190
        if k not in shape_schema or (
            "template" not in k and NUM_RES not in shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1191
1192
        ):
            continue
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1193

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1194
        # randomly permute the templates before cropping them.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1195
        if k.startswith("template") and subsample_templates:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1196
            v = v[templates_select_indices]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1197

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1198
        slices = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1199
1200
1201
        for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
            is_num_res = dim_size == NUM_RES
            if i == 0 and k.startswith("template"):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1202
1203
1204
1205
1206
1207
                crop_size = num_templates_crop_size
                crop_start = templates_crop_start
            else:
                crop_start = num_res_crop_start if is_num_res else 0
                crop_size = num_res_crop_size if is_num_res else dim
            slices.append(slice(crop_start, crop_start + crop_size))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1208
1209
1210
        protein[k] = v[slices]

    protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
1211
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1212
    return protein