Unverified Commit 164f6777 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

modify data module for train (#116)

parent 6fbc402e
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,19 +13,14 @@ ...@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from functools import partial from functools import partial
import json import json
import logging import logging
import os import os
import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler
from fastfold.data import ( from fastfold.data import (
data_pipeline, data_pipeline,
...@@ -217,7 +213,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -217,7 +213,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return data return data
feats = self.feature_pipeline.process_features( feats = self.feature_pipeline.process_features(
data, self.mode data, self.mode
) )
return feats return feats
...@@ -380,8 +376,10 @@ class OpenFoldBatchCollator: ...@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot, self.stage prot, self.stage
) )
processed_prots.append(features) processed_prots.append(features)
# By this stack, the batch dimension is processed and added.
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, processed_prots)
...@@ -478,230 +476,156 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -478,230 +476,156 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it) return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule): def SetupTrainDataset(
def __init__(self, config: mlc.ConfigDict,
config: mlc.ConfigDict, template_mmcif_dir: str,
template_mmcif_dir: str, max_template_date: str,
max_template_date: str, train_data_dir: Optional[str] = None,
train_data_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_chain_data_cache_path: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_chain_data_cache_path: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None, val_data_dir: Optional[str] = None,
val_data_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, kalign_binary_path: str = '/usr/bin/kalign',
predict_data_dir: Optional[str] = None, train_mapping_path: Optional[str] = None,
predict_alignment_dir: Optional[str] = None, distillation_mapping_path: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', obsolete_pdbs_file_path: Optional[str] = None,
train_mapping_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, train_epoch_len: int = 50000,
obsolete_pdbs_file_path: Optional[str] = None, _alignment_index_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, **kwargs,
batch_seed: Optional[int] = None, ):
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None, if(train_data_dir is None or train_alignment_dir is None):
**kwargs raise ValueError(
): 'train_data_dir and train_alignment_dir must be specified'
super(OpenFoldDataModule, self).__init__() )
elif(val_data_dir is not None and val_alignment_dir is None):
self.config = config raise ValueError(
self.template_mmcif_dir = template_mmcif_dir 'If val_data_dir is specified, val_alignment_dir must '
self.max_template_date = max_template_date 'be specified as well'
self.train_data_dir = train_data_dir )
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path _alignment_index = None
self.distillation_data_dir = distillation_data_dir if(_alignment_index_path is not None):
self.distillation_alignment_dir = distillation_alignment_dir with open(_alignment_index_path, "r") as fp:
self.distillation_chain_data_cache_path = ( _alignment_index = json.load(fp)
distillation_chain_data_cache_path
) dataset_gen = partial(OpenFoldSingleDataset,
self.val_data_dir = val_data_dir template_mmcif_dir=template_mmcif_dir,
self.val_alignment_dir = val_alignment_dir max_template_date=max_template_date,
self.predict_data_dir = predict_data_dir config=config,
self.predict_alignment_dir = predict_alignment_dir kalign_binary_path=kalign_binary_path,
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
'If val_data_dir is specified, val_alignment_dir must '
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=
self.template_release_dates_cache_path, template_release_dates_cache_path,
obsolete_pdbs_file_path= obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path, obsolete_pdbs_file_path,
) )
if(self.training_mode): train_dataset = dataset_gen(
train_dataset = dataset_gen( data_dir=train_data_dir,
data_dir=self.train_data_dir, alignment_dir=train_alignment_dir,
alignment_dir=self.train_alignment_dir, mapping_path=train_mapping_path,
mapping_path=self.train_mapping_path, max_template_hits=config.train.max_template_hits,
max_template_hits=self.config.train.max_template_hits, shuffle_top_k_prefiltered=
shuffle_top_k_prefiltered= config.train.shuffle_top_k_prefiltered,
self.config.train.shuffle_top_k_prefiltered, treat_pdb_as_distillation=False,
treat_pdb_as_distillation=False, mode="train",
mode="train", _output_raw=True,
_output_raw=True, _alignment_index=_alignment_index,
_alignment_index=self._alignment_index, )
)
distillation_dataset = None
distillation_dataset = None if(distillation_data_dir is not None):
if(self.distillation_data_dir is not None): distillation_dataset = dataset_gen(
distillation_dataset = dataset_gen( data_dir=distillation_data_dir,
data_dir=self.distillation_data_dir, alignment_dir=distillation_alignment_dir,
alignment_dir=self.distillation_alignment_dir, mapping_path=distillation_mapping_path,
mapping_path=self.distillation_mapping_path, max_template_hits=config.train.max_template_hits,
max_template_hits=self.train.max_template_hits, treat_pdb_as_distillation=True,
treat_pdb_as_distillation=True, mode="train",
mode="train", _output_raw=True,
_output_raw=True, )
)
d_prob = self.config.train.distillation_prob d_prob = config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False,
)
if(self.val_data_dir is not None): if(distillation_dataset is not None):
self.eval_dataset = dataset_gen( datasets = [train_dataset, distillation_dataset]
data_dir=self.val_data_dir, d_prob = config.train.distillation_prob
alignment_dir=self.val_alignment_dir, probabilities = [1 - d_prob, d_prob]
mapping_path=None, chain_data_cache_paths = [
max_template_hits=self.config.eval.max_template_hits, train_chain_data_cache_path,
mode="eval", distillation_chain_data_cache_path,
_output_raw=True, ]
) else:
else: datasets = [train_dataset]
self.eval_dataset = None probabilities = [1.]
else: chain_data_cache_paths = [
self.predict_dataset = dataset_gen( train_chain_data_cache_path,
data_dir=self.predict_data_dir, ]
alignment_dir=self.predict_alignment_dir,
mapping_path=None, train_dataset = OpenFoldDataset(
max_template_hits=self.config.predict.max_template_hits, datasets=datasets,
mode="predict", probabilities=probabilities,
) epoch_len=train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
def _gen_dataloader(self, stage): _roll_at_init=False,
generator = torch.Generator() )
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed) if(val_data_dir is not None):
eval_dataset = dataset_gen(
dataset = None data_dir=val_data_dir,
if(stage == "train"): alignment_dir=val_alignment_dir,
dataset = self.train_dataset mapping_path=None,
max_template_hits=config.eval.max_template_hits,
# Filter the dataset, if necessary mode="eval",
dataset.reroll() _output_raw=True,
elif(stage == "eval"): )
dataset = self.eval_dataset else:
elif(stage == "predict"): eval_dataset = None
dataset = self.predict_dataset
else: return train_dataset, eval_dataset
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage) def TrainDataLoader(
config: mlc.ConfigDict,
dl = OpenFoldDataLoader( train_dataset: torch.utils.data.Dataset,
dataset, test_dataset: Optional[torch.utils.data.Dataset] = None,
config=self.config, batch_seed: Optional[int] = None,
stage=stage, ):
if not config.data_module.data_loaders.batch_size == 1:
raise ValueError("Only support batch size equals to 1")
generator = torch.Generator()
if(batch_seed is not None):
generator = generator.manual_seed(batch_seed)
train_batch_collator = OpenFoldBatchCollator(config, "train")
train_dataset.reroll()
train_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="train",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=train_batch_collator,
)
test_dataloader = None
if test_dataset is not None:
test_batch_collator = OpenFoldBatchCollator(config, "test")
test_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="test",
generator=generator, generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size, batch_size=config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers, num_workers=config.data_module.data_loaders.num_workers,
collate_fn=batch_collator, collate_fn=test_batch_collator,
) )
return dl return train_dataloader, test_dataloader
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
self.batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
class DummyDataLoader(pl.LightningDataModule):
def __init__(self, batch_path):
super().__init__()
self.dataset = DummyDataset(batch_path)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
...@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts): ...@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if type(v) is dict: if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v) new_dict[k] = dict_multimap(fn, all_v)
else: else:
new_dict[k] = fn(all_v) # when bs = 1, returns [...] rather than [1, ...]
new_dict[k] = fn(all_v) if len(all_v) > 1 else all_v[0]
return new_dict return new_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment