Commit 676b6668 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Make filtering more efficient

parent fb341b17
...@@ -212,9 +212,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -212,9 +212,8 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
return len(self._chain_ids) return len(self._chain_ids)
def train_filter( def deterministic_train_filter(
prot_data_cache_entry: Any, prot_data_cache_entry: Any,
generator: torch.Generator,
max_resolution: float = 9., max_resolution: float = 9.,
max_single_aa_prop: float = 0.8, max_single_aa_prop: float = 0.8,
) -> bool: ) -> bool:
...@@ -233,6 +232,12 @@ def train_filter( ...@@ -233,6 +232,12 @@ def train_filter(
if(largest_single_aa_prop > max_single_aa_prop): if(largest_single_aa_prop > max_single_aa_prop):
return False return False
return True
def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any,
) -> List[float]:
# Stochastic filters # Stochastic filters
probabilities = [] probabilities = []
...@@ -243,14 +248,12 @@ def train_filter( ...@@ -243,14 +248,12 @@ def train_filter(
chain_length = len(prot_data_cache_entry["seq"]) chain_length = len(prot_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256))) probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
weights = [[1 - p, p] for p in probabilities] # Risk of underflow here?
results = torch.multinomial( out = 1
torch.tensor(weights), for p in probabilities:
num_samples=1, out *= p
generator=generator,
)
return torch.all(results) return out
class OpenFoldDataset(torch.utils.data.Dataset): class OpenFoldDataset(torch.utils.data.Dataset):
...@@ -265,7 +268,6 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -265,7 +268,6 @@ class OpenFoldDataset(torch.utils.data.Dataset):
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
prot_data_cache_paths: List[str], prot_data_cache_paths: List[str],
filter_fn: Optional[Any] = train_filter,
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -273,7 +275,11 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -273,7 +275,11 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.probabilities = probabilities self.probabilities = probabilities
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.filter_fn = filter_fn
self.prot_data_caches = []
for path in prot_data_cache_paths:
with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
...@@ -288,16 +294,40 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -288,16 +294,40 @@ class OpenFoldDataset(torch.utils.data.Dataset):
for idx in shuf: for idx in shuf:
yield idx yield idx
self.shuffled_idx_iters = [] def looped_samples(dataset_idx):
for d in datasets: max_cache_len = int(epoch_len * probabilities[dataset_idx])
self.shuffled_idx_iters.append( dataset = self.datasets[dataset_idx]
looped_shuffled_dataset_idx(len(d)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx]
while True:
weights = []
idx = []
for _ in range(max_cache_len):
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)):
continue
p = get_stochastic_train_filter_prob(
prot_data_cache_entry,
) )
weights.append([1. - p, p])
idx.append(candidate_idx)
self.prot_data_caches = [] samples = torch.multinomial(
for path in prot_data_cache_paths: torch.tensor(weights),
with open(path, "r") as fp: num_samples=1,
self.prot_data_caches.append(json.load(fp)) generator=self.generator,
)
samples = samples.squeeze()
cache = [i for i, s in zip(idx, samples) if s]
for datapoint_idx in cache:
yield datapoint_idx
self._samples = [looped_samples(i) for i in range(len(self.datasets))]
if(_roll_at_init): if(_roll_at_init):
self.reroll() self.reroll()
...@@ -319,15 +349,8 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -319,15 +349,8 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.datapoints = [] self.datapoints = []
for dataset_idx in dataset_choices: for dataset_idx in dataset_choices:
dataset = self.datasets[dataset_idx] samples = self._samples[dataset_idx]
idx_iter = self.shuffled_idx_iters[dataset_idx] datapoint_idx = next(samples)
prot_data_cache = self.prot_data_caches[dataset_idx]
datapoint_idx = None
while datapoint_idx is None:
candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx)
if(self.filter_fn(prot_data_cache[chain_id], self.generator)):
datapoint_idx = candidate_idx
self.datapoints.append((dataset_idx, datapoint_idx)) self.datapoints.append((dataset_idx, datapoint_idx))
...@@ -448,7 +471,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -448,7 +471,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_filter_fn: Optional[Any] = train_filter,
train_prot_data_cache_path: Optional[str] = None, train_prot_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
...@@ -474,7 +496,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -474,7 +496,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date self.max_template_date = max_template_date
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir self.train_alignment_dir = train_alignment_dir
self.train_filter_fn = train_filter_fn
self.train_prot_data_cache_path = train_prot_data_cache_path self.train_prot_data_cache_path = train_prot_data_cache_path
self.distillation_data_dir = distillation_data_dir self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir self.distillation_alignment_dir = distillation_alignment_dir
...@@ -517,21 +538,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -517,21 +538,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
'be specified as well' 'be specified as well'
) )
cache_missing = (
train_filter_fn and
(
train_prot_data_cache_path is None or
(
distillation_data_dir is not None and
distillation_prot_data_cache_path is None
)
)
)
if(cache_missing):
raise ValueError(
"If train_filter_fn is given, so must the protein data caches"
)
# An ad-hoc measure for our particular filesystem restrictions # An ad-hoc measure for our particular filesystem restrictions
self._alignment_index = None self._alignment_index = None
if(_alignment_index_path is not None): if(_alignment_index_path is not None):
...@@ -599,7 +605,6 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -599,7 +605,6 @@ class OpenFoldDataModule(pl.LightningDataModule):
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths, prot_data_cache_paths=prot_data_cache_paths,
filter_fn=self.train_filter_fn,
_roll_at_init=False, _roll_at_init=False,
) )
......
...@@ -68,7 +68,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -68,7 +68,7 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss = self.loss(outputs, batch)
self.log("train/loss", loss, logger=True) self.log("train/loss", loss, on_step=True, logger=True)
return loss return loss
...@@ -151,9 +151,9 @@ def main(args): ...@@ -151,9 +151,9 @@ def main(args):
if(args.checkpoint_best_val): if(args.checkpoint_best_val):
checkpoint_dir = os.path.join(args.output_dir, "checkpoints") checkpoint_dir = os.path.join(args.output_dir, "checkpoints")
mc = ModelCheckpoint( mc = ModelCheckpoint(
dirpath=checkpoint_dir,
filename="openfold_{epoch}_{step}_{val_loss:.2f}", filename="openfold_{epoch}_{step}_{val_loss:.2f}",
monitor="val/loss", monitor="val/loss",
mode="max",
) )
callbacks.append(mc) callbacks.append(mc)
...@@ -200,6 +200,7 @@ def main(args): ...@@ -200,6 +200,7 @@ def main(args):
) )
if(args.wandb): if(args.wandb):
wdb_logger.experiment.save(args.deepspeed_config_path) wdb_logger.experiment.save(args.deepspeed_config_path)
wdb_logger.experiment.save("openfold/config.py")
elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1: elif (args.gpus is not None and args.gpus > 1) or args.num_nodes > 1:
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPPlugin(find_unused_parameters=False)
else: else:
...@@ -373,9 +374,6 @@ if __name__ == "__main__": ...@@ -373,9 +374,6 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
) )
parser.add_argument(
"--obsolete_pdbs_file_path", type=str,
)
parser.add_argument( parser.add_argument(
"--_alignment_index_path", type=str, default=None, "--_alignment_index_path", type=str, default=None,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment