data_transforms.py 36.9 KB
Newer Older
1
2
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
3
#
4
5
6
7
8
9
10
11
12
13
14
15
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

16
import itertools
17
from functools import reduce
18
from operator import add
19

20
21
22
import numpy as np
import torch

23
from openfold.config import NUM_RES, NUM_EXTRA_SEQ, NUM_TEMPLATES, NUM_MSA_SEQ
24
from openfold.np import residue_constants as rc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
25
from openfold.utils.affine_utils import T
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
26
27
28
29
30
from openfold.utils.tensor_utils import (
    tree_map,
    tensor_tree_map,
    batched_gather,
)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
31

32

33
MSA_FEATURE_NAMES = [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
34
35
36
37
38
39
    "msa",
    "deletion_matrix",
    "msa_mask",
    "msa_row_mask",
    "bert_mask",
    "true_msa",
40
]
41

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42

43
44
45
46
47
48
49
def cast_to_64bit_ints(protein):
    # We keep all ints as int64
    for k, v in protein.items():
        if v.dtype == torch.int32:
            protein[k] = v.type(torch.int64)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
50

51
52
53
54
55
def make_one_hot(x, num_classes):
    x_one_hot = torch.zeros(*x.shape, num_classes)
    x_one_hot.scatter_(-1, x.unsqueeze(-1), 1)
    return x_one_hot

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56

57
def make_seq_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
58
59
60
    protein["seq_mask"] = torch.ones(
        protein["aatype"].shape, dtype=torch.float32
    )
61
62
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63

64
def make_template_mask(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
65
66
    protein["template_mask"] = torch.ones(
        protein["template_aatype"].shape[0], dtype=torch.float32
67
    )
68
69
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
70

71
def curry1(f):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
72
73
74
75
    """Supply all arguments but the first."""

    def fc(*args, **kwargs):
        return lambda x: f(x, *args, **kwargs)
76

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
77
    return fc
78
79
80
81


@curry1
def add_distillation_flag(protein, distillation):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
    protein["is_distillation"] = torch.tensor(
83
84
        float(distillation), dtype=torch.float32
    )
85
86
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87

88
def make_all_atom_aatype(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
    protein["all_atom_aatype"] = protein["aatype"]
90
91
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
92

93
94
def fix_templates_aatype(protein):
    # Map one-hot to indices
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
95
96
97
98
    num_templates = protein["template_aatype"].shape[0]
    protein["template_aatype"] = torch.argmax(
        protein["template_aatype"], dim=-1
    )
99
    # Map hhsearch-aatype to our aatype.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
    new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
101
102
103
104
105
    new_order = torch.tensor(new_order_list, dtype=torch.int64).expand(
        num_templates, -1
    )
    protein["template_aatype"] = torch.gather(
        new_order, 1, index=protein["template_aatype"]
106
    )
107
108
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
109

110
def correct_msa_restypes(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
111
112
    """Correct MSA restype to have the same order as rc."""
    new_order_list = rc.MAP_HHBLITS_AATYPE_TO_OUR_AATYPE
113
    new_order = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
114
115
116
        [new_order_list] * protein["msa"].shape[1], dtype=protein["msa"].dtype
    ).transpose(0, 1)
    protein["msa"] = torch.gather(new_order, 0, protein["msa"])
117
118

    perm_matrix = np.zeros((22, 22), dtype=np.float32)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
    perm_matrix[range(len(new_order_list)), new_order_list] = 1.0
120
121

    for k in protein:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
        if "profile" in k:
123
            num_dim = protein[k].shape.as_list()[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
124
125
126
127
128
            assert num_dim in [
                20,
                21,
                22,
            ], "num_dim for %s out of expected range: %s" % (k, num_dim)
129
130
131
            protein[k] = torch.dot(protein[k], perm_matrix[:num_dim, :num_dim])
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132

133
134
def squeeze_features(protein):
    """Remove singleton and repeated dimensions in protein features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
135
    protein["aatype"] = torch.argmax(protein["aatype"], dim=-1)
136
    for k in [
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
137
138
139
140
141
142
143
144
145
146
147
148
        "domain_name",
        "msa",
        "num_alignments",
        "seq_length",
        "sequence",
        "superfamily",
        "deletion_matrix",
        "resolution",
        "between_segment_residues",
        "residue_index",
        "template_all_atom_mask",
    ]:
149
150
151
152
153
        if k in protein:
            final_dim = protein[k].shape[-1]
            if isinstance(final_dim, int) and final_dim == 1:
                protein[k] = torch.squeeze(protein[k], dim=-1)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
154
    for k in ["seq_length", "num_alignments"]:
155
156
157
158
        if k in protein:
            protein[k] = protein[k][0]
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
159

160
161
162
@curry1
def randomly_replace_msa_with_unknown(protein, replace_proportion):
    """Replace a portion of the MSA with 'X'."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
163
    msa_mask = torch.rand(protein["msa"].shape) < replace_proportion
164
165
    x_idx = 20
    gap_idx = 21
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167
168
    msa_mask = torch.logical_and(msa_mask, protein["msa"] != gap_idx)
    protein["msa"] = torch.where(
        msa_mask, torch.ones_like(protein["msa"]) * x_idx, protein["msa"]
169
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
170
    aatype_mask = torch.rand(protein["aatype"].shape) < replace_proportion
171

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
172
173
174
175
    protein["aatype"] = torch.where(
        aatype_mask,
        torch.ones_like(protein["aatype"]) * x_idx,
        protein["aatype"],
176
    )
177
178
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
179

180
181
@curry1
def sample_msa(protein, max_seq, keep_extra):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
182
183
184
    """Sample MSA randomly, remaining sequences are stored are stored as `extra_*`."""
    num_seq = protein["msa"].shape[0]
    shuffled = torch.randperm(num_seq - 1) + 1
185
186
    index_order = torch.cat((torch.tensor([0]), shuffled), dim=0)
    num_sel = min(max_seq, num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
187
188
189
    sel_seq, not_sel_seq = torch.split(
        index_order, [num_sel, num_seq - num_sel]
    )
190
191
192
193

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            if keep_extra:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
194
195
196
                protein["extra_" + k] = torch.index_select(
                    protein[k], 0, not_sel_seq
                )
197
198
199
            protein[k] = torch.index_select(protein[k], 0, sel_seq)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
200

201
202
@curry1
def crop_extra_msa(protein, max_extra_msa):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
    num_seq = protein["extra_msa"].shape[0]
204
205
206
    num_sel = min(max_extra_msa, num_seq)
    select_indices = torch.randperm(num_seq)[:num_sel]
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
208
209
210
        if "extra_" + k in protein:
            protein["extra_" + k] = torch.index_select(
                protein["extra_" + k], 0, select_indices
            )
211
212
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
213

214
215
def delete_extra_msa(protein):
    for k in MSA_FEATURE_NAMES:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
216
217
        if "extra_" + k in protein:
            del protein["extra_" + k]
218
219
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
220

221
222
223
# Not used in inference
@curry1
def block_delete_msa(protein, config):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
    num_seq = protein["msa"].shape[0]
225
    block_num_seq = torch.floor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
226
227
        torch.tensor(num_seq, dtype=torch.float32)
        * config.msa_fraction_per_block
228
    ).to(torch.int32)
229
230

    if config.randomize_num_blocks:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
232
233
        nb = torch.distributions.uniform.Uniform(
            0, config.num_blocks + 1
        ).sample()
234
235
236
237
238
    else:
        nb = config.num_blocks

    del_block_starts = torch.distributions.Uniform(0, num_seq).sample(nb)
    del_blocks = del_block_starts[:, None] + torch.range(block_num_seq)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
    del_blocks = torch.clip(del_blocks, 0, num_seq - 1)
240
241
242
243
244
245
246
247
248
249
250
251
252
253
    del_indices = torch.unique(torch.sort(torch.reshape(del_blocks, [-1])))[0]

    # Make sure we keep the original sequence
    combined = torch.cat((torch.range(1, num_seq)[None], del_indices[None]))
    uniques, counts = combined.unique(return_counts=True)
    difference = uniques[counts == 1]
    intersection = uniques[counts > 1]
    keep_indices = torch.squeeze(difference, 0)

    for k in MSA_FEATURE_NAMES:
        if k in protein:
            protein[k] = torch.gather(protein[k], keep_indices)

    return protein
254

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
255

256
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
258
259
260
261
def nearest_neighbor_clusters(protein, gap_agreement_weight=0.0):
    weights = torch.cat(
        [torch.ones(21), gap_agreement_weight * torch.ones(1), torch.zeros(1)],
        0,
    )
262
263

    # Make agreement score as weighted Hamming distance
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264
265
266
267
    msa_one_hot = make_one_hot(protein["msa"], 23)
    sample_one_hot = protein["msa_mask"][:, :, None] * msa_one_hot
    extra_msa_one_hot = make_one_hot(protein["extra_msa"], 23)
    extra_one_hot = protein["extra_msa_mask"][:, :, None] * extra_msa_one_hot
268
269
270
271
272
273

    num_seq, num_res, _ = sample_one_hot.shape
    extra_num_seq, _, _ = extra_one_hot.shape

    # Compute tf.einsum('mrc,nrc,c->mn', sample_one_hot, extra_one_hot, weights)
    # in an optimized fashion to avoid possible memory or computation blowup.
274
    agreement = torch.matmul(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
        torch.reshape(extra_one_hot, [extra_num_seq, num_res * 23]),
276
277
278
279
        torch.reshape(
            sample_one_hot * weights, [num_seq, num_res * 23]
        ).transpose(0, 1),
    )
280
281

    # Assign each sequence in the extra sequences to the closest MSA sample
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
282
283
284
    protein["extra_cluster_assignment"] = torch.argmax(agreement, dim=1).to(
        torch.int64
    )
285
286
287

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
288

289
290
291
292
293
294
295
296
297
def unsorted_segment_sum(data, segment_ids, num_segments):
    """
    Computes the sum along segments of a tensor. Analogous to tf.unsorted_segment_sum.

    :param data: A tensor whose segments are to be summed.
    :param segment_ids: The segment indices tensor.
    :param num_segments: The number of segments.
    :return: A tensor of same data type as the data argument.
    """
298
299
    # segment_ids.shape should be a prefix of data.shape
    assert all([i in data.shape for i in segment_ids.shape])
300
301
302
303

    # segment_ids is a 1-D tensor repeat it to have the same shape as data
    if len(segment_ids.shape) == 1:
        s = torch.prod(torch.tensor(data.shape[1:])).long()
304
305
306
        segment_ids = segment_ids.repeat_interleave(s).view(
            segment_ids.shape[0], *data.shape[1:]
        )
307

308
309
    # data.shape and segment_ids.shape should be equal
    assert data.shape == segment_ids.shape
310
311
312
313
314
315

    shape = [num_segments] + list(data.shape[1:])
    tensor = torch.zeros(*shape).scatter_add(0, segment_ids, data.float())
    tensor = tensor.type(data.dtype)
    return tensor

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
316

317
318
319
@curry1
def summarize_clusters(protein):
    """Produce profile and deletion_matrix_mean within each cluster."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
321
    num_seq = protein["msa"].shape[0]

322
    def csum(x):
323
        return unsorted_segment_sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
324
            x, protein["extra_cluster_assignment"], num_seq
325
        )
326

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
327
328
    mask = protein["extra_msa_mask"]
    mask_counts = 1e-6 + protein["msa_mask"] + csum(mask)  # Include center
329

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
330
331
332
    msa_sum = csum(mask[:, :, None] * make_one_hot(protein["extra_msa"], 23))
    msa_sum += make_one_hot(protein["msa"], 23)  # Original sequence
    protein["cluster_profile"] = msa_sum / mask_counts[:, :, None]
333
334
335

    del msa_sum

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
336
337
338
    del_sum = csum(mask * protein["extra_deletion_matrix"])
    del_sum += protein["deletion_matrix"]  # Original sequence
    protein["cluster_deletion_mean"] = del_sum / mask_counts
339
340
341
342
    del del_sum

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
343

344
345
def make_msa_mask(protein):
    """Mask features are all ones, but will later be zero-padded."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346
347
348
349
    protein["msa_mask"] = torch.ones(protein["msa"].shape, dtype=torch.float32)
    protein["msa_row_mask"] = torch.ones(
        protein["msa"].shape[0], dtype=torch.float32
    )
350
    return protein
351

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
352

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
353
def pseudo_beta_fn(aatype, all_atom_positions, all_atom_mask):
354
    """Create pseudo beta features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
355
356
357
    is_gly = torch.eq(aatype, rc.restype_order["G"])
    ca_idx = rc.atom_order["CA"]
    cb_idx = rc.atom_order["CB"]
358
359
360
    pseudo_beta = torch.where(
        torch.tile(is_gly[..., None], [1] * len(is_gly.shape) + [3]),
        all_atom_positions[..., ca_idx, :],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
361
362
        all_atom_positions[..., cb_idx, :],
    )
363

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
364
    if all_atom_mask is not None:
365
        pseudo_beta_mask = torch.where(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
366
367
            is_gly, all_atom_mask[..., ca_idx], all_atom_mask[..., cb_idx]
        )
368
369
370
371
        return pseudo_beta, pseudo_beta_mask
    else:
        return pseudo_beta

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
372

373
@curry1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
374
def make_pseudo_beta(protein, prefix=""):
375
    """Create pseudo-beta (alpha for glycine) position and mask."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
377
378
379
380
381
382
383
384
    assert prefix in ["", "template_"]
    (
        protein[prefix + "pseudo_beta"],
        protein[prefix + "pseudo_beta_mask"],
    ) = pseudo_beta_fn(
        protein["template_aatype" if prefix else "aatype"],
        protein[prefix + "all_atom_positions"],
        protein["template_all_atom_mask" if prefix else "all_atom_mask"],
    )
385
    return protein
386

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
387

388
389
390
391
392
@curry1
def add_constant_field(protein, key, value):
    protein[key] = torch.tensor(value)
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
393

394
395
396
def shaped_categorical(probs, epsilon=1e-10):
    ds = probs.shape
    num_classes = ds[-1]
397
    distribution = torch.distributions.categorical.Categorical(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
398
        torch.reshape(probs + epsilon, [-1, num_classes])
399
    )
400
401
402
    counts = distribution.sample()
    return torch.reshape(counts, ds[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
403

404
405
def make_hhblits_profile(protein):
    """Compute the HHblits MSA profile if not already present."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
406
    if "hhblits_profile" in protein:
407
408
409
        return protein

    # Compute the profile for every residue (over all MSA sequences).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
410
    msa_one_hot = make_one_hot(protein["msa"], 22)
411

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
412
    protein["hhblits_profile"] = torch.mean(msa_one_hot, dim=0)
413
414
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415

416
417
418
419
@curry1
def make_masked_msa(protein, config, replace_fraction):
    """Create data for BERT on raw MSA."""
    # Add a random amino acid uniformly.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
420
    random_aa = torch.tensor([0.05] * 20 + [0.0, 0.0], dtype=torch.float32)
421
422

    categorical_probs = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
424
425
426
        config.uniform_prob * random_aa
        + config.profile_prob * protein["hhblits_profile"]
        + config.same_prob * make_one_hot(protein["msa"], 22)
    )
427
428

    # Put all remaining probability on [MASK] which is a new column
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
429
430
431
    pad_shapes = list(
        reduce(add, [(0, 0) for _ in range(len(categorical_probs.shape))])
    )
432
    pad_shapes[1] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
433
434
435
436
    mask_prob = (
        1.0 - config.profile_prob - config.same_prob - config.uniform_prob
    )
    assert mask_prob >= 0.0
437
438
439
    categorical_probs = torch.nn.functional.pad(
        categorical_probs, pad_shapes, value=mask_prob
    )
440

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
441
    sh = protein["msa"].shape
442
443
444
    mask_position = torch.rand(sh) < replace_fraction

    bert_msa = shaped_categorical(categorical_probs)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
445
    bert_msa = torch.where(mask_position, bert_msa, protein["msa"])
446
447

    # Mix real and masked MSA
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
450
    protein["bert_mask"] = mask_position.to(torch.float32)
    protein["true_msa"] = protein["msa"]
    protein["msa"] = bert_msa
451
452

    return protein
453

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
454

455
@curry1
456
def make_fixed_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
457
458
459
460
461
462
    protein,
    shape_schema,
    msa_cluster_size,
    extra_msa_size,
    num_res=0,
    num_templates=0,
463
):
464
465
466
467
468
469
470
471
472
473
474
    """Guess at the MSA and sequence dimension to make fixed size."""

    pad_size_map = {
        NUM_RES: num_res,
        NUM_MSA_SEQ: msa_cluster_size,
        NUM_EXTRA_SEQ: extra_msa_size,
        NUM_TEMPLATES: num_templates,
    }

    for k, v in protein.items():
        # Don't transfer this to the accelerator.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
475
        if k == "extra_cluster_assignment":
476
477
478
            continue
        shape = list(v.shape)
        schema = shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
479
        msg = "Rank mismatch between shape and shape schema for"
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
480
        assert len(shape) == len(schema), f"{msg} {k}: {shape} vs {schema}"
481
482
483
        pad_size = [
            pad_size_map.get(s2, None) or s1 for (s1, s2) in zip(shape, schema)
        ]
484
485
486
487
488
489
490
491
492
493

        padding = [(0, p - v.shape[i]) for i, p in enumerate(pad_size)]
        padding.reverse()
        padding = list(itertools.chain(*padding))
        if padding:
            protein[k] = torch.nn.functional.pad(v, padding)
            protein[k] = torch.reshape(protein[k], pad_size)

    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
494

495
496
497
@curry1
def make_msa_feat(protein):
    """Create and concatenate MSA features."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
498
    # Whether there is a domain break. Always zero for chains, but keeping for
499
500
    # compatibility with domain datasets.
    has_break = torch.clip(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
501
        protein["between_segment_residues"].to(torch.float32), 0, 1
502
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
503
    aatype_1hot = make_one_hot(protein["aatype"], 21)
504
505
506

    target_feat = [
        torch.unsqueeze(has_break, dim=-1),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
507
        aatype_1hot,  # Everyone gets the original sequence.
508
509
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
510
511
512
513
514
    msa_1hot = make_one_hot(protein["msa"], 23)
    has_deletion = torch.clip(protein["deletion_matrix"], 0.0, 1.0)
    deletion_value = torch.atan(protein["deletion_matrix"] / 3.0) * (
        2.0 / np.pi
    )
515
516
517
518
519
520
521

    msa_feat = [
        msa_1hot,
        torch.unsqueeze(has_deletion, dim=-1),
        torch.unsqueeze(deletion_value, dim=-1),
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
522
523
524
525
526
527
528
529
530
    if "cluster_profile" in protein:
        deletion_mean_value = torch.atan(
            protein["cluster_deletion_mean"] / 3.0
        ) * (2.0 / np.pi)
        msa_feat.extend(
            [
                protein["cluster_profile"],
                torch.unsqueeze(deletion_mean_value, dim=-1),
            ]
531
        )
532

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
533
534
535
    if "extra_deletion_matrix" in protein:
        protein["extra_has_deletion"] = torch.clip(
            protein["extra_deletion_matrix"], 0.0, 1.0
536
        )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
537
538
539
        protein["extra_deletion_value"] = torch.atan(
            protein["extra_deletion_matrix"] / 3.0
        ) * (2.0 / np.pi)
540

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
541
542
    protein["msa_feat"] = torch.cat(msa_feat, dim=-1)
    protein["target_feat"] = torch.cat(target_feat, dim=-1)
543
544
    return protein

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
545

546
547
548
549
@curry1
def select_feat(protein, feature_list):
    return {k: v for k, v in protein.items() if k in feature_list}

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
550

551
552
553
@curry1
def crop_templates(protein, max_templates):
    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
554
        if k.startswith("template_"):
555
556
            protein[k] = v[:max_templates]
    return protein
557

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
558

559
560
561
562
563
564
def make_atom14_masks(protein):
    """Construct denser atom positions (14 dimensions instead of 37)."""
    restype_atom14_to_atom37 = []
    restype_atom37_to_atom14 = []
    restype_atom14_mask = []

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
565
    for rt in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
566
567
568
569
        atom_names = rc.restype_name_to_atom14_names[rc.restype_1to3[rt]]
        restype_atom14_to_atom37.append(
            [(rc.atom_order[name] if name else 0) for name in atom_names]
        )
570
        atom_name_to_idx14 = {name: i for i, name in enumerate(atom_names)}
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
571
572
573
574
575
576
577
578
579
580
        restype_atom37_to_atom14.append(
            [
                (atom_name_to_idx14[name] if name in atom_name_to_idx14 else 0)
                for name in rc.atom_types
            ]
        )

        restype_atom14_mask.append(
            [(1.0 if name else 0.0) for name in atom_names]
        )
581
582
583
584

    # Add dummy mapping for restype 'UNK'
    restype_atom14_to_atom37.append([0] * 14)
    restype_atom37_to_atom14.append([0] * 37)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
585
    restype_atom14_mask.append([0.0] * 14)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
586

587
    restype_atom14_to_atom37 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
588
589
590
        restype_atom14_to_atom37,
        dtype=torch.int32,
        device=protein["aatype"].device,
591
592
    )
    restype_atom37_to_atom14 = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
593
594
595
        restype_atom37_to_atom14,
        dtype=torch.int32,
        device=protein["aatype"].device,
596
597
    )
    restype_atom14_mask = torch.tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
598
599
600
        restype_atom14_mask,
        dtype=torch.float32,
        device=protein["aatype"].device,
601
    )
602
603
604

    # create the mapping for (residx, atom14) --> atom37, i.e. an array
    # with shape (num_res, 14) containing the atom37 indices for this protein
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605
606
    residx_atom14_to_atom37 = restype_atom14_to_atom37[protein["aatype"]]
    residx_atom14_mask = restype_atom14_mask[protein["aatype"]]
607

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
608
609
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["residx_atom14_to_atom37"] = residx_atom14_to_atom37.long()
610
611

    # create the gather indices for mapping back
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
612
613
    residx_atom37_to_atom14 = restype_atom37_to_atom14[protein["aatype"]]
    protein["residx_atom37_to_atom14"] = residx_atom37_to_atom14.long()
614
615

    # create the corresponding mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
616
    restype_atom37_mask = torch.zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
617
        [21, 37], dtype=torch.float32, device=protein["aatype"].device
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
618
619
620
621
    )
    for restype, restype_letter in enumerate(rc.restypes):
        restype_name = rc.restype_1to3[restype_letter]
        atom_names = rc.residue_atoms[restype_name]
622
        for atom_name in atom_names:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
623
            atom_type = rc.atom_order[atom_name]
624
625
            restype_atom37_mask[restype, atom_type] = 1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
626
627
    residx_atom37_mask = restype_atom37_mask[protein["aatype"]]
    protein["atom37_atom_exists"] = residx_atom37_mask
628

629
    return protein
630
631
632
633
634
635
636


def make_atom14_masks_np(batch):
    batch = tree_map(lambda n: torch.tensor(n), batch, np.ndarray)
    out = make_atom14_masks(batch)
    out = tensor_tree_map(lambda t: np.array(t), out)
    return out
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
637
638
639
640
641
642


def make_atom14_positions(protein):
    """Constructs denser atom positions (14 dimensions instead of 37)."""
    residx_atom14_mask = protein["atom14_atom_exists"]
    residx_atom14_to_atom37 = protein["residx_atom14_to_atom37"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
643

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
644
645
    # Create a mask for known ground truth positions.
    residx_atom14_gt_mask = residx_atom14_mask * batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
646
647
648
649
        protein["all_atom_mask"],
        residx_atom14_to_atom37,
        dim=-1,
        no_batch_dims=len(protein["all_atom_mask"].shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
650
651
652
653
654
655
656
657
    )

    # Gather the ground truth positions.
    residx_atom14_gt_positions = residx_atom14_gt_mask[..., None] * (
        batched_gather(
            protein["all_atom_positions"],
            residx_atom14_to_atom37,
            dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
658
            no_batch_dims=len(protein["all_atom_positions"].shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
659
660
        )
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
661

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
663
664
    protein["atom14_atom_exists"] = residx_atom14_mask
    protein["atom14_gt_exists"] = residx_atom14_gt_mask
    protein["atom14_gt_positions"] = residx_atom14_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
665

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
666
667
    # As the atom naming is ambiguous for 7 of the 20 amino acids, provide
    # alternative ground truth coordinates where the naming is swapped
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
668
    restype_3 = [rc.restype_1to3[res] for res in rc.restypes]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
669
    restype_3 += ["UNK"]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
670

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
671
672
673
    # Matrices for renaming ambiguous atoms.
    all_matrices = {
        res: torch.eye(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
674
675
676
677
678
            14,
            dtype=protein["all_atom_mask"].dtype,
            device=protein["all_atom_mask"].device,
        )
        for res in restype_3
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
679
680
    }
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
        correspondences = torch.arange(
            14, device=protein["all_atom_mask"].device
        )
        for source_atom_swap, target_atom_swap in swap.items():
            source_index = rc.restype_name_to_atom14_names[resname].index(
                source_atom_swap
            )
            target_index = rc.restype_name_to_atom14_names[resname].index(
                target_atom_swap
            )
            correspondences[source_index] = target_index
            correspondences[target_index] = source_index
            renaming_matrix = protein["all_atom_mask"].new_zeros((14, 14))
            for index, correspondence in enumerate(correspondences):
                renaming_matrix[index, correspondence] = 1.0
        all_matrices[resname] = renaming_matrix
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
697
698
699
    renaming_matrices = torch.stack(
        [all_matrices[restype] for restype in restype_3]
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
700

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
701
702
703
    # Pick the transformation matrices for the given residue sequence
    # shape (num_res, 14, 14).
    renaming_transform = renaming_matrices[protein["aatype"]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
704

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
705
706
    # Apply it to the ground truth positions. shape (num_res, 14, 3).
    alternative_gt_positions = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
707
        "...rac,...rab->...rbc", residx_atom14_gt_positions, renaming_transform
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
708
709
    )
    protein["atom14_alt_gt_positions"] = alternative_gt_positions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
710

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
711
712
713
714
    # Create the mask for the alternative ground truth (differs from the
    # ground truth mask, if only one of the atoms in an ambiguous pair has a
    # ground truth position).
    alternative_gt_mask = torch.einsum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
715
716
        "...ra,...rab->...rb", residx_atom14_gt_mask, renaming_transform
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
717
    protein["atom14_alt_gt_exists"] = alternative_gt_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
718

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
719
720
721
    # Create an ambiguous atoms mask.  shape: (21, 14).
    restype_atom14_is_ambiguous = protein["all_atom_mask"].new_zeros((21, 14))
    for resname, swap in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
722
723
724
725
726
727
728
729
730
731
732
        for atom_name1, atom_name2 in swap.items():
            restype = rc.restype_order[rc.restype_3to1[resname]]
            atom_idx1 = rc.restype_name_to_atom14_names[resname].index(
                atom_name1
            )
            atom_idx2 = rc.restype_name_to_atom14_names[resname].index(
                atom_name2
            )
            restype_atom14_is_ambiguous[restype, atom_idx1] = 1
            restype_atom14_is_ambiguous[restype, atom_idx2] = 1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
733
    # From this create an ambiguous_mask for the given sequence.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
734
735
736
737
    protein["atom14_atom_is_ambiguous"] = restype_atom14_is_ambiguous[
        protein["aatype"]
    ]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
738
739
740
741
742
743
744
745
746
747
    return protein


def atom37_to_frames(protein):
    aatype = protein["aatype"]
    all_atom_positions = protein["all_atom_positions"]
    all_atom_mask = protein["all_atom_mask"]

    batch_dims = len(aatype.shape[:-1])

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
748
749
750
751
    restype_rigidgroup_base_atom_names = np.full([21, 8, 3], "", dtype=object)
    restype_rigidgroup_base_atom_names[:, 0, :] = ["C", "CA", "N"]
    restype_rigidgroup_base_atom_names[:, 3, :] = ["CA", "C", "O"]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
752
753
754
    for restype, restype_letter in enumerate(rc.restypes):
        resname = rc.restype_1to3[restype_letter]
        for chi_idx in range(4):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
755
            if rc.chi_angles_mask[restype][chi_idx]:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
756
757
758
759
760
761
                names = rc.chi_angles_atoms[resname][chi_idx]
                restype_rigidgroup_base_atom_names[
                    restype, chi_idx + 4, :
                ] = names[1:]

    restype_rigidgroup_mask = all_atom_mask.new_zeros(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
762
        (*aatype.shape[:-1], 21, 8),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
763
764
765
    )
    restype_rigidgroup_mask[..., 0] = 1
    restype_rigidgroup_mask[..., 3] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
766
767
    restype_rigidgroup_mask[..., :20, 4:] = all_atom_mask.new_tensor(
        rc.chi_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
768
769
770
    )

    lookuptable = rc.atom_order.copy()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
771
    lookuptable[""] = 0
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
772
773
774
775
776
777
778
779
780
    lookup = np.vectorize(lambda x: lookuptable[x])
    restype_rigidgroup_base_atom37_idx = lookup(
        restype_rigidgroup_base_atom_names,
    )
    restype_rigidgroup_base_atom37_idx = aatype.new_tensor(
        restype_rigidgroup_base_atom37_idx,
    )
    restype_rigidgroup_base_atom37_idx = (
        restype_rigidgroup_base_atom37_idx.view(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
781
            *((1,) * batch_dims), *restype_rigidgroup_base_atom37_idx.shape
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
782
783
784
785
786
787
788
789
790
        )
    )

    residx_rigidgroup_base_atom37_idx = batched_gather(
        restype_rigidgroup_base_atom37_idx,
        aatype,
        dim=-3,
        no_batch_dims=batch_dims,
    )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
791

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
    base_atom_pos = batched_gather(
        all_atom_positions,
        residx_rigidgroup_base_atom37_idx,
        dim=-2,
        no_batch_dims=len(all_atom_positions.shape[:-2]),
    )

    gt_frames = T.from_3_points(
        p_neg_x_axis=base_atom_pos[..., 0, :],
        origin=base_atom_pos[..., 1, :],
        p_xy_plane=base_atom_pos[..., 2, :],
        eps=1e-8,
    )

    group_exists = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
807
808
809
        restype_rigidgroup_mask,
        aatype,
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
810
811
812
813
814
815
816
        no_batch_dims=batch_dims,
    )

    gt_atoms_exist = batched_gather(
        all_atom_mask,
        residx_rigidgroup_base_atom37_idx,
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
817
        no_batch_dims=len(all_atom_mask.shape[:-1]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
818
819
820
    )
    gt_exists = torch.min(gt_atoms_exist, dim=-1)[0] * group_exists

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
821
    rots = torch.eye(3, dtype=all_atom_mask.dtype, device=aatype.device)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
822
823
824
825
    rots = torch.tile(rots, (*((1,) * batch_dims), 8, 1, 1))
    rots[..., 0, 0, 0] = -1
    rots[..., 0, 2, 2] = -1

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
826
827
    gt_frames = gt_frames.compose(T(rots, None))

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
828
829
830
831
832
833
834
835
836
837
838
839
    restype_rigidgroup_is_ambiguous = all_atom_mask.new_zeros(
        *((1,) * batch_dims), 21, 8
    )
    restype_rigidgroup_rots = torch.eye(
        3, dtype=all_atom_mask.dtype, device=aatype.device
    )
    restype_rigidgroup_rots = torch.tile(
        restype_rigidgroup_rots,
        (*((1,) * batch_dims), 21, 8, 1, 1),
    )

    for resname, _ in rc.residue_atom_renaming_swaps.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
840
        restype = rc.restype_order[rc.restype_3to1[resname]]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
841
842
        chi_idx = int(sum(rc.chi_angles_mask[restype]) - 1)
        restype_rigidgroup_is_ambiguous[..., restype, chi_idx + 4] = 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
843
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 1, 1] = -1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        restype_rigidgroup_rots[..., restype, chi_idx + 4, 2, 2] = -1

    residx_rigidgroup_is_ambiguous = batched_gather(
        restype_rigidgroup_is_ambiguous,
        aatype,
        dim=-2,
        no_batch_dims=batch_dims,
    )

    residx_rigidgroup_ambiguity_rot = batched_gather(
        restype_rigidgroup_rots,
        aatype,
        dim=-4,
        no_batch_dims=batch_dims,
    )

    alt_gt_frames = gt_frames.compose(T(residx_rigidgroup_ambiguity_rot, None))

    gt_frames_tensor = gt_frames.to_4x4()
    alt_gt_frames_tensor = alt_gt_frames.to_4x4()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
865
866
867
868
869
    protein["rigidgroups_gt_frames"] = gt_frames_tensor
    protein["rigidgroups_gt_exists"] = gt_exists
    protein["rigidgroups_group_exists"] = group_exists
    protein["rigidgroups_group_is_ambiguous"] = residx_rigidgroup_is_ambiguous
    protein["rigidgroups_alt_gt_frames"] = alt_gt_frames_tensor
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
870
871
872
873
874
875

    return protein


def get_chi_atom_indices():
    """Returns atom indices needed to compute chi angles for all residue types.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
876

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
877
878
879
880
881
882
883
884
    Returns:
      A tensor of shape [residue_types=21, chis=4, atoms=4]. The residue types are
      in the order specified in rc.restypes + unknown residue type
      at the end. For chi angles which are not defined on the residue, the
      positions indices are by default set to 0.
    """
    chi_atom_indices = []
    for residue_name in rc.restypes:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
885
886
887
888
889
890
891
892
893
894
895
        residue_name = rc.restype_1to3[residue_name]
        residue_chi_angles = rc.chi_angles_atoms[residue_name]
        atom_indices = []
        for chi_angle in residue_chi_angles:
            atom_indices.append([rc.atom_order[atom] for atom in chi_angle])
        for _ in range(4 - len(atom_indices)):
            atom_indices.append(
                [0, 0, 0, 0]
            )  # For chi angles not defined on the AA.
        chi_atom_indices.append(atom_indices)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
896
    chi_atom_indices.append([[0, 0, 0, 0]] * 4)  # For UNKNOWN residue.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
897

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
898
899
900
901
902
903
    return chi_atom_indices


@curry1
def atom37_to_torsion_angles(
    protein,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
904
    prefix="",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
905
906
):
    """
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
    Convert coordinates to torsion angles.

    This function is extremely sensitive to floating point imprecisions
    and should be run with double precision whenever possible.

    Args:
        Dict containing:
            * (prefix)aatype:
                [*, N_res] residue indices
            * (prefix)all_atom_positions:
                [*, N_res, 37, 3] atom positions (in atom37
                format)
            * (prefix)all_atom_mask:
                [*, N_res, 37] atom position mask
    Returns:
        The same dictionary updated with the following features:

        "(prefix)torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Torsion angles
        "(prefix)alt_torsion_angles_sin_cos" ([*, N_res, 7, 2])
            Alternate torsion angles (accounting for 180-degree symmetry)
        "(prefix)torsion_angles_mask" ([*, N_res, 7])
            Torsion angles mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
930
931
932
933
934
935
    """
    aatype = protein[prefix + "aatype"]
    all_atom_positions = protein[prefix + "all_atom_positions"]
    all_atom_mask = protein[prefix + "all_atom_mask"]

    aatype = torch.clamp(aatype, max=20)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
936

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
937
938
939
940
941
942
943
944
945
946
947
    pad = all_atom_positions.new_zeros(
        [*all_atom_positions.shape[:-3], 1, 37, 3]
    )
    prev_all_atom_positions = torch.cat(
        [pad, all_atom_positions[..., :-1, :, :]], dim=-3
    )

    pad = all_atom_mask.new_zeros([*all_atom_mask.shape[:-2], 1, 37])
    prev_all_atom_mask = torch.cat([pad, all_atom_mask[..., :-1, :]], dim=-2)

    pre_omega_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
948
949
        [prev_all_atom_positions[..., 1:3, :], all_atom_positions[..., :2, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
950
951
    )
    phi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
952
953
        [prev_all_atom_positions[..., 2:3, :], all_atom_positions[..., :3, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
954
955
    )
    psi_atom_pos = torch.cat(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
956
957
        [all_atom_positions[..., :3, :], all_atom_positions[..., 4:5, :]],
        dim=-2,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
958
959
    )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
960
961
962
963
964
    pre_omega_mask = torch.prod(
        prev_all_atom_mask[..., 1:3], dim=-1
    ) * torch.prod(all_atom_mask[..., :2], dim=-1)
    phi_mask = prev_all_atom_mask[..., 2] * torch.prod(
        all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
965
966
    )
    psi_mask = (
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
967
968
        torch.prod(all_atom_mask[..., :3], dim=-1, dtype=all_atom_mask.dtype)
        * all_atom_mask[..., 4]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
969
970
971
972
973
974
975
976
977
978
979
980
    )

    chi_atom_indices = torch.as_tensor(
        get_chi_atom_indices(), device=aatype.device
    )

    atom_indices = chi_atom_indices[..., aatype, :, :]
    chis_atom_pos = batched_gather(
        all_atom_positions, atom_indices, -2, len(atom_indices.shape[:-2])
    )

    chi_angles_mask = list(rc.chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
981
    chi_angles_mask.append([0.0, 0.0, 0.0, 0.0])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
982
    chi_angles_mask = all_atom_mask.new_tensor(chi_angles_mask)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
983

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
984
985
986
    chis_mask = chi_angles_mask[aatype, :]

    chi_angle_atoms_mask = batched_gather(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
987
988
989
990
        all_atom_mask,
        atom_indices,
        dim=-1,
        no_batch_dims=len(atom_indices.shape[:-2]),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
991
992
993
994
995
996
997
998
999
1000
1001
1002
    )
    chi_angle_atoms_mask = torch.prod(
        chi_angle_atoms_mask, dim=-1, dtype=chi_angle_atoms_mask.dtype
    )
    chis_mask = chis_mask * chi_angle_atoms_mask

    torsions_atom_pos = torch.cat(
        [
            pre_omega_atom_pos[..., None, :, :],
            phi_atom_pos[..., None, :, :],
            psi_atom_pos[..., None, :, :],
            chis_atom_pos,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1003
1004
        ],
        dim=-3,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1005
1006
1007
1008
1009
1010
1011
1012
    )

    torsion_angles_mask = torch.cat(
        [
            pre_omega_mask[..., None],
            phi_mask[..., None],
            psi_mask[..., None],
            chis_mask,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1013
1014
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
    )

    torsion_frames = T.from_3_points(
        torsions_atom_pos[..., 1, :],
        torsions_atom_pos[..., 2, :],
        torsions_atom_pos[..., 0, :],
        eps=1e-8,
    )

    fourth_atom_rel_pos = torsion_frames.invert().apply(
        torsions_atom_pos[..., 3, :]
    )

    torsion_angles_sin_cos = torch.stack(
        [fourth_atom_rel_pos[..., 2], fourth_atom_rel_pos[..., 1]], dim=-1
    )

    denom = torch.sqrt(
        torch.sum(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1034
1035
1036
1037
1038
1039
            torch.square(torsion_angles_sin_cos),
            dim=-1,
            dtype=torsion_angles_sin_cos.dtype,
            keepdims=True,
        )
        + 1e-8
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1040
1041
1042
1043
    )
    torsion_angles_sin_cos = torsion_angles_sin_cos / denom

    torsion_angles_sin_cos = torsion_angles_sin_cos * all_atom_mask.new_tensor(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1044
        [1.0, 1.0, -1.0, 1.0, 1.0, 1.0, 1.0],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1045
1046
1047
1048
1049
1050
1051
1052
1053
    )[((None,) * len(torsion_angles_sin_cos.shape[:-2])) + (slice(None), None)]

    chi_is_ambiguous = torsion_angles_sin_cos.new_tensor(
        rc.chi_pi_periodic,
    )[aatype, ...]

    mirror_torsion_angles = torch.cat(
        [
            all_atom_mask.new_ones(*aatype.shape, 3),
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1054
1055
1056
            1.0 - 2.0 * chi_is_ambiguous,
        ],
        dim=-1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1057
1058
1059
1060
1061
1062
1063
1064
1065
    )

    alt_torsion_angles_sin_cos = (
        torsion_angles_sin_cos * mirror_torsion_angles[..., None]
    )

    protein[prefix + "torsion_angles_sin_cos"] = torsion_angles_sin_cos
    protein[prefix + "alt_torsion_angles_sin_cos"] = alt_torsion_angles_sin_cos
    protein[prefix + "torsion_angles_mask"] = torsion_angles_mask
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1066

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1067
1068
1069
1070
    return protein


def get_backbone_frames(protein):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1071
1072
1073
1074
1075
    # TODO: Verify that this is correct
    protein["backbone_affine_tensor"] = protein["rigidgroups_gt_frames"][
        ..., 0, :, :
    ]
    protein["backbone_affine_mask"] = protein["rigidgroups_gt_exists"][..., 0]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091

    return protein


def get_chi_angles(protein):
    dtype = protein["all_atom_mask"].dtype
    protein["chi_angles_sin_cos"] = (
        protein["torsion_angles_sin_cos"][..., 3:, :]
    ).to(dtype)
    protein["chi_mask"] = protein["torsion_angles_mask"][..., 3:].to(dtype)

    return protein


@curry1
def random_crop_to_size(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1092
1093
1094
    protein,
    crop_size,
    max_templates,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1095
    shape_schema,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1096
1097
1098
    subsample_templates=False,
    seed=None,
    batch_mode="clamped",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1099
1100
):
    """Crop randomly to `crop_size`, or keep as is if shorter than that."""
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1101
1102
1103
    seq_length = protein["seq_length"]
    if "template_mask" in protein:
        num_templates = protein["template_mask"].shape[-1]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1104
    else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1105
1106
        num_templates = protein["aatype"].new_zeros((1,))

1107
    num_res_crop_size = min(seq_length.item(), crop_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1108
1109

    # We want each ensemble to be cropped the same way
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1110
1111
    g = torch.Generator(device=protein["seq_length"].device)
    if seed is not None:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1112
1113
1114
        g.manual_seed(seed)

    def _randint(lower, upper):
1115
        return torch.randint(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1116
                lower,
1117
                upper + 1,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1118
1119
1120
                (1,),
                device=protein["seq_length"].device,
                generator=g,
1121
        )[0].item()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1122

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1123
    if subsample_templates:
1124
        templates_crop_start = _randint(0, num_templates)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1125
        templates_select_indices = torch.randperm(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1126
            num_templates, device=protein["seq_length"].device, generator=g
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1127
1128
1129
1130
1131
1132
1133
        )
        num_templates_crop_size = min(
            num_templates - templates_crop_start, max_templates
        )
    else:
        templates_crop_start = 0
        num_templates_crop_size = num_templates
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1134

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1135
    n = seq_length - num_res_crop_size
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1136
    if batch_mode == "clamped":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1137
        right_anchor = n + 1
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1138
    elif batch_mode == "unclamped":
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1139
1140
1141
1142
1143
1144
1145
1146
        x = _randint(0, n)
        right_anchor = n - x + 1
    else:
        raise ValueError("Invalid batch mode")

    num_res_crop_start = _randint(0, right_anchor)

    for k, v in protein.items():
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1147
1148
        if k not in shape_schema or (
            "template" not in k and NUM_RES not in shape_schema[k]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1149
1150
        ):
            continue
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1151

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1152
        # randomly permute the templates before cropping them.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1153
        if k.startswith("template") and subsample_templates:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1154
            v = v[templates_select_indices]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1155

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1156
        slices = []
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1157
1158
1159
        for i, (dim_size, dim) in enumerate(zip(shape_schema[k], v.shape)):
            is_num_res = dim_size == NUM_RES
            if i == 0 and k.startswith("template"):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1160
1161
1162
1163
1164
1165
                crop_size = num_templates_crop_size
                crop_start = templates_crop_start
            else:
                crop_start = num_res_crop_start if is_num_res else 0
                crop_size = num_res_crop_size if is_num_res else dim
            slices.append(slice(crop_start, crop_start + crop_size))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1166
1167
1168
        protein[k] = v[slices]

    protein["seq_length"] = protein["seq_length"].new_tensor(num_res_crop_size)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
1169
    return protein