"vscode:/vscode.git/clone" did not exist on "d675d2218e5b271e8434cd03bb3384a2641f12b1"
Commit b9faee76 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Change name of prot data cache

parent 2864b7ca
......@@ -215,7 +215,7 @@ python3 train_openfold.py mmcif_dir/ alignment_dir/ template_mmcif_dir/ \
--deepspeed_config_path deepspeed_config.json \
--checkpoint_every_epoch \
--resume_from_ckpt ckpt_dir/ \
--train_prot_data_cache_path chain_data_cache.json
--train_chain_data_cache_path chain_data_cache.json
```
where `--template_release_dates_cache_path` is a path to the `.json` file
......
......@@ -201,16 +201,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def deterministic_train_filter(
prot_data_cache_entry: Any,
chain_data_cache_entry: Any,
max_resolution: float = 9.,
max_single_aa_prop: float = 0.8,
) -> bool:
# Hard filters
resolution = prot_data_cache_entry.get("resolution", None)
resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution):
return False
seq = prot_data_cache_entry["seq"]
seq = chain_data_cache_entry["seq"]
counts = {}
for aa in seq:
counts.setdefault(aa, 0)
......@@ -224,16 +224,16 @@ def deterministic_train_filter(
def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any,
chain_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters
probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None)
cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"])
chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here?
......@@ -255,7 +255,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int],
epoch_len: int,
prot_data_cache_paths: List[str],
chain_data_cache_paths: List[str],
generator: torch.Generator = None,
_roll_at_init: bool = True,
):
......@@ -264,10 +264,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len
self.generator = generator
self.prot_data_caches = []
for path in prot_data_cache_paths:
self.chain_data_caches = []
for path in chain_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len):
while True:
......@@ -286,19 +286,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx]
chain_data_cache = self.chain_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)):
chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(chain_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
prot_data_cache_entry,
chain_data_cache_entry,
)
weights.append([1. - p, p])
idx.append(candidate_idx)
......@@ -459,10 +459,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str,
train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None,
train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None,
distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None,
......@@ -483,11 +483,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date
self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir
self.train_prot_data_cache_path = train_prot_data_cache_path
self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = (
distillation_prot_data_cache_path
self.distillation_chain_data_cache_path = (
distillation_chain_data_cache_path
)
self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir
......@@ -569,22 +569,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
self.distillation_prot_data_cache_path,
chain_data_cache_paths = [
self.train_chain_data_cache_path,
self.distillation_chain_data_cache_path,
]
else:
datasets = [train_dataset]
probabilities = [1.]
prot_data_cache_paths = [
self.train_prot_data_cache_path,
chain_data_cache_paths = [
self.train_chain_data_cache_path,
]
self.train_dataset = OpenFoldDataset(
datasets=datasets,
probabilities=probabilities,
epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths,
chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False,
)
......
......@@ -358,10 +358,10 @@ if __name__ == "__main__":
help="Whether to TorchScript eligible components of them model"
)
parser.add_argument(
"--train_prot_data_cache_path", type=str, default=None,
"--train_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--distillation_prot_data_cache_path", type=str, default=None,
"--distillation_chain_data_cache_path", type=str, default=None,
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment