"src/lib/vscode:/vscode.git/clone" did not exist on "b6ab357e8cb8422b224cf72dfa80c6d3950de7c3"
Unverified Commit 164f6777 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

modify data module for train (#116)

parent 6fbc402e
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from functools import partial
import json
import logging
import os
import pickle
from typing import Optional, Sequence, List, Any
import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch
from torch.utils.data import RandomSampler
from fastfold.data import (
data_pipeline,
......@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot, self.stage
)
processed_prots.append(features)
# By this stack, the batch dimension is processed and added.
stack_fn = partial(torch.stack, dim=0)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return dict_multimap(stack_fn, processed_prots)
......@@ -478,8 +476,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule):
def __init__(self,
def SetupTrainDataset(
config: mlc.ConfigDict,
template_mmcif_dir: str,
max_template_date: str,
......@@ -491,60 +488,19 @@ class OpenFoldDataModule(pl.LightningDataModule):
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
**kwargs,
):
self.config = config
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
if(train_data_dir is None or train_alignment_dir is None):
raise ValueError(
'At least one of train_data_dir or predict_data_dir must be '
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
'train_data_dir and train_alignment_dir must be specified'
)
elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError(
......@@ -552,156 +508,124 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
_alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir,
max_template_date=self.max_template_date,
config=self.config,
kalign_binary_path=self.kalign_binary_path,
template_mmcif_dir=template_mmcif_dir,
max_template_date=max_template_date,
config=config,
kalign_binary_path=kalign_binary_path,
template_release_dates_cache_path=
self.template_release_dates_cache_path,
template_release_dates_cache_path,
obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path,
obsolete_pdbs_file_path,
)
if(self.training_mode):
train_dataset = dataset_gen(
data_dir=self.train_data_dir,
alignment_dir=self.train_alignment_dir,
mapping_path=self.train_mapping_path,
max_template_hits=self.config.train.max_template_hits,
data_dir=train_data_dir,
alignment_dir=train_alignment_dir,
mapping_path=train_mapping_path,
max_template_hits=config.train.max_template_hits,
shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered,
config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
_alignment_index=_alignment_index,
)
distillation_dataset = None
if(self.distillation_data_dir is not None):
if(distillation_data_dir is not None):
distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir,
alignment_dir=self.distillation_alignment_dir,
mapping_path=self.distillation_mapping_path,
max_template_hits=self.train.max_template_hits,
data_dir=distillation_data_dir,
alignment_dir=distillation_alignment_dir,
mapping_path=distillation_mapping_path,
max_template_hits=config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
_output_raw=True,
)
d_prob = self.config.train.distillation_prob
d_prob = config.train.distillation_prob
if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
d_prob = config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
train_chain_data_cache_path,
distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
train_chain_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
epoch_len=train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
data_dir=self.val_data_dir,
alignment_dir=self.val_alignment_dir,
if(val_data_dir is not None):
eval_dataset = dataset_gen(
data_dir=val_data_dir,
alignment_dir=val_alignment_dir,
mapping_path=None,
max_template_hits=self.config.eval.max_template_hits,
max_template_hits=config.eval.max_template_hits,
mode="eval",
_output_raw=True,
)
else:
self.eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
eval_dataset = None
def _gen_dataloader(self, stage):
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage)
dl = OpenFoldDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=batch_collator,
)
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
return train_dataset, eval_dataset
class DummyDataset(torch.utils.data.Dataset):
def __init__(self, batch_path):
with open(batch_path, "rb") as f:
self.batch = pickle.load(f)
def __getitem__(self, idx):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
def TrainDataLoader(
config: mlc.ConfigDict,
train_dataset: torch.utils.data.Dataset,
test_dataset: Optional[torch.utils.data.Dataset] = None,
batch_seed: Optional[int] = None,
):
if not config.data_module.data_loaders.batch_size == 1:
raise ValueError("Only support batch size equals to 1")
class DummyDataLoader(pl.LightningDataModule):
def __init__(self, batch_path):
super().__init__()
self.dataset = DummyDataset(batch_path)
generator = torch.Generator()
if(batch_seed is not None):
generator = generator.manual_seed(batch_seed)
train_batch_collator = OpenFoldBatchCollator(config, "train")
train_dataset.reroll()
train_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="train",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=train_batch_collator,
)
test_dataloader = None
if test_dataset is not None:
test_batch_collator = OpenFoldBatchCollator(config, "test")
test_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="test",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=test_batch_collator,
)
def train_dataloader(self):
return torch.utils.data.DataLoader(self.dataset)
return train_dataloader, test_dataloader
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited
#
......@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v)
else:
new_dict[k] = fn(all_v)
# when bs = 1, returns [...] rather than [1, ...]
new_dict[k] = fn(all_v) if len(all_v) > 1 else all_v[0]
return new_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment