data_modules.py 27 KB
Newer Older
1
import copy
2
3
4
5
from functools import partial
import json
import logging
import os
6
import pickle
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
7
from typing import Optional, Sequence, List, Any
8
9

import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
import numpy as np
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler

from openfold.data import (
    data_pipeline,
    feature_pipeline,
    mmcif_parsing,
    templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap


class OpenFoldSingleDataset(torch.utils.data.Dataset):
    def __init__(self,
        data_dir: str,
        alignment_dir: str, 
        template_mmcif_dir: str,
        max_template_date: str,
        config: mlc.ConfigDict,
Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
31
        chain_data_cache_path: Optional[str] = None,
32
33
        kalign_binary_path: str = '/usr/bin/kalign',
        max_template_hits: int = 4,
34
        obsolete_pdbs_file_path: Optional[str] = None,
35
        template_release_dates_cache_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
36
        shuffle_top_k_prefiltered: Optional[int] = None,
37
        treat_pdb_as_distillation: bool = True,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
38
        filter_path: Optional[str] = None,
39
        mode: str = "train", 
40
        alignment_index: Optional[Any] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
41
        _output_raw: bool = False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
42
        _structure_index: Optional[Any] = None,
43
44
45
46
47
48
49
50
51
52
53
    ):
        """
            Args:
                data_dir:
                    A path to a directory containing mmCIF files (in train
                    mode) or FASTA files (in inference mode).
                alignment_dir:
                    A path to a directory containing only data in the format 
                    output by an AlignmentRunner 
                    (defined in openfold.features.alignment_runner).
                    I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
54
55
                    or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
                    files.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
56
57
                template_mmcif_dir:
                    Path to a directory containing template mmCIF files.
58
59
                config:
                    A dataset config object. See openfold.config
Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
60
61
62
                chain_data_cache_path:
                    Path to cache of data_dir generated by
                    scripts/generate_chain_data_cache.py
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
63
64
65
66
67
68
69
70
                kalign_binary_path:
                    Path to kalign binary.
                max_template_hits:
                    An upper bound on how many templates are considered. During
                    training, the templates ultimately used are subsampled
                    from this total quantity.
                template_release_dates_cache_path:
                    Path to the output of scripts/generate_mmcif_cache.
71
72
                obsolete_pdbs_file_path:
                    Path to the file containing replacements for obsolete PDBs.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
73
74
75
76
77
                shuffle_top_k_prefiltered:
                    Whether to uniformly shuffle the top k template hits before
                    parsing max_template_hits of them. Can be used to
                    approximate DeepMind's training-time template subsampling
                    scheme much more performantly.
78
79
80
81
                treat_pdb_as_distillation:
                    Whether to assume that .pdb files in the data_dir are from
                    the self-distillation set (and should be subjected to
                    special distillation set preprocessing steps).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
82
83
                mode:
                    "train", "val", or "predict"
84
85
86
        """
        super(OpenFoldSingleDataset, self).__init__()
        self.data_dir = data_dir
87

Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
88
89
90
91
92
        self.chain_data_cache = None
        if chain_data_cache_path is not None:
            with open(chain_data_cache_path, "r") as fp:
                self.chain_data_cache = json.load(fp)
            assert isinstance(self.chain_data_cache, dict)
93

94
95
        self.alignment_dir = alignment_dir
        self.config = config
96
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
97
        self.mode = mode
98
        self.alignment_index = alignment_index
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
99
        self._output_raw = _output_raw
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
100
        self._structure_index = _structure_index
101

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
102
103
        self.supported_exts = [".cif", ".core", ".pdb"]

104
        valid_modes = ["train", "eval", "predict"]
105
106
107
108
109
110
        if(mode not in valid_modes):
            raise ValueError(f'mode must be one of {valid_modes}')

        if(template_release_dates_cache_path is None):
            logging.warning(
                "Template release dates cache does not exist. Remember to run "
111
                "scripts/generate_mmcif_cache.py before running OpenFold"
112
113
            )

114
115
        if(alignment_index is not None):
            self._chain_ids = list(alignment_index.keys())
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
116
        else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
117
            self._chain_ids = list(os.listdir(alignment_dir))
118

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
119
        if(filter_path is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
120
            with open(filter_path, "r") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
121
                chains_to_include = set([l.strip() for l in f.readlines()])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122

123
124
125
126
            self._chain_ids = [
                c for c in self._chain_ids if c in chains_to_include
            ]

Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
127
128
129
130
131
132
        if self.chain_data_cache is not None:
            # Filter to include only chains where we have structure data
            # (entries in chain_data_cache)
            original_chain_ids = self._chain_ids
            self._chain_ids = [
                c for c in self._chain_ids if c in self.chain_data_cache
133
            ]
Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
            if len(self._chain_ids) < len(original_chain_ids):
                missing = [
                    c for c in original_chain_ids
                    if c not in self.chain_data_cache
                ]
                max_to_print = 10
                missing_examples = ", ".join(missing[:max_to_print])
                if len(missing) > max_to_print:
                    missing_examples += ", ..."
                logging.warning(
                    "Removing %d alignment entries (%s) with no corresponding "
                    "entries in chain_data_cache (%s).",
                    len(missing),
                    missing_examples,
                    chain_data_cache_path)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
149
       
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
150
151
152
153
        self._chain_id_to_idx_dict = {
            chain: i for i, chain in enumerate(self._chain_ids)
        }

154
155
156
157
158
159
        template_featurizer = templates.TemplateHitFeaturizer(
            mmcif_dir=template_mmcif_dir,
            max_template_date=max_template_date,
            max_hits=max_template_hits,
            kalign_binary_path=kalign_binary_path,
            release_dates_path=template_release_dates_cache_path,
160
            obsolete_pdbs_path=obsolete_pdbs_file_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
161
            _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
162
163
164
165
166
167
        )

        self.data_pipeline = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
168
        if(not self._output_raw):
169
170
            self.feature_pipeline = feature_pipeline.FeaturePipeline(config) 

171
    def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
        with open(path, 'r') as f:
            mmcif_string = f.read()

        mmcif_object = mmcif_parsing.parse(
            file_id=file_id, mmcif_string=mmcif_string
        )

        # Crash if an error is encountered. Any parsing errors should have
        # been dealt with at the alignment stage.
        if(mmcif_object.mmcif_object is None):
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            chain_id=chain_id,
190
            alignment_index=alignment_index
191
192
193
        )

        return data
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
194
195
196
197
198
199
200

    def chain_id_to_idx(self, chain_id):
        return self._chain_id_to_idx_dict[chain_id]

    def idx_to_chain_id(self, idx):
        return self._chain_ids[idx]

201
    def __getitem__(self, idx):
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
202
        name = self.idx_to_chain_id(idx)
203
204
        alignment_dir = os.path.join(self.alignment_dir, name)

205
206
        alignment_index = None
        if(self.alignment_index is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
            alignment_dir = self.alignment_dir
208
            alignment_index = self.alignment_index[name]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
209

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
210
        if(self.mode == 'train' or self.mode == 'eval'):
211
212
213
214
215
216
217
            spl = name.rsplit('_', 1)
            if(len(spl) == 2):
                file_id, chain_id = spl
            else:
                file_id, = spl
                chain_id = None

Gustaf's avatar
Gustaf committed
218
            path = os.path.join(self.data_dir, file_id)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
            structure_index_entry = None
            if(self._structure_index is not None):
                structure_index_entry = self._structure_index[name]
                assert(len(structure_index_entry["files"]) == 1)
                filename, _, _ = structure_index_entry["files"][0]
                ext = os.path.splitext(filename)[1]
            else:
                ext = None
                for e in self.supported_exts:
                    if(os.path.exists(path + e)):
                        ext = e
                        break

                if(ext is None):
                    raise ValueError("Invalid file type")

            path += ext
            if(ext == ".cif"):
237
                data = self._parse_mmcif(
238
                    path, file_id, chain_id, alignment_dir, alignment_index,
Gustaf's avatar
Gustaf committed
239
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
            elif(ext == ".core"):
Gustaf's avatar
Gustaf committed
241
                data = self.data_pipeline.process_core(
242
                    path, alignment_dir, alignment_index,
243
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
244
            elif(ext == ".pdb"):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
246
247
                structure_index = None
                if(self._structure_index is not None):
                    structure_index = self._structure_index[name]
248
                data = self.data_pipeline.process_pdb(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
                    pdb_path=path,
250
251
252
                    alignment_dir=alignment_dir,
                    is_distillation=self.treat_pdb_as_distillation,
                    chain_id=chain_id,
253
                    alignment_index=alignment_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
254
                    _structure_index=structure_index,
255
                )
256
            else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
257
               raise ValueError("Extension branch missing") 
258
259
260
        else:
            path = os.path.join(name, name + ".fasta")
            data = self.data_pipeline.process_fasta(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
261
                fasta_path=path,
262
                alignment_dir=alignment_dir,
263
                alignment_index=alignment_index,
264
265
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
266
        if(self._output_raw):
267
268
269
            return data

        feats = self.feature_pipeline.process_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
270
            data, self.mode 
271
272
        )

273
274
275
276
        feats["batch_idx"] = torch.tensor(
            [idx for _ in range(feats["aatype"].shape[-1])],
            dtype=torch.int64,
            device=feats["aatype"].device)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
277

278
279
280
        return feats

    def __len__(self):
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
281
282
283
        return len(self._chain_ids) 


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
284
def deterministic_train_filter(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
285
    chain_data_cache_entry: Any,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
286
287
288
289
    max_resolution: float = 9.,
    max_single_aa_prop: float = 0.8,
) -> bool:
    # Hard filters
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
290
    resolution = chain_data_cache_entry.get("resolution", None)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
291
292
293
    if(resolution is not None and resolution > max_resolution):
        return False

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
294
    seq = chain_data_cache_entry["seq"]
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
295
296
297
298
299
300
301
302
303
    counts = {}
    for aa in seq:
        counts.setdefault(aa, 0)
        counts[aa] += 1
    largest_aa_count = max(counts.values())
    largest_single_aa_prop = largest_aa_count / len(seq)
    if(largest_single_aa_prop > max_single_aa_prop):
        return False

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
304
305
306
307
    return True


def get_stochastic_train_filter_prob(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
308
    chain_data_cache_entry: Any,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
309
) -> List[float]:
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
310
311
312
    # Stochastic filters
    probabilities = []
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
313
    cluster_size = chain_data_cache_entry.get("cluster_size", None)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
314
315
316
    if(cluster_size is not None and cluster_size > 0):
        probabilities.append(1 / cluster_size)
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
317
    chain_length = len(chain_data_cache_entry["seq"])
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
318
    probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
319

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
320
321
322
323
    # Risk of underflow here?
    out = 1
    for p in probabilities:
        out *= p
324

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
    return out
326
327


Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
328
class OpenFoldDataset(torch.utils.data.Dataset):
329
    """
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
330
331
332
333
        Implements the stochastic filters applied during AlphaFold's training.
        Because samples are selected from constituent datasets randomly, the
        length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
        and filtered once at initialization.
334
335
336
    """
    def __init__(self,
        datasets: Sequence[OpenFoldSingleDataset],
337
        probabilities: Sequence[float],
338
        epoch_len: int,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
339
340
        generator: torch.Generator = None,
        _roll_at_init: bool = True,
341
342
    ):
        self.datasets = datasets
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
343
        self.probabilities = probabilities
344
        self.epoch_len = epoch_len
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
345
        self.generator = generator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
346

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
347
348
349
350
351
352
353
354
355
356
357
358
        def looped_shuffled_dataset_idx(dataset_len):
            while True:
                # Uniformly shuffle each dataset's indices
                weights = [1. for _ in range(dataset_len)]
                shuf = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=dataset_len,
                    replacement=False,
                    generator=self.generator,
                )
                for idx in shuf:
                    yield idx
359

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
360
361
362
363
        def looped_samples(dataset_idx):
            max_cache_len = int(epoch_len * probabilities[dataset_idx])
            dataset = self.datasets[dataset_idx]
            idx_iter = looped_shuffled_dataset_idx(len(dataset))
364
            chain_data_cache = dataset.chain_data_cache
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
365
366
367
368
369
370
            while True:
                weights = []
                idx = []
                for _ in range(max_cache_len):
                    candidate_idx = next(idx_iter)
                    chain_id = dataset.idx_to_chain_id(candidate_idx)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
371
372
                    chain_data_cache_entry = chain_data_cache[chain_id]
                    if(not deterministic_train_filter(chain_data_cache_entry)):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
373
374
375
                        continue

                    p = get_stochastic_train_filter_prob(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
                        chain_data_cache_entry,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
377
378
379
380
381
382
383
384
385
386
                    )
                    weights.append([1. - p, p])
                    idx.append(candidate_idx)

                samples = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=1,
                    generator=self.generator,
                )
                samples = samples.squeeze()
387

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
388
389
390
391
392
393
                cache = [i for i, s in zip(idx, samples) if s]

                for datapoint_idx in cache:
                    yield datapoint_idx

        self._samples = [looped_samples(i) for i in range(len(self.datasets))]
394

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
395
396
397
398
399
400
        if(_roll_at_init):
            self.reroll()

    def __getitem__(self, idx):
        dataset_idx, datapoint_idx = self.datapoints[idx]
        return self.datasets[dataset_idx][datapoint_idx]
401
402
403
404

    def __len__(self):
        return self.epoch_len

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
405
406
407
408
409
410
411
412
413
414
    def reroll(self):
        dataset_choices = torch.multinomial(
            torch.tensor(self.probabilities),
            num_samples=self.epoch_len,
            replacement=True,
            generator=self.generator,
        )

        self.datapoints = []
        for dataset_idx in dataset_choices:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
415
416
            samples = self._samples[dataset_idx]
            datapoint_idx = next(samples)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
417
418
            self.datapoints.append((dataset_idx, datapoint_idx))

419
420

class OpenFoldBatchCollator:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
421
    def __call__(self, prots):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
422
        stack_fn = partial(torch.stack, dim=0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
423
        return dict_multimap(stack_fn, prots) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
425
426
427
428
429
430
431
432
433


class OpenFoldDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, config, stage="train", generator=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.stage = stage    

        if(generator is None):
            generator = torch.Generator()
434
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
        self.generator = generator
        self._prep_batch_properties_probs()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
438
439
440
441
442
    def _prep_batch_properties_probs(self):
        keyed_probs = []
        stage_cfg = self.config[self.stage]

        max_iters = self.config.common.max_recycling_iters
Gustaf Ahdritz's avatar
Merge  
Gustaf Ahdritz committed
443
        
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
444
445
446
447
        if(stage_cfg.uniform_recycling):
            recycling_probs = [
                1. / (max_iters + 1) for _ in range(max_iters + 1)
            ]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
448
449
450
451
452
        else:
            recycling_probs = [
                0. for _ in range(max_iters + 1)
            ]
            recycling_probs[-1] = 1.
Gustaf Ahdritz's avatar
Merge  
Gustaf Ahdritz committed
453
        
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
454
455
456
        keyed_probs.append(
            ("no_recycling_iters", recycling_probs)
        )
457

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
458
459
460
461
462
463
464
465
466
        keys, probs = zip(*keyed_probs)
        max_len = max([len(p) for p in probs])
        padding = [[0.] * (max_len - len(p)) for p in probs] 
        
        self.prop_keys = keys
        self.prop_probs_tensor = torch.tensor(
            [p + pad for p, pad in zip(probs, padding)],
            dtype=torch.float32,
        )
467

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
468
    def _add_batch_properties(self, batch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
469
470
471
472
        samples = torch.multinomial(
            self.prop_probs_tensor,
            num_samples=1, # 1 per row
            replacement=True,
473
            generator=self.generator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
474
475
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
476
477
478
479
        aatype = batch["aatype"]
        batch_dims = aatype.shape[:-2]
        recycling_dim = aatype.shape[-1]
        no_recycling = recycling_dim
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
480
        for i, key in enumerate(self.prop_keys):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
481
482
483
484
485
            sample = int(samples[i][0])
            sample_tensor = torch.tensor(
                sample, 
                device=aatype.device, 
                requires_grad=False
486
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
488
489
490
491
492
493
494
            orig_shape = sample_tensor.shape
            sample_tensor = sample_tensor.view(
                (1,) * len(batch_dims) + sample_tensor.shape + (1,)
            )
            sample_tensor = sample_tensor.expand(
                batch_dims + orig_shape + (recycling_dim,)
            )
            batch[key] = sample_tensor
495

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
            if(key == "no_recycling_iters"):
                no_recycling = sample 
        
        resample_recycling = lambda t: t[..., :no_recycling + 1]
        batch = tensor_tree_map(resample_recycling, batch)

        return batch

    def __iter__(self):
        it = super().__iter__()

        def _batch_prop_gen(iterator):
            for batch in iterator:
                yield self._add_batch_properties(batch)

        return _batch_prop_gen(it)
512
513
514
515
516
517
518
519
520


class OpenFoldDataModule(pl.LightningDataModule):
    def __init__(self,
        config: mlc.ConfigDict,
        template_mmcif_dir: str,
        max_template_date: str,
        train_data_dir: Optional[str] = None,
        train_alignment_dir: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
521
        train_chain_data_cache_path: Optional[str] = None,
522
523
        distillation_data_dir: Optional[str] = None,
        distillation_alignment_dir: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
524
        distillation_chain_data_cache_path: Optional[str] = None,
525
526
527
528
529
        val_data_dir: Optional[str] = None,
        val_alignment_dir: Optional[str] = None,
        predict_data_dir: Optional[str] = None,
        predict_alignment_dir: Optional[str] = None,
        kalign_binary_path: str = '/usr/bin/kalign',
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
530
531
        train_filter_path: Optional[str] = None,
        distillation_filter_path: Optional[str] = None,
532
        obsolete_pdbs_file_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
533
534
        template_release_dates_cache_path: Optional[str] = None,
        batch_seed: Optional[int] = None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
535
        train_epoch_len: int = 50000, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
536
        _distillation_structure_index_path: Optional[str] = None,
537
538
        alignment_index_path: Optional[str] = None,
        distillation_alignment_index_path: Optional[str] = None,
539
540
541
542
543
544
545
546
547
        **kwargs
    ):
        super(OpenFoldDataModule, self).__init__()

        self.config = config
        self.template_mmcif_dir = template_mmcif_dir
        self.max_template_date = max_template_date
        self.train_data_dir = train_data_dir
        self.train_alignment_dir = train_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
548
        self.train_chain_data_cache_path = train_chain_data_cache_path
549
550
        self.distillation_data_dir = distillation_data_dir
        self.distillation_alignment_dir = distillation_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
551
552
        self.distillation_chain_data_cache_path = (
            distillation_chain_data_cache_path
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
553
        )
554
555
556
557
558
        self.val_data_dir = val_data_dir
        self.val_alignment_dir = val_alignment_dir
        self.predict_data_dir = predict_data_dir
        self.predict_alignment_dir = predict_alignment_dir
        self.kalign_binary_path = kalign_binary_path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
559
560
        self.train_filter_path = train_filter_path
        self.distillation_filter_path = distillation_filter_path
561
562
563
        self.template_release_dates_cache_path = (
            template_release_dates_cache_path
        )
564
        self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
565
        self.batch_seed = batch_seed
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
566
        self.train_epoch_len = train_epoch_len
567
568
569
570
571
572
573
574
575

        if(self.train_data_dir is None and self.predict_data_dir is None):
            raise ValueError(
                'At least one of train_data_dir or predict_data_dir must be '
                'specified'
            )

        self.training_mode = self.train_data_dir is not None

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
576
        if(self.training_mode and train_alignment_dir is None):
577
578
579
            raise ValueError(
                'In training mode, train_alignment_dir must be specified'
            )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
580
        elif(not self.training_mode and predict_alignment_dir is None):
581
582
583
584
585
586
587
588
589
            raise ValueError(
                'In inference mode, predict_alignment_dir must be specified'
            )      
        elif(val_data_dir is not None and val_alignment_dir is None):
            raise ValueError(
                'If val_data_dir is specified, val_alignment_dir must '
                'be specified as well'
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
590
        # An ad-hoc measure for our particular filesystem restrictions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
591
592
593
594
595
        self._distillation_structure_index = None
        if(_distillation_structure_index_path is not None):
            with open(_distillation_structure_index_path, "r") as fp:
                self._distillation_structure_index = json.load(fp)
        
596
597
598
599
        self.alignment_index = None
        if(alignment_index_path is not None):
            with open(alignment_index_path, "r") as fp:
                self.alignment_index = json.load(fp)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
600

601
602
603
604
        self.distillation_alignment_index = None
        if(distillation_alignment_index_path is not None):
            with open(distillation_alignment_index_path, "r") as fp:
                self.distillation_alignment_index = json.load(fp)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
605

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
606
    def setup(self):
607
608
609
610
611
612
613
614
        # Most of the arguments are the same for the three datasets 
        dataset_gen = partial(OpenFoldSingleDataset,
            template_mmcif_dir=self.template_mmcif_dir,
            max_template_date=self.max_template_date,
            config=self.config,
            kalign_binary_path=self.kalign_binary_path,
            template_release_dates_cache_path=
                self.template_release_dates_cache_path,
615
616
            obsolete_pdbs_file_path=
                self.obsolete_pdbs_file_path,
617
618
        )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
619
620
        if(self.training_mode):
            train_dataset = dataset_gen(
621
                data_dir=self.train_data_dir,
Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
622
                chain_data_cache_path=self.train_chain_data_cache_path,
623
                alignment_dir=self.train_alignment_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
                filter_path=self.train_filter_path,
625
                max_template_hits=self.config.train.max_template_hits,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
626
627
                shuffle_top_k_prefiltered=
                    self.config.train.shuffle_top_k_prefiltered,
628
                treat_pdb_as_distillation=False,
629
                mode="train",
630
                alignment_index=self.alignment_index,
631
632
            )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
633
            distillation_dataset = None
634
635
636
            if(self.distillation_data_dir is not None):
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
Tim O'Donnell's avatar
fix  
Tim O'Donnell committed
637
                    chain_data_cache_path=self.distillation_chain_data_cache_path,
638
                    alignment_dir=self.distillation_alignment_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
639
                    filter_path=self.distillation_filter_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
640
                    max_template_hits=self.config.train.max_template_hits,
641
                    treat_pdb_as_distillation=True,
642
                    mode="train",
643
                    alignment_index=self.distillation_alignment_index,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
644
                    _structure_index=self._distillation_structure_index,
645
646
647
                )

                d_prob = self.config.train.distillation_prob
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
648
649
650
651
           
            if(distillation_dataset is not None):
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
652
                probabilities = [1. - d_prob, d_prob]
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
653
654
            else:
                datasets = [train_dataset]
655
                probabilities = [1.]
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
656

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
657
            generator = None
658
659
            if(self.batch_seed is not None):
                generator = torch.Generator()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
660
                generator = generator.manual_seed(self.batch_seed + 1)
661
            
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
662
663
664
665
            self.train_dataset = OpenFoldDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
666
                generator=generator,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
667
668
                _roll_at_init=False,
            )
669
670
    
            if(self.val_data_dir is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
671
                self.eval_dataset = dataset_gen(
672
673
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
674
                    filter_path=None,
675
676
677
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                )
678
            else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
679
                self.eval_dataset = None
680
681
682
683
        else:           
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
684
                filter_path=None,
685
686
687
688
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
689
    def _gen_dataloader(self, stage):
690
        generator = torch.Generator()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
691
692
        if(self.batch_seed is not None):
            generator = generator.manual_seed(self.batch_seed)
693

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
694
695
696
        dataset = None
        if(stage == "train"):
            dataset = self.train_dataset
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
697
698
            # Filter the dataset, if necessary
            dataset.reroll()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
699
700
701
702
703
704
705
        elif(stage == "eval"):
            dataset = self.eval_dataset
        elif(stage == "predict"):
            dataset = self.predict_dataset
        else:
            raise ValueError("Invalid stage")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
706
        batch_collator = OpenFoldBatchCollator()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
707
708
709
710
711
712

        dl = OpenFoldDataLoader(
            dataset,
            config=self.config,
            stage=stage,
            generator=generator,
713
714
            batch_size=self.config.data_module.data_loaders.batch_size,
            num_workers=self.config.data_module.data_loaders.num_workers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
715
            collate_fn=batch_collator,
716
717
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
718
        return dl
719

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
720
721
722
723
724
725
    def train_dataloader(self):
        return self._gen_dataloader("train") 

    def val_dataloader(self):
        if(self.eval_dataset is not None):
            return self._gen_dataloader("eval")
726
        return None
727
728

    def predict_dataloader(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
729
        return self._gen_dataloader("predict") 
730
731
732
733
734


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, batch_path):
        with open(batch_path, "rb") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
735
            self.batch = pickle.load(f)
736
737
738
739
740
741
742
743
744

    def __getitem__(self, idx):
        return copy.deepcopy(self.batch)

    def __len__(self):
        return 1000


class DummyDataLoader(pl.LightningDataModule):
745
    def __init__(self, batch_path):
746
        super().__init__()
747
        self.dataset = DummyDataset(batch_path)
748
749
750

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset)