Commit b9faee76 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Change name of prot data cache

parent 2864b7ca
...@@ -215,7 +215,7 @@ python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \ ...@@ -215,7 +215,7 @@ python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \
--deepspeed_config_path deepspeed_config.json \ --deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \ --checkpoint_every_epoch \
--resume_from_ckpt ckpt_dir/ \ --resume_from_ckpt ckpt_dir/ \
--train_prot_data_cache_path chain_data_cache.json --train_chain_data_cache_path chain_data_cache.json
``` ```
where `--template_release_dates_cache_path` is a path to the `.json` file where `--template_release_dates_cache_path` is a path to the `.json` file
......
...@@ -201,16 +201,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -201,16 +201,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def deterministic_train_filter( def deterministic_train_filter(
prot_data_cache_entry: Any, chain_data_cache_entry: Any,
max_resolution: float = 9., max_resolution: float = 9.,
max_single_aa_prop: float = 0.8, max_single_aa_prop: float = 0.8,
) -> bool: ) -> bool:
# Hard filters # Hard filters
resolution = prot_data_cache_entry.get("resolution", None) resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution): if(resolution is not None and resolution > max_resolution):
return False return False
seq = prot_data_cache_entry["seq"] seq = chain_data_cache_entry["seq"]
counts = {} counts = {}
for aa in seq: for aa in seq:
counts.setdefault(aa, 0) counts.setdefault(aa, 0)
...@@ -224,16 +224,16 @@ def deterministic_train_filter( ...@@ -224,16 +224,16 @@ def deterministic_train_filter(
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any, chain_data_cache_entry: Any,
) -> List[float]: ) -> List[float]:
# Stochastic filters # Stochastic filters
probabilities = [] probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None) cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0): if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size) probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"]) chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256))) probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here? # Risk of underflow here?
...@@ -255,7 +255,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -255,7 +255,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
prot_data_cache_paths: List[str], chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -264,10 +264,10 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -264,10 +264,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.prot_data_caches = [] self.chain_data_caches = []
for path in prot_data_cache_paths: for path in chain_data_cache_paths:
with open(path, "r") as fp: with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp)) self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
...@@ -286,19 +286,19 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -286,19 +286,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx] chain_data_cache = self.chain_data_caches[dataset_idx]
while True: while True:
weights = [] weights = []
idx = [] idx = []
for _ in range(max_cache_len): for _ in range(max_cache_len):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx) chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id] chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)): if(not deterministic_train_filter(chain_data_cache_entry)):
continue continue
p = get_stochastic_train_filter_prob( p = get_stochastic_train_filter_prob(
prot_data_cache_entry, chain_data_cache_entry,
) )
weights.append([1. - p, p]) weights.append([1. - p, p])
idx.append(candidate_idx) idx.append(candidate_idx)
...@@ -459,10 +459,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -459,10 +459,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None, train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None, distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
...@@ -483,11 +483,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -483,11 +483,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date self.max_template_date = max_template_date
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir self.train_alignment_dir = train_alignment_dir
self.train_prot_data_cache_path = train_prot_data_cache_path self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = ( self.distillation_chain_data_cache_path = (
distillation_prot_data_cache_path distillation_chain_data_cache_path
) )
self.val_data_dir = val_data_dir self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir self.val_alignment_dir = val_alignment_dir
...@@ -569,22 +569,22 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -569,22 +569,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1 - d_prob, d_prob]
prot_data_cache_paths = [ chain_data_cache_paths = [
self.train_prot_data_cache_path, self.train_chain_data_cache_path,
self.distillation_prot_data_cache_path, self.distillation_chain_data_cache_path,
] ]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
prot_data_cache_paths = [ chain_data_cache_paths = [
self.train_prot_data_cache_path, self.train_chain_data_cache_path,
] ]
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths, chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False, _roll_at_init=False,
) )
......
...@@ -358,10 +358,10 @@ if __name__ == "__main__": ...@@ -358,10 +358,10 @@ if __name__ == "__main__":
help="Whether to TorchScript eligible components of them model" help="Whether to TorchScript eligible components of them model"
) )
parser.add_argument( parser.add_argument(
"--train_prot_data_cache_path", type=str, default=None, "--train_chain_data_cache_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--distillation_prot_data_cache_path", type=str, default=None, "--distillation_chain_data_cache_path", type=str, default=None,
) )
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment