Commit 9a617649 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Finish bringing the alignment index out of hiding

parent 805b45cc
......@@ -36,9 +36,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
treat_pdb_as_distillation: bool = True,
mapping_path: Optional[str] = None,
mode: str = "train",
alignment_index: Optional[Any] = None,
_output_raw: bool = False,
_structure_index: Optional[Any] = None,
_alignment_index: Optional[Any] = None,
):
"""
Args:
......@@ -84,9 +84,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.config = config
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self.alignment_index = alignment_index
self._output_raw = _output_raw
self._structure_index = _structure_index
self._alignment_index = _alignment_index
self.supported_exts = [".cif", ".core", ".pdb"]
......@@ -100,8 +100,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
if(alignment_index is not None):
self._chain_ids = list(alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
......@@ -129,7 +129,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -148,7 +148,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
alignment_index=alignment_index
)
return data
......@@ -163,10 +163,10 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_index = None
if(self.alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
alignment_index = self.alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
......@@ -196,11 +196,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path += ext
if(ext == ".cif"):
data = self._parse_mmcif(
path, file_id, chain_id, alignment_dir, _alignment_index,
path, file_id, chain_id, alignment_dir, alignment_index,
)
elif(ext == ".core"):
data = self.data_pipeline.process_core(
path, alignment_dir, _alignment_index,
path, alignment_dir, alignment_index,
)
elif(ext == ".pdb"):
data = self.data_pipeline.process_pdb(
......@@ -208,8 +208,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
alignment_index=alignment_index,
_structure_index=self._structure_index[name],
_alignment_index=_alignment_index,
)
else:
raise ValueError("Extension branch missing")
......@@ -218,7 +218,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
alignment_index=alignment_index,
)
if(self._output_raw):
......@@ -500,8 +500,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_distillation_structure_index_path: Optional[str] = None,
_alignment_index_path: Optional[str] = None,
_distillation_alignment_index_path: Optional[str] = None,
alignment_index_path: Optional[str] = None,
distillation_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -559,15 +559,15 @@ class OpenFoldDataModule(pl.LightningDataModule):
with open(_distillation_structure_index_path, "r") as fp:
self._distillation_structure_index = json.load(fp)
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
self.alignment_index = None
if(alignment_index_path is not None):
with open(alignment_index_path, "r") as fp:
self.alignment_index = json.load(fp)
self._distillation_alignment_index = None
if(_distillation_alignment_index_path is not None):
with open(_distillation_alignment_index_path, "r") as fp:
self._distillation_alignment_index = json.load(fp)
self.distillation_alignment_index = None
if(distillation_alignment_index_path is not None):
with open(distillation_alignment_index_path, "r") as fp:
self.distillation_alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
......@@ -592,7 +592,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.config.train.shuffle_top_k_prefiltered,
treat_pdb_as_distillation=False,
mode="train",
_alignment_index=self._alignment_index,
alignment_index=self.alignment_index,
)
distillation_dataset = None
......@@ -604,8 +604,8 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_hits=self.config.train.max_template_hits,
treat_pdb_as_distillation=True,
mode="train",
alignment_index=self.distillation_alignment_index,
_structure_index=self._distillation_structure_index,
_alignment_index=self._distillation_alignment_index,
)
d_prob = self.config.train.distillation_prob
......@@ -625,7 +625,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.train_chain_data_cache_path,
]
generator = None
if(self.batch_seed is not None):
generator = torch.Generator()
generator = generator.manual_seed(self.batch_seed + 1)
......@@ -638,7 +637,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
generator=generator,
_roll_at_init=False,
)
if(self.val_data_dir is not None):
self.eval_dataset = dataset_gen(
......
......@@ -462,18 +462,18 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
......@@ -517,17 +517,17 @@ class DataPipeline:
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
if(alignment_index is not None):
fp = open(os.path.join(alignment_dir, alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
for (name, start, size) in alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
......@@ -550,9 +550,9 @@ class DataPipeline:
def _get_msas(self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
):
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
msa_data = self._parse_msa_data(alignment_dir, alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
raise ValueError(
......@@ -576,10 +576,10 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msas, deletion_matrices = self._get_msas(
alignment_dir, input_sequence, _alignment_index
alignment_dir, input_sequence, alignment_index
)
msa_features = make_msa_features(
msas=msas,
......@@ -592,7 +592,7 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
......@@ -606,7 +606,7 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -619,7 +619,7 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {
**sequence_features,
......@@ -632,7 +632,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -650,7 +650,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -658,7 +658,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -669,7 +669,7 @@ class DataPipeline:
is_distillation: bool = True,
chain_id: Optional[str] = None,
_structure_index: Optional[str] = None,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -696,14 +696,14 @@ class DataPipeline:
is_distillation=is_distillation
)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, alignment_index)
return {**pdb_feats, **template_features, **msa_features}
......@@ -711,7 +711,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -724,7 +724,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
hits = self._parse_template_hits(alignment_dir, alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -809,7 +809,7 @@ class DataPipeline:
alignment_dir = os.path.join(
super_alignment_dir, desc
)
hits = self._parse_template_hits(alignment_dir, _alignment_index=None)
hits = self._parse_template_hits(alignment_dir, alignment_index=None)
template_features = make_template_features(
seq,
hits,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment