"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "54c4e0761a5b9e5d102c1c0dface786cf489fdd7"
Unverified Commit 164f6777 authored by LuGY's avatar LuGY Committed by GitHub
Browse files

modify data module for train (#116)

parent 6fbc402e
# Copyright 2022 HPC-AI Tech Inc.
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -12,19 +13,14 @@ ...@@ -12,19 +13,14 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import copy
from functools import partial from functools import partial
import json import json
import logging import logging
import os import os
import pickle
from typing import Optional, Sequence, List, Any from typing import Optional, Sequence, List, Any
import ml_collections as mlc import ml_collections as mlc
import numpy as np
import pytorch_lightning as pl
import torch import torch
from torch.utils.data import RandomSampler
from fastfold.data import ( from fastfold.data import (
data_pipeline, data_pipeline,
...@@ -380,8 +376,10 @@ class OpenFoldBatchCollator: ...@@ -380,8 +376,10 @@ class OpenFoldBatchCollator:
prot, self.stage prot, self.stage
) )
processed_prots.append(features) processed_prots.append(features)
# By this stack, the batch dimension is processed and added.
stack_fn = partial(torch.stack, dim=0) stack_fn = partial(torch.stack, dim=0)
# I have modified some codes. Now if the bs=1, the shape will be [...] rather than [1, ...]
# If bs>1(not allowed), the shape would be still [2, ...]
return dict_multimap(stack_fn, processed_prots) return dict_multimap(stack_fn, processed_prots)
...@@ -478,8 +476,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader): ...@@ -478,8 +476,7 @@ class OpenFoldDataLoader(torch.utils.data.DataLoader):
return _batch_prop_gen(it) return _batch_prop_gen(it)
class OpenFoldDataModule(pl.LightningDataModule): def SetupTrainDataset(
def __init__(self,
config: mlc.ConfigDict, config: mlc.ConfigDict,
template_mmcif_dir: str, template_mmcif_dir: str,
max_template_date: str, max_template_date: str,
...@@ -491,60 +488,19 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -491,60 +488,19 @@ class OpenFoldDataModule(pl.LightningDataModule):
distillation_chain_data_cache_path: Optional[str] = None, distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
predict_alignment_dir: Optional[str] = None,
kalign_binary_path: str = '/usr/bin/kalign', kalign_binary_path: str = '/usr/bin/kalign',
train_mapping_path: Optional[str] = None, train_mapping_path: Optional[str] = None,
distillation_mapping_path: Optional[str] = None, distillation_mapping_path: Optional[str] = None,
obsolete_pdbs_file_path: Optional[str] = None, obsolete_pdbs_file_path: Optional[str] = None,
template_release_dates_cache_path: Optional[str] = None, template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000, train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None, _alignment_index_path: Optional[str] = None,
**kwargs **kwargs,
): ):
super(OpenFoldDataModule, self).__init__()
self.config = config if(train_data_dir is None or train_alignment_dir is None):
self.template_mmcif_dir = template_mmcif_dir
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
self.predict_data_dir = predict_data_dir
self.predict_alignment_dir = predict_alignment_dir
self.kalign_binary_path = kalign_binary_path
self.train_mapping_path = train_mapping_path
self.distillation_mapping_path = distillation_mapping_path
self.template_release_dates_cache_path = (
template_release_dates_cache_path
)
self.obsolete_pdbs_file_path = obsolete_pdbs_file_path
self.batch_seed = batch_seed
self.train_epoch_len = train_epoch_len
if(self.train_data_dir is None and self.predict_data_dir is None):
raise ValueError( raise ValueError(
'At least one of train_data_dir or predict_data_dir must be ' 'train_data_dir and train_alignment_dir must be specified'
'specified'
)
self.training_mode = self.train_data_dir is not None
if(self.training_mode and train_alignment_dir is None):
raise ValueError(
'In training mode, train_alignment_dir must be specified'
)
elif(not self.training_mode and predict_alignment_dir is None):
raise ValueError(
'In inference mode, predict_alignment_dir must be specified'
) )
elif(val_data_dir is not None and val_alignment_dir is None): elif(val_data_dir is not None and val_alignment_dir is None):
raise ValueError( raise ValueError(
...@@ -552,156 +508,124 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -552,156 +508,124 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
# An ad-hoc measure for our particular filesystem restrictions _alignment_index = None
self._alignment_index = None
if(_alignment_index_path is not None): if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp: with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp) _alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset, dataset_gen = partial(OpenFoldSingleDataset,
template_mmcif_dir=self.template_mmcif_dir, template_mmcif_dir=template_mmcif_dir,
max_template_date=self.max_template_date, max_template_date=max_template_date,
config=self.config, config=config,
kalign_binary_path=self.kalign_binary_path, kalign_binary_path=kalign_binary_path,
template_release_dates_cache_path= template_release_dates_cache_path=
self.template_release_dates_cache_path, template_release_dates_cache_path,
obsolete_pdbs_file_path= obsolete_pdbs_file_path=
self.obsolete_pdbs_file_path, obsolete_pdbs_file_path,
) )
if(self.training_mode):
train_dataset = dataset_gen( train_dataset = dataset_gen(
data_dir=self.train_data_dir, data_dir=train_data_dir,
alignment_dir=self.train_alignment_dir, alignment_dir=train_alignment_dir,
mapping_path=self.train_mapping_path, mapping_path=train_mapping_path,
max_template_hits=self.config.train.max_template_hits, max_template_hits=config.train.max_template_hits,
shuffle_top_k_prefiltered= shuffle_top_k_prefiltered=
self.config.train.shuffle_top_k_prefiltered, config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False, treat_pdb_as_distillation=False,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
_alignment_index=self._alignment_index, _alignment_index=_alignment_index,
) )
distillation_dataset = None distillation_dataset = None
if(self.distillation_data_dir is not None): if(distillation_data_dir is not None):
distillation_dataset = dataset_gen( distillation_dataset = dataset_gen(
data_dir=self.distillation_data_dir, data_dir=distillation_data_dir,
alignment_dir=self.distillation_alignment_dir, alignment_dir=distillation_alignment_dir,
mapping_path=self.distillation_mapping_path, mapping_path=distillation_mapping_path,
max_template_hits=self.train.max_template_hits, max_template_hits=config.train.max_template_hits,
treat_pdb_as_distillation=True, treat_pdb_as_distillation=True,
mode="train", mode="train",
_output_raw=True, _output_raw=True,
) )
d_prob = self.config.train.distillation_prob d_prob = config.train.distillation_prob
if(distillation_dataset is not None): if(distillation_dataset is not None):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [ chain_data_cache_paths = [
self.train_chain_data_cache_path, train_chain_data_cache_path,
self.distillation_chain_data_cache_path, distillation_chain_data_cache_path,
] ]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
chain_data_cache_paths = [ chain_data_cache_paths = [
self.train_chain_data_cache_path, train_chain_data_cache_path,
] ]
self.train_dataset = OpenFoldDataset( train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths, chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False, _roll_at_init=False,
) )
if(self.val_data_dir is not None): if(val_data_dir is not None):
self.eval_dataset = dataset_gen( eval_dataset = dataset_gen(
data_dir=self.val_data_dir, data_dir=val_data_dir,
alignment_dir=self.val_alignment_dir, alignment_dir=val_alignment_dir,
mapping_path=None, mapping_path=None,
max_template_hits=self.config.eval.max_template_hits, max_template_hits=config.eval.max_template_hits,
mode="eval", mode="eval",
_output_raw=True, _output_raw=True,
) )
else: else:
self.eval_dataset = None eval_dataset = None
else:
self.predict_dataset = dataset_gen(
data_dir=self.predict_data_dir,
alignment_dir=self.predict_alignment_dir,
mapping_path=None,
max_template_hits=self.config.predict.max_template_hits,
mode="predict",
)
def _gen_dataloader(self, stage): return train_dataset, eval_dataset
generator = torch.Generator()
if(self.batch_seed is not None):
generator = generator.manual_seed(self.batch_seed)
dataset = None
if(stage == "train"):
dataset = self.train_dataset
# Filter the dataset, if necessary
dataset.reroll()
elif(stage == "eval"):
dataset = self.eval_dataset
elif(stage == "predict"):
dataset = self.predict_dataset
else:
raise ValueError("Invalid stage")
batch_collator = OpenFoldBatchCollator(self.config, stage)
dl = OpenFoldDataLoader(
dataset,
config=self.config,
stage=stage,
generator=generator,
batch_size=self.config.data_module.data_loaders.batch_size,
num_workers=self.config.data_module.data_loaders.num_workers,
collate_fn=batch_collator,
)
return dl
def train_dataloader(self):
return self._gen_dataloader("train")
def val_dataloader(self):
if(self.eval_dataset is not None):
return self._gen_dataloader("eval")
return None
def predict_dataloader(self):
return self._gen_dataloader("predict")
class DummyDataset(torch.utils.data.Dataset): def TrainDataLoader(
def __init__(self, batch_path): config: mlc.ConfigDict,
with open(batch_path, "rb") as f: train_dataset: torch.utils.data.Dataset,
self.batch = pickle.load(f) test_dataset: Optional[torch.utils.data.Dataset] = None,
batch_seed: Optional[int] = None,
def __getitem__(self, idx): ):
return copy.deepcopy(self.batch)
def __len__(self):
return 1000
if not config.data_module.data_loaders.batch_size == 1:
raise ValueError("Only support batch size equals to 1")
class DummyDataLoader(pl.LightningDataModule): generator = torch.Generator()
def __init__(self, batch_path): if(batch_seed is not None):
super().__init__() generator = generator.manual_seed(batch_seed)
self.dataset = DummyDataset(batch_path)
train_batch_collator = OpenFoldBatchCollator(config, "train")
train_dataset.reroll()
train_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="train",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=train_batch_collator,
)
test_dataloader = None
if test_dataset is not None:
test_batch_collator = OpenFoldBatchCollator(config, "test")
test_dataloader = OpenFoldDataLoader(
train_dataset,
config=config,
stage="test",
generator=generator,
batch_size=config.data_module.data_loaders.batch_size,
num_workers=config.data_module.data_loaders.num_workers,
collate_fn=test_batch_collator,
)
def train_dataloader(self): return train_dataloader, test_dataloader
return torch.utils.data.DataLoader(self.dataset)
# Copyright 2022 HPC-AI Tech Inc
# Copyright 2021 AlQuraishi Laboratory # Copyright 2021 AlQuraishi Laboratory
# Copyright 2021 DeepMind Technologies Limited # Copyright 2021 DeepMind Technologies Limited
# #
...@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts): ...@@ -52,7 +53,8 @@ def dict_multimap(fn, dicts):
if type(v) is dict: if type(v) is dict:
new_dict[k] = dict_multimap(fn, all_v) new_dict[k] = dict_multimap(fn, all_v)
else: else:
new_dict[k] = fn(all_v) # when bs = 1, returns [...] rather than [1, ...]
new_dict[k] = fn(all_v) if len(all_v) > 1 else all_v[0]
return new_dict return new_dict
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment