data_modules.py 25.9 KB
Newer Older
1
import copy
2
3
4
5
from functools import partial
import json
import logging
import os
6
import pickle
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
7
from typing import Optional, Sequence, List, Any
8
9

import ml_collections as mlc
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
10
import numpy as np
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler

from openfold.data import (
    data_pipeline,
    feature_pipeline,
    mmcif_parsing,
    templates,
)
from openfold.utils.tensor_utils import tensor_tree_map, dict_multimap


class OpenFoldSingleDataset(torch.utils.data.Dataset):
    def __init__(self,
        data_dir: str,
        alignment_dir: str, 
        template_mmcif_dir: str,
        max_template_date: str,
        config: mlc.ConfigDict,
        kalign_binary_path: str = '/usr/bin/kalign',
        max_template_hits: int = 4,
33
        obsolete_pdbs_file_path: Optional[str] = None,
34
        template_release_dates_cache_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
35
        shuffle_top_k_prefiltered: Optional[int] = None,
36
        treat_pdb_as_distillation: bool = True,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
37
        mapping_path: Optional[str] = None,
38
        mode: str = "train", 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
39
        _output_raw: bool = False,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
40
41
        _structure_index: Optional[Any] = None,
        _alignment_index: Optional[Any] = None,
42
43
44
45
46
47
48
49
50
51
52
    ):
        """
            Args:
                data_dir:
                    A path to a directory containing mmCIF files (in train
                    mode) or FASTA files (in inference mode).
                alignment_dir:
                    A path to a directory containing only data in the format 
                    output by an AlignmentRunner 
                    (defined in openfold.features.alignment_runner).
                    I.e. a directory of directories named {PDB_ID}_{CHAIN_ID}
53
54
                    or simply {PDB_ID}, each containing .a3m, .sto, and .hhr
                    files.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
55
56
                template_mmcif_dir:
                    Path to a directory containing template mmCIF files.
57
58
                config:
                    A dataset config object. See openfold.config
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
59
60
61
62
63
64
65
66
                kalign_binary_path:
                    Path to kalign binary.
                max_template_hits:
                    An upper bound on how many templates are considered. During
                    training, the templates ultimately used are subsampled
                    from this total quantity.
                template_release_dates_cache_path:
                    Path to the output of scripts/generate_mmcif_cache.
67
68
                obsolete_pdbs_file_path:
                    Path to the file containing replacements for obsolete PDBs.
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
69
70
71
72
73
                shuffle_top_k_prefiltered:
                    Whether to uniformly shuffle the top k template hits before
                    parsing max_template_hits of them. Can be used to
                    approximate DeepMind's training-time template subsampling
                    scheme much more performantly.
74
75
76
77
                treat_pdb_as_distillation:
                    Whether to assume that .pdb files in the data_dir are from
                    the self-distillation set (and should be subjected to
                    special distillation set preprocessing steps).
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
78
79
                mode:
                    "train", "val", or "predict"
80
81
82
83
84
        """
        super(OpenFoldSingleDataset, self).__init__()
        self.data_dir = data_dir
        self.alignment_dir = alignment_dir
        self.config = config
85
        self.treat_pdb_as_distillation = treat_pdb_as_distillation
86
        self.mode = mode
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
87
        self._output_raw = _output_raw
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
88
        self._structure_index = _structure_index
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
89
        self._alignment_index = _alignment_index
90

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
91
92
        self.supported_exts = [".cif", ".core", ".pdb"]

93
        valid_modes = ["train", "eval", "predict"]
94
95
96
97
98
99
        if(mode not in valid_modes):
            raise ValueError(f'mode must be one of {valid_modes}')

        if(template_release_dates_cache_path is None):
            logging.warning(
                "Template release dates cache does not exist. Remember to run "
100
                "scripts/generate_mmcif_cache.py before running OpenFold"
101
102
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
103
104
105
        if(_alignment_index is not None):
            self._chain_ids = list(_alignment_index.keys())
        elif(mapping_path is None):
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
106
107
108
109
            self._chain_ids = list(os.listdir(alignment_dir))
        else:
            with open(mapping_path, "r") as f:
                self._chain_ids = [l.strip() for l in f.readlines()]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
110
       
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
111
112
113
114
        self._chain_id_to_idx_dict = {
            chain: i for i, chain in enumerate(self._chain_ids)
        }

115
116
117
118
119
120
        template_featurizer = templates.TemplateHitFeaturizer(
            mmcif_dir=template_mmcif_dir,
            max_template_date=max_template_date,
            max_hits=max_template_hits,
            kalign_binary_path=kalign_binary_path,
            release_dates_path=template_release_dates_cache_path,
121
            obsolete_pdbs_path=obsolete_pdbs_file_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
122
            _shuffle_top_k_prefiltered=shuffle_top_k_prefiltered,
123
124
125
126
127
128
        )

        self.data_pipeline = data_pipeline.DataPipeline(
            template_featurizer=template_featurizer,
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
129
        if(not self._output_raw):
130
131
            self.feature_pipeline = feature_pipeline.FeaturePipeline(config) 

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
132
    def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        with open(path, 'r') as f:
            mmcif_string = f.read()

        mmcif_object = mmcif_parsing.parse(
            file_id=file_id, mmcif_string=mmcif_string
        )

        # Crash if an error is encountered. Any parsing errors should have
        # been dealt with at the alignment stage.
        if(mmcif_object.mmcif_object is None):
            raise list(mmcif_object.errors.values())[0]

        mmcif_object = mmcif_object.mmcif_object

        data = self.data_pipeline.process_mmcif(
            mmcif=mmcif_object,
            alignment_dir=alignment_dir,
            chain_id=chain_id,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
151
            _alignment_index=_alignment_index
152
153
154
        )

        return data
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
155
156
157
158
159
160
161

    def chain_id_to_idx(self, chain_id):
        return self._chain_id_to_idx_dict[chain_id]

    def idx_to_chain_id(self, idx):
        return self._chain_ids[idx]

162
    def __getitem__(self, idx):
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
163
        name = self.idx_to_chain_id(idx)
164
165
        alignment_dir = os.path.join(self.alignment_dir, name)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
166
167
168
169
170
        _alignment_index = None
        if(self._alignment_index is not None):
            alignment_dir = self.alignment_dir
            _alignment_index = self._alignment_index[name]

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
171
        if(self.mode == 'train' or self.mode == 'eval'):
172
173
174
175
176
177
178
            spl = name.rsplit('_', 1)
            if(len(spl) == 2):
                file_id, chain_id = spl
            else:
                file_id, = spl
                chain_id = None

Gustaf's avatar
Gustaf committed
179
            path = os.path.join(self.data_dir, file_id)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            structure_index_entry = None
            if(self._structure_index is not None):
                structure_index_entry = self._structure_index[name]
                assert(len(structure_index_entry["files"]) == 1)
                filename, _, _ = structure_index_entry["files"][0]
                ext = os.path.splitext(filename)[1]
            else:
                ext = None
                for e in self.supported_exts:
                    if(os.path.exists(path + e)):
                        ext = e
                        break

                if(ext is None):
                    raise ValueError("Invalid file type")

            path += ext
            if(ext == ".cif"):
198
                data = self._parse_mmcif(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
199
                    path, file_id, chain_id, alignment_dir, _alignment_index,
Gustaf's avatar
Gustaf committed
200
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
201
            elif(ext == ".core"):
Gustaf's avatar
Gustaf committed
202
                data = self.data_pipeline.process_core(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
203
                    path, alignment_dir, _alignment_index,
204
                )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
205
            elif(ext == ".pdb"):
206
                data = self.data_pipeline.process_pdb(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
207
                    pdb_path=path,
208
209
210
                    alignment_dir=alignment_dir,
                    is_distillation=self.treat_pdb_as_distillation,
                    chain_id=chain_id,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
211
                    _structure_index=self._structure_index[name],
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
212
                    _alignment_index=_alignment_index,
213
                )
214
            else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
215
               raise ValueError("Extension branch missing") 
216
217
218
        else:
            path = os.path.join(name, name + ".fasta")
            data = self.data_pipeline.process_fasta(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
219
                fasta_path=path,
220
                alignment_dir=alignment_dir,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
221
                _alignment_index=_alignment_index,
222
223
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
224
        if(self._output_raw):
225
226
227
            return data

        feats = self.feature_pipeline.process_features(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
228
            data, self.mode 
229
230
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
231
232
        feats["batch_idx"] = torch.tensor([idx for _ in range(feats["aatype"].shape[-1])], dtype=torch.int64, device=feats["aatype"].device)

233
234
235
        return feats

    def __len__(self):
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
236
237
238
        return len(self._chain_ids) 


Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
239
def deterministic_train_filter(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
240
    chain_data_cache_entry: Any,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
241
242
243
244
    max_resolution: float = 9.,
    max_single_aa_prop: float = 0.8,
) -> bool:
    # Hard filters
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
245
    resolution = chain_data_cache_entry.get("resolution", None)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
246
247
248
    if(resolution is not None and resolution > max_resolution):
        return False

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
249
    seq = chain_data_cache_entry["seq"]
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
250
251
252
253
254
255
256
257
258
    counts = {}
    for aa in seq:
        counts.setdefault(aa, 0)
        counts[aa] += 1
    largest_aa_count = max(counts.values())
    largest_single_aa_prop = largest_aa_count / len(seq)
    if(largest_single_aa_prop > max_single_aa_prop):
        return False

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
259
260
261
262
    return True


def get_stochastic_train_filter_prob(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
263
    chain_data_cache_entry: Any,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
264
) -> List[float]:
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
265
266
267
    # Stochastic filters
    probabilities = []
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
268
    cluster_size = chain_data_cache_entry.get("cluster_size", None)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
269
270
271
    if(cluster_size is not None and cluster_size > 0):
        probabilities.append(1 / cluster_size)
    
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
272
    chain_length = len(chain_data_cache_entry["seq"])
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
273
    probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
274

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
275
276
277
278
    # Risk of underflow here?
    out = 1
    for p in probabilities:
        out *= p
279

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
280
    return out
281
282


Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
283
class OpenFoldDataset(torch.utils.data.Dataset):
284
    """
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
285
286
287
288
        Implements the stochastic filters applied during AlphaFold's training.
        Because samples are selected from constituent datasets randomly, the
        length of an OpenFoldFilteredDataset is arbitrary. Samples are selected
        and filtered once at initialization.
289
290
291
292
293
    """
    def __init__(self,
        datasets: Sequence[OpenFoldSingleDataset],
        probabilities: Sequence[int],
        epoch_len: int,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
294
        chain_data_cache_paths: List[str],
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
295
296
        generator: torch.Generator = None,
        _roll_at_init: bool = True,
297
298
    ):
        self.datasets = datasets
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
299
        self.probabilities = probabilities
300
        self.epoch_len = epoch_len
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
301
302
        self.generator = generator
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
303
304
        self.chain_data_caches = []
        for path in chain_data_cache_paths:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
305
            with open(path, "r") as fp:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
306
                self.chain_data_caches.append(json.load(fp))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
307

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
308
309
310
311
312
313
314
315
316
317
318
319
        def looped_shuffled_dataset_idx(dataset_len):
            while True:
                # Uniformly shuffle each dataset's indices
                weights = [1. for _ in range(dataset_len)]
                shuf = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=dataset_len,
                    replacement=False,
                    generator=self.generator,
                )
                for idx in shuf:
                    yield idx
320

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
321
322
323
324
        def looped_samples(dataset_idx):
            max_cache_len = int(epoch_len * probabilities[dataset_idx])
            dataset = self.datasets[dataset_idx]
            idx_iter = looped_shuffled_dataset_idx(len(dataset))
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
325
            chain_data_cache = self.chain_data_caches[dataset_idx]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
326
327
328
329
330
331
            while True:
                weights = []
                idx = []
                for _ in range(max_cache_len):
                    candidate_idx = next(idx_iter)
                    chain_id = dataset.idx_to_chain_id(candidate_idx)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
332
333
                    chain_data_cache_entry = chain_data_cache[chain_id]
                    if(not deterministic_train_filter(chain_data_cache_entry)):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
334
335
336
                        continue

                    p = get_stochastic_train_filter_prob(
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
337
                        chain_data_cache_entry,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
338
339
340
341
342
343
344
345
346
347
                    )
                    weights.append([1. - p, p])
                    idx.append(candidate_idx)

                samples = torch.multinomial(
                    torch.tensor(weights),
                    num_samples=1,
                    generator=self.generator,
                )
                samples = samples.squeeze()
348

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
349
350
351
352
353
354
                cache = [i for i, s in zip(idx, samples) if s]

                for datapoint_idx in cache:
                    yield datapoint_idx

        self._samples = [looped_samples(i) for i in range(len(self.datasets))]
355

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
356
357
358
359
360
361
        if(_roll_at_init):
            self.reroll()

    def __getitem__(self, idx):
        dataset_idx, datapoint_idx = self.datapoints[idx]
        return self.datasets[dataset_idx][datapoint_idx]
362
363
364
365

    def __len__(self):
        return self.epoch_len

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
366
367
368
369
370
371
372
373
374
375
    def reroll(self):
        dataset_choices = torch.multinomial(
            torch.tensor(self.probabilities),
            num_samples=self.epoch_len,
            replacement=True,
            generator=self.generator,
        )

        self.datapoints = []
        for dataset_idx in dataset_choices:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
376
377
            samples = self._samples[dataset_idx]
            datapoint_idx = next(samples)
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
378
379
            self.datapoints.append((dataset_idx, datapoint_idx))

380
381

class OpenFoldBatchCollator:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
382
    def __call__(self, prots):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
383
        stack_fn = partial(torch.stack, dim=0)
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
384
        return dict_multimap(stack_fn, prots) 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
385
386
387
388
389
390
391
392
393
394


class OpenFoldDataLoader(torch.utils.data.DataLoader):
    def __init__(self, *args, config, stage="train", generator=None, **kwargs):
        super().__init__(*args, **kwargs)
        self.config = config
        self.stage = stage    

        if(generator is None):
            generator = torch.Generator()
395
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
396
397
398
        self.generator = generator
        self._prep_batch_properties_probs()

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
399
400
401
402
403
404
405
406
407
    def _prep_batch_properties_probs(self):
        keyed_probs = []
        stage_cfg = self.config[self.stage]

        max_iters = self.config.common.max_recycling_iters
        if(stage_cfg.supervised):
            clamp_prob = self.config.supervised.clamp_prob
            keyed_probs.append(
                ("use_clamped_fape", [1 - clamp_prob, clamp_prob])
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
408
            )
Gustaf Ahdritz's avatar
Merge  
Gustaf Ahdritz committed
409
        
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
410
411
412
413
        if(stage_cfg.uniform_recycling):
            recycling_probs = [
                1. / (max_iters + 1) for _ in range(max_iters + 1)
            ]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
414
415
416
417
418
        else:
            recycling_probs = [
                0. for _ in range(max_iters + 1)
            ]
            recycling_probs[-1] = 1.
Gustaf Ahdritz's avatar
Merge  
Gustaf Ahdritz committed
419
        
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
420
421
422
        keyed_probs.append(
            ("no_recycling_iters", recycling_probs)
        )
423

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
424
425
426
427
428
429
430
431
432
        keys, probs = zip(*keyed_probs)
        max_len = max([len(p) for p in probs])
        padding = [[0.] * (max_len - len(p)) for p in probs] 
        
        self.prop_keys = keys
        self.prop_probs_tensor = torch.tensor(
            [p + pad for p, pad in zip(probs, padding)],
            dtype=torch.float32,
        )
433

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
434
    def _add_batch_properties(self, batch):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
435
436
437
438
        samples = torch.multinomial(
            self.prop_probs_tensor,
            num_samples=1, # 1 per row
            replacement=True,
439
            generator=self.generator
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
440
441
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
442
443
444
445
        aatype = batch["aatype"]
        batch_dims = aatype.shape[:-2]
        recycling_dim = aatype.shape[-1]
        no_recycling = recycling_dim
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
446
        for i, key in enumerate(self.prop_keys):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
447
448
449
450
451
            sample = int(samples[i][0])
            sample_tensor = torch.tensor(
                sample, 
                device=aatype.device, 
                requires_grad=False
452
            )
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
453
454
455
456
457
458
459
460
            orig_shape = sample_tensor.shape
            sample_tensor = sample_tensor.view(
                (1,) * len(batch_dims) + sample_tensor.shape + (1,)
            )
            sample_tensor = sample_tensor.expand(
                batch_dims + orig_shape + (recycling_dim,)
            )
            batch[key] = sample_tensor
461

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
            if(key == "no_recycling_iters"):
                no_recycling = sample 
        
        resample_recycling = lambda t: t[..., :no_recycling + 1]
        batch = tensor_tree_map(resample_recycling, batch)

        return batch

    def __iter__(self):
        it = super().__iter__()

        def _batch_prop_gen(iterator):
            for batch in iterator:
                yield self._add_batch_properties(batch)

        return _batch_prop_gen(it)
478
479
480
481
482
483
484
485
486


class OpenFoldDataModule(pl.LightningDataModule):
    def __init__(self,
        config: mlc.ConfigDict,
        template_mmcif_dir: str,
        max_template_date: str,
        train_data_dir: Optional[str] = None,
        train_alignment_dir: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
487
        train_chain_data_cache_path: Optional[str] = None,
488
489
        distillation_data_dir: Optional[str] = None,
        distillation_alignment_dir: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
490
        distillation_chain_data_cache_path: Optional[str] = None,
491
492
493
494
495
496
497
        val_data_dir: Optional[str] = None,
        val_alignment_dir: Optional[str] = None,
        predict_data_dir: Optional[str] = None,
        predict_alignment_dir: Optional[str] = None,
        kalign_binary_path: str = '/usr/bin/kalign',
        train_mapping_path: Optional[str] = None,
        distillation_mapping_path: Optional[str] = None,
498
        obsolete_pdbs_file_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
499
500
        template_release_dates_cache_path: Optional[str] = None,
        batch_seed: Optional[int] = None,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
501
        train_epoch_len: int = 50000, 
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
502
        _distillation_structure_index_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
503
        _alignment_index_path: Optional[str] = None,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
504
        _distillation_alignment_index_path: Optional[str] = None,
505
506
507
508
509
510
511
512
513
        **kwargs
    ):
        super(OpenFoldDataModule, self).__init__()

        self.config = config
        self.template_mmcif_dir = template_mmcif_dir
        self.max_template_date = max_template_date
        self.train_data_dir = train_data_dir
        self.train_alignment_dir = train_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
514
        self.train_chain_data_cache_path = train_chain_data_cache_path
515
516
        self.distillation_data_dir = distillation_data_dir
        self.distillation_alignment_dir = distillation_alignment_dir
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
517
518
        self.distillation_chain_data_cache_path = (
            distillation_chain_data_cache_path
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
519
        )
520
521
522
523
524
525
526
527
528
529
        self.val_data_dir = val_data_dir
        self.val_alignment_dir = val_alignment_dir
        self.predict_data_dir = predict_data_dir
        self.predict_alignment_dir = predict_alignment_dir
        self.kalign_binary_path = kalign_binary_path
        self.train_mapping_path = train_mapping_path
        self.distillation_mapping_path = distillation_mapping_path
        self.template_release_dates_cache_path = (
            template_release_dates_cache_path
        )
530
        self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
531
        self.batch_seed = batch_seed
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
532
        self.train_epoch_len = train_epoch_len
533
534
535
536
537
538
539
540
541

        if(self.train_data_dir is None and self.predict_data_dir is None):
            raise ValueError(
                'At least one of train_data_dir or predict_data_dir must be '
                'specified'
            )

        self.training_mode = self.train_data_dir is not None

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
542
        if(self.training_mode and train_alignment_dir is None):
543
544
545
            raise ValueError(
                'In training mode, train_alignment_dir must be specified'
            )
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
546
        elif(not self.training_mode and predict_alignment_dir is None):
547
548
549
550
551
552
553
554
555
            raise ValueError(
                'In inference mode, predict_alignment_dir must be specified'
            )      
        elif(val_data_dir is not None and val_alignment_dir is None):
            raise ValueError(
                'If val_data_dir is specified, val_alignment_dir must '
                'be specified as well'
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
556
        # An ad-hoc measure for our particular filesystem restrictions
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
557
558
559
560
561
        self._distillation_structure_index = None
        if(_distillation_structure_index_path is not None):
            with open(_distillation_structure_index_path, "r") as fp:
                self._distillation_structure_index = json.load(fp)
        
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
562
563
564
565
566
        self._alignment_index = None
        if(_alignment_index_path is not None):
            with open(_alignment_index_path, "r") as fp:
                self._alignment_index = json.load(fp)

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
567
568
569
570
571
        self._distillation_alignment_index = None
        if(_distillation_alignment_index_path is not None):
            with open(_distillation_alignment_index_path, "r") as fp:
                self._distillation_alignment_index = json.load(fp)

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
572
    def setup(self):
573
574
575
576
577
578
579
580
        # Most of the arguments are the same for the three datasets 
        dataset_gen = partial(OpenFoldSingleDataset,
            template_mmcif_dir=self.template_mmcif_dir,
            max_template_date=self.max_template_date,
            config=self.config,
            kalign_binary_path=self.kalign_binary_path,
            template_release_dates_cache_path=
                self.template_release_dates_cache_path,
581
582
            obsolete_pdbs_file_path=
                self.obsolete_pdbs_file_path,
583
584
        )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
585
586
        if(self.training_mode):
            train_dataset = dataset_gen(
587
588
589
590
                data_dir=self.train_data_dir,
                alignment_dir=self.train_alignment_dir,
                mapping_path=self.train_mapping_path,
                max_template_hits=self.config.train.max_template_hits,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
591
592
                shuffle_top_k_prefiltered=
                    self.config.train.shuffle_top_k_prefiltered,
593
                treat_pdb_as_distillation=False,
594
                mode="train",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
595
                _alignment_index=self._alignment_index,
596
597
            )

Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
598
            distillation_dataset = None
599
600
601
602
603
            if(self.distillation_data_dir is not None):
                distillation_dataset = dataset_gen(
                    data_dir=self.distillation_data_dir,
                    alignment_dir=self.distillation_alignment_dir,
                    mapping_path=self.distillation_mapping_path,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
604
                    max_template_hits=self.config.train.max_template_hits,
605
                    treat_pdb_as_distillation=True,
606
                    mode="train",
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
607
608
                    _structure_index=self._distillation_structure_index,
                    _alignment_index=self._distillation_alignment_index,
609
610
611
                )

                d_prob = self.config.train.distillation_prob
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
612
613
614
615
           
            if(distillation_dataset is not None):
                datasets = [train_dataset, distillation_dataset]
                d_prob = self.config.train.distillation_prob
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
616
                probabilities = [1. - d_prob, d_prob]
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
617
618
619
                chain_data_cache_paths = [
                    self.train_chain_data_cache_path,
                    self.distillation_chain_data_cache_path,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
620
621
622
623
                ]
            else:
                datasets = [train_dataset]
                probabilities = [1.]   
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
624
625
                chain_data_cache_paths = [
                    self.train_chain_data_cache_path,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
626
627
                ]

628
629
630
631
632
            generator = None
            if(self.batch_seed is not None):
                generator = torch.Generator()
                generator = generator.manual_seed(self.batch_seed + 1)
            
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
633
634
635
636
            self.train_dataset = OpenFoldDataset(
                datasets=datasets,
                probabilities=probabilities,
                epoch_len=self.train_epoch_len,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
637
                chain_data_cache_paths=chain_data_cache_paths,
638
                generator=generator,
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
639
640
                _roll_at_init=False,
            )
641

642
643
    
            if(self.val_data_dir is not None):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
644
                self.eval_dataset = dataset_gen(
645
646
647
648
649
650
                    data_dir=self.val_data_dir,
                    alignment_dir=self.val_alignment_dir,
                    mapping_path=None,
                    max_template_hits=self.config.eval.max_template_hits,
                    mode="eval",
                )
651
            else:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
652
                self.eval_dataset = None
653
654
655
656
657
658
659
660
661
        else:           
            self.predict_dataset = dataset_gen(
                data_dir=self.predict_data_dir,
                alignment_dir=self.predict_alignment_dir,
                mapping_path=None,
                max_template_hits=self.config.predict.max_template_hits,
                mode="predict",
            )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
662
    def _gen_dataloader(self, stage):
663
        generator = torch.Generator()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
664
665
        if(self.batch_seed is not None):
            generator = generator.manual_seed(self.batch_seed)
666

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
667
668
669
        dataset = None
        if(stage == "train"):
            dataset = self.train_dataset
Gustaf Ahdritz's avatar
Fixes  
Gustaf Ahdritz committed
670
671
            # Filter the dataset, if necessary
            dataset.reroll()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
672
673
674
675
676
677
678
        elif(stage == "eval"):
            dataset = self.eval_dataset
        elif(stage == "predict"):
            dataset = self.predict_dataset
        else:
            raise ValueError("Invalid stage")

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
679
        batch_collator = OpenFoldBatchCollator()
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
680
681
682
683
684
685

        dl = OpenFoldDataLoader(
            dataset,
            config=self.config,
            stage=stage,
            generator=generator,
686
687
            batch_size=self.config.data_module.data_loaders.batch_size,
            num_workers=self.config.data_module.data_loaders.num_workers,
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
688
            collate_fn=batch_collator,
689
690
        )

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
691
        return dl
692

Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
693
694
695
696
697
698
    def train_dataloader(self):
        return self._gen_dataloader("train") 

    def val_dataloader(self):
        if(self.eval_dataset is not None):
            return self._gen_dataloader("eval")
699
        return None
700
701

    def predict_dataloader(self):
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
702
        return self._gen_dataloader("predict") 
703
704
705
706
707


class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, batch_path):
        with open(batch_path, "rb") as f:
Gustaf Ahdritz's avatar
Gustaf Ahdritz committed
708
            self.batch = pickle.load(f)
709
710
711
712
713
714
715
716
717

    def __getitem__(self, idx):
        return copy.deepcopy(self.batch)

    def __len__(self):
        return 1000


class DummyDataLoader(pl.LightningDataModule):
718
    def __init__(self, batch_path):
719
        super().__init__()
720
        self.dataset = DummyDataset(batch_path)
721
722
723

    def train_dataloader(self):
        return torch.utils.data.DataLoader(self.dataset)