Commit a3e8ebbc authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Add missing function

parent b9faee76
......@@ -37,6 +37,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mapping_path: Optional[str] = None,
mode: str = "train",
_output_raw: bool = False,
_alignment_index: Optional[Any] = None
):
"""
Args:
......@@ -83,6 +84,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
self.treat_pdb_as_distillation = treat_pdb_as_distillation
self.mode = mode
self._output_raw = _output_raw
self._alignment_index = _alignment_index
valid_modes = ["train", "eval", "predict"]
if(mode not in valid_modes):
......@@ -94,7 +96,9 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
"scripts/generate_mmcif_cache.py before running OpenFold"
)
if(mapping_path is None):
if(_alignment_index is not None):
self._chain_ids = list(_alignment_index.keys())
elif(mapping_path is None):
self._chain_ids = list(os.listdir(alignment_dir))
else:
with open(mapping_path, "r") as f:
......@@ -121,7 +125,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
if(not self._output_raw):
self.feature_pipeline = feature_pipeline.FeaturePipeline(config)
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir):
def _parse_mmcif(self, path, file_id, chain_id, alignment_dir, _alignment_index):
with open(path, 'r') as f:
mmcif_string = f.read()
......@@ -140,6 +144,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
mmcif=mmcif_object,
alignment_dir=alignment_dir,
chain_id=chain_id,
_alignment_index=_alignment_index
)
return data
......@@ -154,6 +159,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
name = self.idx_to_chain_id(idx)
alignment_dir = os.path.join(self.alignment_dir, name)
_alignment_index = None
if(self._alignment_index is not None):
alignment_dir = self.alignment_dir
_alignment_index = self._alignment_index[name]
if(self.mode == 'train' or self.mode == 'eval'):
spl = name.rsplit('_', 1)
if(len(spl) == 2):
......@@ -165,11 +175,11 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
path = os.path.join(self.data_dir, file_id)
if(os.path.exists(path + ".cif")):
data = self._parse_mmcif(
path + ".cif", file_id, chain_id, alignment_dir,
path + ".cif", file_id, chain_id, alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".core")):
data = self.data_pipeline.process_core(
path + ".core", alignment_dir,
path + ".core", alignment_dir, _alignment_index,
)
elif(os.path.exists(path + ".pdb")):
data = self.data_pipeline.process_pdb(
......@@ -177,6 +187,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
alignment_dir=alignment_dir,
is_distillation=self.treat_pdb_as_distillation,
chain_id=chain_id,
_alignment_index=_alignment_index,
)
else:
raise ValueError("Invalid file type")
......@@ -185,6 +196,7 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
data = self.data_pipeline.process_fasta(
fasta_path=path,
alignment_dir=alignment_dir,
_alignment_index=_alignment_index,
)
if(self._output_raw):
......@@ -201,16 +213,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def deterministic_train_filter(
chain_data_cache_entry: Any,
prot_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = chain_data_cache_entry.get("resolution", None)
resolution = prot_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = chain_data_cache_entry["seq"]
seq = prot_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
......@@ -224,16 +236,16 @@ def deterministic_train_filter(
def get_stochastic_train_filter_prob(
chain_data_cache_entry: Any,
prot_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = chain_data_cache_entry.get("cluster_size", None)
cluster_size = prot_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(chain_data_cache_entry["seq"])
chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
......@@ -255,7 +267,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
chain_data_cache_paths: List[str],
prot_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
......@@ -264,10 +276,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len
self.generator = generator
self.chain_data_caches = []
for path in chain_data_cache_paths:
self.prot_data_caches = []
for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.chain_data_caches.append(json.load(fp))
self.prot_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
......@@ -286,19 +298,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
chain_data_cache = self.chain_data_caches[dataset_idx]
prot_data_cache = self.prot_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
prot_data_cache_entry = prot_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
chain_data_cache_entry,
prot_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
......@@ -459,10 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
......@@ -474,6 +486,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
template_release_dates_cache_path: Optional[str] = None,
batch_seed: Optional[int] = None,
train_epoch_len: int = 50000,
_alignment_index_path: Optional[str] = None,
**kwargs
):
super(OpenFoldDataModule, self).__init__()
......@@ -483,11 +496,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_chain_data_cache_path = train_chain_data_cache_path
self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
......@@ -525,6 +538,12 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well'
)
# An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None
if(_alignment_index_path is not None):
with open(_alignment_index_path, "r") as fp:
self._alignment_index = json.load(fp)
def setup(self):
# Most of the arguments are the same for the three datasets
dataset_gen = partial(OpenFoldSingleDataset,
......@@ -549,6 +568,7 @@ class OpenFoldDataModule(pl.LightningDataModule):
treat_pdb_as_distillation=False,
mode="train",
_output_raw=True,
_alignment_index=self._alignment_index,
)
distillation_dataset = None
......@@ -569,22 +589,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
prot_data_cache_paths = [
self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
chain_data_cache_paths = [
self.train_chain_data_cache_path,
prot_data_cache_paths = [
self.train_prot_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
chain_data_cache_paths=chain_data_cache_paths,
prot_data_cache_paths=prot_data_cache_paths,
_roll_at_init=False,
)
......
......@@ -422,8 +422,38 @@ class DataPipeline:
def _parse_msa_data(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None,
) -> Mapping[str, Any]:
msa_data = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), "rb")
def read_msa(start, size):
fp.seek(start)
msa = fp.read(size).decode("utf-8")
return msa
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".a3m"):
msa, deletion_matrix = parsers.parse_a3m(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
elif(ext == ".sto"):
msa, deletion_matrix, _ = parsers.parse_stockholm(
read_msa(start, size)
)
data = {"msa": msa, "deletion_matrix": deletion_matrix}
else:
continue
msa_data[name] = data
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -448,8 +478,25 @@ class DataPipeline:
def _parse_template_hits(
self,
alignment_dir: str,
_alignment_index: Optional[Any] = None
) -> Mapping[str, Any]:
all_hits = {}
if(_alignment_index is not None):
fp = open(os.path.join(alignment_dir, _alignment_index["db"]), 'rb')
def read_template(start, size):
fp.seek(start)
return fp.read(size).decode("utf-8")
for (name, start, size) in _alignment_index["files"]:
ext = os.path.splitext(name)[-1]
if(ext == ".hhr"):
hits = parsers.parse_hhr(read_template(start, size))
all_hits[name] = hits
fp.close()
else:
for f in os.listdir(alignment_dir):
path = os.path.join(alignment_dir, f)
ext = os.path.splitext(f)[-1]
......@@ -465,8 +512,9 @@ class DataPipeline:
self,
alignment_dir: str,
input_sequence: Optional[str] = None,
_alignment_index: Optional[str] = None
) -> Mapping[str, Any]:
msa_data = self._parse_msa_data(alignment_dir)
msa_data = self._parse_msa_data(alignment_dir, _alignment_index)
if(len(msa_data) == 0):
if(input_sequence is None):
......@@ -496,6 +544,7 @@ class DataPipeline:
self,
fasta_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""Assembles features for a single sequence in a FASTA file"""
with open(fasta_path) as f:
......@@ -509,7 +558,7 @@ class DataPipeline:
input_description = input_descs[0]
num_res = len(input_sequence)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -522,7 +571,7 @@ class DataPipeline:
num_res=num_res,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {
**sequence_features,
......@@ -535,6 +584,7 @@ class DataPipeline:
mmcif: mmcif_parsing.MmcifObject, # parsing is expensive, so no path
alignment_dir: str,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a specific chain in an mmCIF object.
......@@ -552,7 +602,7 @@ class DataPipeline:
mmcif_feats = make_mmcif_features(mmcif, chain_id)
input_sequence = mmcif.chain_to_seqres[chain_id]
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......@@ -560,7 +610,7 @@ class DataPipeline:
query_release_date=to_date(mmcif.header["release_date"])
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**mmcif_feats, **template_features, **msa_features}
......@@ -570,6 +620,7 @@ class DataPipeline:
alignment_dir: str,
is_distillation: bool = True,
chain_id: Optional[str] = None,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a PDB file.
......@@ -586,14 +637,14 @@ class DataPipeline:
is_distillation
)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
self.template_featurizer,
)
msa_features = self._process_msa_feats(alignment_dir, input_sequence)
msa_features = self._process_msa_feats(alignment_dir, input_sequence, _alignment_index)
return {**pdb_feats, **template_features, **msa_features}
......@@ -601,6 +652,7 @@ class DataPipeline:
self,
core_path: str,
alignment_dir: str,
_alignment_index: Optional[str] = None,
) -> FeatureDict:
"""
Assembles features for a protein in a ProteinNet .core file.
......@@ -613,7 +665,7 @@ class DataPipeline:
description = os.path.splitext(os.path.basename(core_path))[0].upper()
core_feats = make_protein_features(protein_object, description)
hits = self._parse_template_hits(alignment_dir)
hits = self._parse_template_hits(alignment_dir, _alignment_index)
template_features = make_template_features(
input_sequence,
hits,
......
......@@ -1301,3 +1301,10 @@ def _make_atom14_ambiguity_feats():
_make_atom14_ambiguity_feats()
def aatype_to_str_sequence(aatype):
return ''.join([
residue_constants.restypes_with_x[aatype[i]]
for i in range(len(aatype))
])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment