"deploy/Kubernetes/common/tests/resource_gpu.yaml" did not exist on "08fcd7e93ba5df3093a8b54fe79e0895fe7a5f15"
data_transforms.py 37.1 KB
Newer Older
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import itertools
17
from functools import reduce
18
from operator import add
19

20
21
22
import numpy as np
import torch

23
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
24
from openfold.np import residue_constants as rc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
from openfold.utils.affine_utils import T
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
29
30
from openfold.utils.tensor_utils import (
    tree_map,
    tensor_tree_map,
    batched_gather,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31

32

33
MSA_FEATURE_NAMES = [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
39
    "msa",
    "deletion_matrix",
    "msa_mask",
    "msa_row_mask",
    "bert_mask",
    "true_msa",
40
]
41

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
44
45
46
47
48
49
def cast_to_64bit_ints(protein):
    # We keep all ints as int64
    for k, v in protein.items():
        if v.dtype == torch.int32:
            protein[k] = v.type(torch.int64)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

51
52
53
54
55
def make_one_hot(x, num_classes):
    x_one_hot = torch.zeros(*x.shape, num_classes)
    x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
    return x_one_hot

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56

57
def make_seq_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
60
    protein["seq_mask"] = torch.ones(
        protein["aatype"].shape, dtype=torch.float32
    )
61
62
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63

64
def make_template_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
66
    protein["template_mask"] = torch.ones(
        protein["template_aatype"].shape[0], dtype=torch.float32
67
    )
68
69
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70

71
def curry1(f):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
72
73
74
75
    """Supply all arguments but the first."""

    def fc(*args, **kwargs):
        return lambda x: f(x, *args, **kwargs)
76

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
77
    return fc
78
79
80


def make_all_atom_aatype(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
81
    protein["all_atom_aatype"] = protein["aatype"]
82
83
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
84

85
86
def fix_templates_aatype(protein):
    # Map one-hot to indices
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
    num_templates = protein["template_aatype"].shape[0]
88
89
90
91
92
93
94
95
96
97
98
99
    if(num_templates > 0):
        protein["template_aatype"] = torch.argmax(
            protein["template_aatype"], dim=-1
        )
        # Map hhsearch-aatype to our aatype.
        new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
        new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
            num_templates, -1
        )
        protein["template_aatype"] = torch.gather(
            new_order, 1, index=protein["template_aatype"]
        )
100
101
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102

103
def correct_msa_restypes(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
104
105
    """Correct MSA restype to have the same order as rc."""
    new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
106
    new_order = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
107
108
109
        [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
    ).transpose(0, 1)
    protein["msa"] = torch.gather(new_order, 0, protein["msa"])
110
111

    perm_matrix = np.zeros((22, 22), dtype=np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
112
    perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
113
114

    for k in protein:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
115
        if "profile" in k:
116
            num_dim = protein[k].shape.as_list()[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
118
119
120
121
            assert num_dim in [
                20,
                21,
                22,
            ], "num_dim for %s out of expected range: %s" % (k, num_dim)
122
123
124
            protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
125

126
127
def squeeze_features(protein):
    """Remove singleton and repeated dimensions in protein features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
128
    protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
129
    for k in [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
130
131
132
133
134
135
136
137
138
139
140
141
        "domain_name",
        "msa",
        "num_alignments",
        "seq_length",
        "sequence",
        "superfamily",
        "deletion_matrix",
        "resolution",
        "between_segment_residues",
        "residue_index",
        "template_all_atom_mask",
    ]:
142
143
144
145
146
        if k in protein:
            final_dim = protein[k].shape[-1]
            if isinstance(final_dim, int) and final_dim == 1:
                protein[k] = torch.squeeze(protein[k], dim=-1)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
147
    for k in ["seq_length", "num_alignments"]:
148
149
150
151
        if k in protein:
            protein[k] = protein[k][0]
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
152

153
154
155
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
    """Replace a portion of the MSA with 'X'."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
156
    msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
157
158
    x_idx = 20
    gap_idx = 21
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159
160
161
    msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
    protein["msa"] = torch.where(
        msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"]
162
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
    aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
164

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
165
166
167
168
    protein["aatype"] = torch.where(
        aatype_mask,
        torch.ones_like(protein["aatype"]) * x_idx,
        protein["aatype"],
169
    )
170
171
172
    return protein

@curry1
173
def sample_msa(protein, max_seq, keep_extra, seed=None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
174
175
    """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
    num_seq = protein["msa"].shape[0]
176
177
178
179
    g = torch.Generator(device=protein["msa"].device)
    if seed is not None:
        g.manual_seed(seed)
    shuffled = torch.randperm(num_seq - 1, generator=g) + 1
180
181
    index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
    num_sel = min(max_seq, num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182
183
184
    sel_seq, not_sel_seq = torch.split(
        index_order, [num_sel, num_seq - num_sel]
    )
185
186
187
188

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            if keep_extra:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
189
190
191
                protein["extra_" + k] = torch.index_select(
                    protein[k], 0, not_sel_seq
                )
192
193
194
            protein[k] = torch.index_select(protein[k], 0, sel_seq)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
195

196
197
198
199
200
201
202
@curry1
def sample_msa_distillation(protein, max_seq):
    if(protein["is_distillation"] == 1):
        protein = sample_msa(protein, max_seq, keep_extra=False)
    return protein


203
204
@curry1
def crop_extra_msa(protein, max_extra_msa):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205
    num_seq = protein["extra_msa"].shape[0]
206
207
208
    num_sel = min(max_extra_msa, num_seq)
    select_indices = torch.randperm(num_seq)[:num_sel]
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209
210
211
212
        if "extra_" + k in protein:
            protein["extra_" + k] = torch.index_select(
                protein["extra_" + k], 0, select_indices
            )
213
214
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215

216
217
def delete_extra_msa(protein):
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
218
219
        if "extra_" + k in protein:
            del protein["extra_" + k]
220
221
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
222

223
224
225
# Not used in inference
@curry1
def block_delete_msa(protein, config):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
    num_seq = protein["msa"].shape[0]
227
    block_num_seq = torch.floor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
229
        torch.tensor(num_seq, dtype=torch.float32)
        * config.msa_fraction_per_block
230
    ).to(torch.int32)
231
232

    if config.randomize_num_blocks:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
233
234
235
        nb = torch.distributions.uniform.Uniform(
            0, config.num_blocks + 1
        ).sample()
236
237
238
239
240
    else:
        nb = config.num_blocks

    del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
    del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
241
    del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
242
243
244
245
246
247
248
249
250
251
252
253
254
255
    del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
    uniques, counts = combined.unique(return_counts=True)
    difference = uniques[counts == 1]
    intersection = uniques[counts > 1]
    keep_indices = torch.squeeze(difference, 0)

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = torch.gather(protein[k], keep_indices)

    return protein
256

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257

258
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
260
261
262
263
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
    weights = torch.cat(
        [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
        0,
    )
264
265

    # Make agreement score as weighted Hamming distance
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
267
268
269
    msa_one_hot = make_one_hot(protein["msa"], 23)
    sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
    extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
    extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
270
271
272
273
274
275

    num_seq, num_res, _ = sample_one_hot.shape
    extra_num_seq, _, _ = extra_one_hot.shape

    # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
    # in an optimized fashion to avoid possible memory or computation blowup.
276
    agreement = torch.matmul(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277
        torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
278
279
280
281
        torch.reshape(
            sample_one_hot * weights, [num_seq, num_res * 23]
        ).transpose(0, 1),
    )
282
283

    # Assign each sequence in the extra sequences to the closest MSA sample
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
285
286
    protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
        torch.int64
    )
287
288
289

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290

291
292
293
294
295
296
297
298
299
def unsorted_segment_sum(data, segment_ids, num_segments):
    """
    Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.

    :param data: A tensor whose segments are to be summed.
    :param segment_ids: The segment indices tensor.
    :param num_segments: The number of segments.
    :return: A tensor of same data type as the data argument.
    """
300
301
    # segment_ids.shape should be a prefix of data.shape
    assert all([i in data.shape for i in segment_ids.shape])
302
303
304
305

    # segment_ids is a 1-D tensor repeat it to have the same shape as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:])).long()
306
307
308
        segment_ids = segment_ids.repeat_interleave(s).view(
            segment_ids.shape[0], *data.shape[1:]
        )
309

310
311
    # data.shape and segment_ids.shape should be equal
    assert data.shape == segment_ids.shape
312
313
314
315
316
317

    shape = [num_segments] + list(data.shape[1:])
    tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
    tensor = tensor.type(data.dtype)
    return tensor

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
318

319
320
321
@curry1
def summarize_clusters(protein):
    """Produce profile and deletion_matrix_mean within each cluster."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
322
323
    num_seq = protein["msa"].shape[0]

324
    def csum(x):
325
        return unsorted_segment_sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326
            x, protein["extra_cluster_assignment"], num_seq
327
        )
328

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
329
330
    mask = protein["extra_msa_mask"]
    mask_counts = 1e-6 + protein["msa_mask"] + csum(mask)  # Include center
331

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
333
334
    msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
    msa_sum += make_one_hot(protein["msa"], 23)  # Original sequence
    protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
335
336
337

    del msa_sum

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
339
340
    del_sum = csum(mask * protein["extra_deletion_matrix"])
    del_sum += protein["deletion_matrix"]  # Original sequence
    protein["cluster_deletion_mean"] = del_sum / mask_counts
341
342
343
344
    del del_sum

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
345

346
347
def make_msa_mask(protein):
    """Mask features are all ones, but will later be zero-padded."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
348
349
350
351
    protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
    protein["msa_row_mask"] = torch.ones(
        protein["msa"].shape[0], dtype=torch.float32
    )
352
    return protein
353

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
354

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
356
    """Create pseudo beta features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
357
358
359
    is_gly = torch.eq(aatype, rc.restype_order["G"])
    ca_idx = rc.atom_order["CA"]
    cb_idx = rc.atom_order["CB"]
360
361
362
    pseudo_beta = torch.where(
        torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
        all_atom_positions[..., ca_idx, :],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
363
364
        all_atom_positions[..., cb_idx, :],
    )
365

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
    if all_atom_mask is not None:
367
        pseudo_beta_mask = torch.where(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
368
369
            is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
        )
370
371
372
373
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374

375
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
def make_pseudo_beta(protein, prefix=""):
377
    """Create pseudo-beta (alpha for glycine) position and mask."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
378
379
380
381
382
383
384
385
386
    assert prefix in ["", "template_"]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein["template_aatype" if prefix else "aatype"],
        protein[prefix + "all_atom_positions"],
        protein["template_all_atom_mask" if prefix else "all_atom_mask"],
    )
387
    return protein
388

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
389

390
391
392
393
394
@curry1
def add_constant_field(protein, key, value):
    protein[key] = torch.tensor(value)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
395

396
397
398
def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
399
    distribution = torch.distributions.categorical.Categorical(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
400
        torch.reshape(probs + epsilon, [-1, num_classes])
401
    )
402
403
404
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
405

406
407
def make_hhblits_profile(protein):
    """Compute the HHblits MSA profile if not already present."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
    if "hhblits_profile" in protein:
409
410
411
        return protein

    # Compute the profile for every residue (over all MSA sequences).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
    msa_one_hot = make_one_hot(protein["msa"], 22)
413

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
    protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
415
416
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
417

418
419
420
421
@curry1
def make_masked_msa(protein, config, replace_fraction):
    """Create data for BERT on raw MSA."""
    # Add a random amino acid uniformly.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
    random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
423
424

    categorical_probs = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
425
426
427
428
        config.uniform_prob * random_aa
        + config.profile_prob * protein["hhblits_profile"]
        + config.same_prob * make_one_hot(protein["msa"], 22)
    )
429
430

    # Put all remaining probability on [MASK] which is a new column
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
431
432
433
    pad_shapes = list(
        reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
    )
434
    pad_shapes[1] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
438
    mask_prob = (
        1.0 - config.profile_prob - config.same_prob - config.uniform_prob
    )
    assert mask_prob >= 0.0
439
440
441
    categorical_probs = torch.nn.functional.pad(
        categorical_probs, pad_shapes, value=mask_prob
    )
442

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
443
    sh = protein["msa"].shape
444
445
446
    mask_position = torch.rand(sh) < replace_fraction

    bert_msa = shaped_categorical(categorical_probs)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
447
    bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
448
449

    # Mix real and masked MSA
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
450
451
452
    protein["bert_mask"] = mask_position.to(torch.float32)
    protein["true_msa"] = protein["msa"]
    protein["msa"] = bert_msa
453
454

    return protein
455

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
456

457
@curry1
458
def make_fixed_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
459
460
461
462
463
464
    protein,
    shape_schema,
    msa_cluster_size,
    extra_msa_size,
    num_res=0,
    num_templates=0,
465
):
466
467
468
469
470
471
472
473
474
475
476
    """Guess at the MSA and sequence dimension to make fixed size."""

    pad_size_map = {
        NUM_RES: num_res,
        NUM_MSA_SEQ: msa_cluster_size,
        NUM_EXTRA_SEQ: extra_msa_size,
        NUM_TEMPLATES: num_templates,
    }

    for k, v in protein.items():
        # Don't transfer this to the accelerator.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
477
        if k == "extra_cluster_assignment":
478
479
480
            continue
        shape = list(v.shape)
        schema = shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481
        msg = "Rank mismatch between shape and shape schema for"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
482
        assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
483
484
485
        pad_size = [
            pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
        ]
486
487
488
489
490
491
492
493
494
495

        padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
        padding.reverse()
        padding = list(itertools.chain(*padding))
        if padding:
            protein[k] = torch.nn.functional.pad(v, padding)
            protein[k] = torch.reshape(protein[k], pad_size)

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
496

497
498
499
@curry1
def make_msa_feat(protein):
    """Create and concatenate MSA features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
500
    # Whether there is a domain break. Always zero for chains, but keeping for
501
502
    # compatibility with domain datasets.
    has_break = torch.clip(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
503
        protein["between_segment_residues"].to(torch.float32), 0, 1
504
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
505
    aatype_1hot = make_one_hot(protein["aatype"], 21)
506
507
508

    target_feat = [
        torch.unsqueeze(has_break, dim=-1),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
509
        aatype_1hot,  # Everyone gets the original sequence.
510
511
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
512
513
514
515
516
    msa_1hot = make_one_hot(protein["msa"], 23)
    has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
    deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
        2.0 / np.pi
    )
517
518
519
520
521
522
523

    msa_feat = [
        msa_1hot,
        torch.unsqueeze(has_deletion, dim=-1),
        torch.unsqueeze(deletion_value, dim=-1),
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
524
525
526
527
528
529
530
531
532
    if "cluster_profile" in protein:
        deletion_mean_value = torch.atan(
            protein["cluster_deletion_mean"] / 3.0
        ) * (2.0 / np.pi)
        msa_feat.extend(
            [
                protein["cluster_profile"],
                torch.unsqueeze(deletion_mean_value, dim=-1),
            ]
533
        )
534

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
535
536
537
    if "extra_deletion_matrix" in protein:
        protein["extra_has_deletion"] = torch.clip(
            protein["extra_deletion_matrix"], 0.0, 1.0
538
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
539
540
541
        protein["extra_deletion_value"] = torch.atan(
            protein["extra_deletion_matrix"] / 3.0
        ) * (2.0 / np.pi)
542

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
543
544
    protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
    protein["target_feat"] = torch.cat(target_feat, dim=-1)
545
546
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
547

548
549
550
551
@curry1
def select_feat(protein, feature_list):
    return {k: v for k, v in protein.items() if k in feature_list}

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
552

553
554
555
@curry1
def crop_templates(protein, max_templates):
    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
556
        if k.startswith("template_"):
557
558
            protein[k] = v[:max_templates]
    return protein
559

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
560

561
562
563
564
565
566
def make_atom14_masks(protein):
    """Construct denser atom positions (14 dimensions instead of 37)."""
    restype_atom14_to_atom37 = []
    restype_atom37_to_atom14 = []
    restype_atom14_mask = []

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
567
    for rt in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
568
569
570
571
        atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
        restype_atom14_to_atom37.append(
            [(rc.atom_order[name] if name else 0) for name in atom_names]
        )
572
        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
573
574
575
576
577
578
579
580
581
582
        restype_atom37_to_atom14.append(
            [
                (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                for name in rc.atom_types
            ]
        )

        restype_atom14_mask.append(
            [(1.0 if name else 0.0) for name in atom_names]
        )
583
584
585
586

    # Add dummy mapping for restype 'UNK'
    restype_atom14_to_atom37.append([0] * 14)
    restype_atom37_to_atom14.append([0] * 37)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
587
    restype_atom14_mask.append([0.0] * 14)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588

589
    restype_atom14_to_atom37 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590
591
592
        restype_atom14_to_atom37,
        dtype=torch.int32,
        device=protein["aatype"].device,
593
594
    )
    restype_atom37_to_atom14 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
595
596
597
        restype_atom37_to_atom14,
        dtype=torch.int32,
        device=protein["aatype"].device,
598
599
    )
    restype_atom14_mask = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
600
601
602
        restype_atom14_mask,
        dtype=torch.float32,
        device=protein["aatype"].device,
603
    )
604
605
606

    # create the mapping for (residx, atom14) --> atom37, i.e. an array
    # with shape (num_res, 14) containing the atom37 indices for this protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
607
608
    residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]]
    residx_atom14_mask = restype_atom14_mask[protein["aatype"]]
609

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
610
611
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
612
613

    # create the gather indices for mapping back
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
614
615
    residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]]
    protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
616
617

    # create the corresponding mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
618
    restype_atom37_mask = torch.zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
619
        [21, 37], dtype=torch.float32, device=protein["aatype"].device
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
620
621
622
623
    )
    for restype, restype_letter in enumerate(rc.restypes):
        restype_name = rc.restype_1to3[restype_letter]
        atom_names = rc.residue_atoms[restype_name]
624
        for atom_name in atom_names:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
625
            atom_type = rc.atom_order[atom_name]
626
627
            restype_atom37_mask[restype, atom_type] = 1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
628
629
    residx_atom37_mask = restype_atom37_mask[protein["aatype"]]
    protein["atom37_atom_exists"] = residx_atom37_mask
630

631
    return protein
632
633
634
635
636
637
638


def make_atom14_masks_np(batch):
    batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
    out = make_atom14_masks(batch)
    out = tensor_tree_map(lambda t: np.array(t), out)
    return out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
639
640
641
642
643
644


def make_atom14_positions(protein):
    """Constructs denser atom positions (14 dimensions instead of 37)."""
    residx_atom14_mask = protein["atom14_atom_exists"]
    residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
645

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646
647
    # Create a mask for known ground truth positions.
    residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
648
649
650
651
        protein["all_atom_mask"],
        residx_atom14_to_atom37,
        dim=-1,
        no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
652
653
654
655
656
657
658
659
    )

    # Gather the ground truth positions.
    residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
        batched_gather(
            protein["all_atom_positions"],
            residx_atom14_to_atom37,
            dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
660
            no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
661
662
        )
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
663

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
664
665
666
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["atom14_gt_exists"] = residx_atom14_gt_mask
    protein["atom14_gt_positions"] = residx_atom14_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
667

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
668
669
    # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
    # alternative ground truth coordinates where the naming is swapped
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
670
    restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
671
    restype_3 += ["UNK"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
672

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
673
674
675
    # Matrices for renaming ambiguous atoms.
    all_matrices = {
        res: torch.eye(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
676
677
678
679
680
            14,
            dtype=protein["all_atom_mask"].dtype,
            device=protein["all_atom_mask"].device,
        )
        for res in restype_3
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
681
682
    }
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
        correspondences = torch.arange(
            14, device=protein["all_atom_mask"].device
        )
        for source_atom_swap, target_atom_swap in swap.items():
            source_index = rc.restype_name_to_atom14_names[resname].index(
                source_atom_swap
            )
            target_index = rc.restype_name_to_atom14_names[resname].index(
                target_atom_swap
            )
            correspondences[source_index] = target_index
            correspondences[target_index] = source_index
            renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
            for index, correspondence in enumerate(correspondences):
                renaming_matrix[index, correspondence] = 1.0
        all_matrices[resname] = renaming_matrix
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
699
700
701
    renaming_matrices = torch.stack(
        [all_matrices[restype] for restype in restype_3]
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
702

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
703
704
705
    # Pick the transformation matrices for the given residue sequence
    # shape (num_res, 14, 14).
    renaming_transform = renaming_matrices[protein["aatype"]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
706

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
707
708
    # Apply it to the ground truth positions. shape (num_res, 14, 3).
    alternative_gt_positions = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
709
        "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
710
711
    )
    protein["atom14_alt_gt_positions"] = alternative_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
712

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
713
714
715
716
    # Create the mask for the alternative ground truth (differs from the
    # ground truth mask, if only one of the atoms in an ambiguous pair has a
    # ground truth position).
    alternative_gt_mask = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
717
718
        "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
719
    protein["atom14_alt_gt_exists"] = alternative_gt_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
720

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
721
722
723
    # Create an ambiguous atoms mask.  shape: (21, 14).
    restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
724
725
726
727
728
729
730
731
732
733
734
        for atom_name1, atom_name2 in swap.items():
            restype = rc.restype_order[rc.restype_3to1[resname]]
            atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
                atom_name1
            )
            atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
                atom_name2
            )
            restype_atom14_is_ambiguous[restype, atom_idx1] = 1
            restype_atom14_is_ambiguous[restype, atom_idx2] = 1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
735
    # From this create an ambiguous_mask for the given sequence.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
736
737
738
739
    protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
        protein["aatype"]
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
740
741
742
743
744
745
746
747
748
749
    return protein


def atom37_to_frames(protein):
    aatype = protein["aatype"]
    all_atom_positions = protein["all_atom_positions"]
    all_atom_mask = protein["all_atom_mask"]

    batch_dims = len(aatype.shape[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
750
751
752
753
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
    restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
    restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
754
755
756
    for restype, restype_letter in enumerate(rc.restypes):
        resname = rc.restype_1to3[restype_letter]
        for chi_idx in range(4):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
757
            if rc.chi_angles_mask[restype][chi_idx]:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
758
759
760
761
762
763
                names = rc.chi_angles_atoms[resname][chi_idx]
                restype_rigidgroup_base_atom_names[
                    restype, chi_idx + 4, :
                ] = names[1:]

    restype_rigidgroup_mask = all_atom_mask.new_zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
764
        (*aatype.shape[:-1], 21, 8),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
765
766
767
    )
    restype_rigidgroup_mask[..., 0] = 1
    restype_rigidgroup_mask[..., 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
768
769
    restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
        rc.chi_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
770
771
772
    )

    lookuptable = rc.atom_order.copy()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
773
    lookuptable[""] = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
774
775
776
777
778
779
780
781
782
    lookup = np.vectorize(lambda x: lookuptable[x])
    restype_rigidgroup_base_atom37_idx = lookup(
        restype_rigidgroup_base_atom_names,
    )
    restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
        restype_rigidgroup_base_atom37_idx,
    )
    restype_rigidgroup_base_atom37_idx = (
        restype_rigidgroup_base_atom37_idx.view(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
783
            *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
784
785
786
787
788
789
790
791
792
        )
    )

    residx_rigidgroup_base_atom37_idx = batched_gather(
        restype_rigidgroup_base_atom37_idx,
        aatype,
        dim=-3,
        no_batch_dims=batch_dims,
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
793

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
    base_atom_pos = batched_gather(
        all_atom_positions,
        residx_rigidgroup_base_atom37_idx,
        dim=-2,
        no_batch_dims=len(all_atom_positions.shape[:-2]),
    )

    gt_frames = T.from_3_points(
        p_neg_x_axis=base_atom_pos[..., 0, :],
        origin=base_atom_pos[..., 1, :],
        p_xy_plane=base_atom_pos[..., 2, :],
        eps=1e-8,
    )

    group_exists = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
809
810
811
        restype_rigidgroup_mask,
        aatype,
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
812
813
814
815
816
817
818
        no_batch_dims=batch_dims,
    )

    gt_atoms_exist = batched_gather(
        all_atom_mask,
        residx_rigidgroup_base_atom37_idx,
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
819
        no_batch_dims=len(all_atom_mask.shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
820
821
822
    )
    gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
823
    rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
824
825
826
827
    rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
    rots[..., 0, 0, 0] = -1
    rots[..., 0, 2, 2] = -1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
828
829
    gt_frames = gt_frames.compose(T(rots, None))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
830
831
832
833
834
835
836
837
838
839
840
841
    restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
        *((1,) * batch_dims), 21, 8
    )
    restype_rigidgroup_rots = torch.eye(
        3, dtype=all_atom_mask.dtype, device=aatype.device
    )
    restype_rigidgroup_rots = torch.tile(
        restype_rigidgroup_rots,
        (*((1,) * batch_dims), 21, 8, 1, 1),
    )

    for resname, _ in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
842
        restype = rc.restype_order[rc.restype_3to1[resname]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
843
844
        chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
845
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1

    residx_rigidgroup_is_ambiguous = batched_gather(
        restype_rigidgroup_is_ambiguous,
        aatype,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    residx_rigidgroup_ambiguity_rot = batched_gather(
        restype_rigidgroup_rots,
        aatype,
        dim=-4,
        no_batch_dims=batch_dims,
    )

    alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))

    gt_frames_tensor = gt_frames.to_4x4()
    alt_gt_frames_tensor = alt_gt_frames.to_4x4()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
867
868
869
870
871
    protein["rigidgroups_gt_frames"] = gt_frames_tensor
    protein["rigidgroups_gt_exists"] = gt_exists
    protein["rigidgroups_group_exists"] = group_exists
    protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
    protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
872
873
874
875
876
877

    return protein


def get_chi_atom_indices():
    """Returns atom indices needed to compute chi angles for all residue types.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
878

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
879
880
881
882
883
884
885
886
    Returns:
      A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
      in the order specified in rc.restypes + unknown residue type
      at the end. For chi angles which are not defined on the residue, the
      positions indices are by default set to 0.
    """
    chi_atom_indices = []
    for residue_name in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
887
888
889
890
891
892
893
894
895
896
897
        residue_name = rc.restype_1to3[residue_name]
        residue_chi_angles = rc.chi_angles_atoms[residue_name]
        atom_indices = []
        for chi_angle in residue_chi_angles:
            atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
        for _ in range(4 - len(atom_indices)):
            atom_indices.append(
                [0, 0, 0, 0]
            )  # For chi angles not defined on the AA.
        chi_atom_indices.append(atom_indices)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
898
    chi_atom_indices.append([[0, 0, 0, 0]] * 4)  # For UNKNOWN residue.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
899

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
900
901
902
903
904
905
    return chi_atom_indices


@curry1
def atom37_to_torsion_angles(
    protein,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
906
    prefix="",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
907
908
):
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
    Convert coordinates to torsion angles.

    This function is extremely sensitive to floating point imprecisions
    and should be run with double precision whenever possible.

    Args:
        Dict containing:
            * (prefix)aatype:
                [*, N_res] residue indices
            * (prefix)all_atom_positions:
                [*, N_res, 37, 3] atom positions (in atom37
                format)
            * (prefix)all_atom_mask:
                [*, N_res, 37] atom position mask
    Returns:
        The same dictionary updated with the following features:

        "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Torsion angles
        "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Alternate torsion angles (accounting for 180-degree symmetry)
        "(prefix)torsion_angles_mask" ([*, N_res, 7])
            Torsion angles mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
932
933
934
935
936
937
    """
    aatype = protein[prefix + "aatype"]
    all_atom_positions = protein[prefix + "all_atom_positions"]
    all_atom_mask = protein[prefix + "all_atom_mask"]

    aatype = torch.clamp(aatype, max=20)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
938

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
939
940
941
942
943
944
945
946
947
948
949
    pad = all_atom_positions.new_zeros(
        [*all_atom_positions.shape[:-3], 1, 37, 3]
    )
    prev_all_atom_positions = torch.cat(
        [pad, all_atom_positions[..., :-1, :, :]], dim=-3
    )

    pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
    prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)

    pre_omega_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
950
951
        [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
952
953
    )
    phi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
954
955
        [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
956
957
    )
    psi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
958
959
        [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
960
961
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
962
963
964
965
966
    pre_omega_mask = torch.prod(
        prev_all_atom_mask[..., 1:3], dim=-1
    ) * torch.prod(all_atom_mask[..., :2], dim=-1)
    phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
        all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
967
968
    )
    psi_mask = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
969
970
        torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
        * all_atom_mask[..., 4]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
971
972
973
974
975
976
977
978
979
980
981
982
    )

    chi_atom_indices = torch.as_tensor(
        get_chi_atom_indices(), device=aatype.device
    )

    atom_indices = chi_atom_indices[..., aatype, :, :]
    chis_atom_pos = batched_gather(
        all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
    )

    chi_angles_mask = list(rc.chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
983
    chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
984
    chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
985

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
986
987
988
    chis_mask = chi_angles_mask[aatype, :]

    chi_angle_atoms_mask = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
989
990
991
992
        all_atom_mask,
        atom_indices,
        dim=-1,
        no_batch_dims=len(atom_indices.shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
    )
    chi_angle_atoms_mask = torch.prod(
        chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
    )
    chis_mask = chis_mask * chi_angle_atoms_mask

    torsions_atom_pos = torch.cat(
        [
            pre_omega_atom_pos[..., None, :, :],
            phi_atom_pos[..., None, :, :],
            psi_atom_pos[..., None, :, :],
            chis_atom_pos,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1005
1006
        ],
        dim=-3,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1007
1008
1009
1010
1011
1012
1013
1014
    )

    torsion_angles_mask = torch.cat(
        [
            pre_omega_mask[..., None],
            phi_mask[..., None],
            psi_mask[..., None],
            chis_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1015
1016
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
    )

    torsion_frames = T.from_3_points(
        torsions_atom_pos[..., 1, :],
        torsions_atom_pos[..., 2, :],
        torsions_atom_pos[..., 0, :],
        eps=1e-8,
    )

    fourth_atom_rel_pos = torsion_frames.invert().apply(
        torsions_atom_pos[..., 3, :]
    )

    torsion_angles_sin_cos = torch.stack(
        [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
    )

    denom = torch.sqrt(
        torch.sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1036
1037
1038
1039
1040
1041
            torch.square(torsion_angles_sin_cos),
            dim=-1,
            dtype=torsion_angles_sin_cos.dtype,
            keepdims=True,
        )
        + 1e-8
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1042
1043
1044
1045
    )
    torsion_angles_sin_cos = torsion_angles_sin_cos / denom

    torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1046
        [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1047
1048
1049
1050
1051
1052
1053
1054
1055
    )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]

    chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
        rc.chi_pi_periodic,
    )[aatype, ...]

    mirror_torsion_angles = torch.cat(
        [
            all_atom_mask.new_ones(*aatype.shape, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1056
1057
1058
            1.0 - 2.0 * chi_is_ambiguous,
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1059
1060
1061
1062
1063
1064
1065
1066
1067
    )

    alt_torsion_angles_sin_cos = (
        torsion_angles_sin_cos * mirror_torsion_angles[..., None]
    )

    protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
    protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
    protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1068

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1069
1070
1071
1072
    return protein


def get_backbone_frames(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1073
1074
1075
1076
1077
    # TODO: Verify that this is correct
    protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
        ..., 0, :, :
    ]
    protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093

    return protein


def get_chi_angles(protein):
    dtype = protein["all_atom_mask"].dtype
    protein["chi_angles_sin_cos"] = (
        protein["torsion_angles_sin_cos"][..., 3:, :]
    ).to(dtype)
    protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)

    return protein


@curry1
def random_crop_to_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1094
1095
1096
    protein,
    crop_size,
    max_templates,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1097
    shape_schema,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1098
1099
    subsample_templates=False,
    seed=None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1100
1101
):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
1102
1103
1104
1105
1106
    # We want each ensemble to be cropped the same way
    g = torch.Generator(device=protein["seq_length"].device)
    if seed is not None:
        g.manual_seed(seed)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1107
    seq_length = protein["seq_length"]
1108

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1109
1110
    if "template_mask" in protein:
        num_templates = protein["template_mask"].shape[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1111
    else:
1112
        num_templates = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1113

1114
1115
    # No need to subsample templates if there aren't any
    subsample_templates = subsample_templates and num_templates
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1116

1117
    num_res_crop_size = min(int(seq_length), crop_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1118
1119

    def _randint(lower, upper):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1120
        return int(torch.randint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1121
                lower,
1122
                upper + 1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1123
1124
1125
                (1,),
                device=protein["seq_length"].device,
                generator=g,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1126
        )[0])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1127

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1128
    if subsample_templates:
1129
        templates_crop_start = _randint(0, num_templates)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1130
        templates_select_indices = torch.randperm(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1131
            num_templates, device=protein["seq_length"].device, generator=g
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1132
1133
1134
1135
1136
1137
1138
        )
        num_templates_crop_size = min(
            num_templates - templates_crop_start, max_templates
        )
    else:
        templates_crop_start = 0
        num_templates_crop_size = num_templates
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1139

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1140
    n = seq_length - num_res_crop_size
1141
    if "use_clamped_fape" in protein and protein["use_clamped_fape"] == 1.:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1142
        right_anchor = n
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1143
    else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1144
        x = _randint(0, n)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1145
        right_anchor = n - x
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1146
1147
1148
1149

    num_res_crop_start = _randint(0, right_anchor)

    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1150
1151
        if k not in shape_schema or (
            "template" not in k and NUM_RES not in shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1152
1153
        ):
            continue
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1154

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1155
        # randomly permute the templates before cropping them.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1156
        if k.startswith("template") and subsample_templates:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1157
            v = v[templates_select_indices]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1158

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1159
        slices = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1160
1161
1162
        for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
            is_num_res = dim_size == NUM_RES
            if i == 0 and k.startswith("template"):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1163
1164
1165
1166
1167
1168
                crop_size = num_templates_crop_size
                crop_start = templates_crop_start
            else:
                crop_start = num_res_crop_start if is_num_res else 0
                crop_size = num_res_crop_size if is_num_res else dim
            slices.append(slice(crop_start, crop_start + crop_size))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1169
1170
1171
        protein[k] = v[slices]

    protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1172
    return protein