Commit 4c7c30b1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix chain data cache script bug, improve logging

parent 61d004a2
...@@ -1627,6 +1627,8 @@ class AlphaFoldLoss(nn.Module): ...@@ -1627,6 +1627,8 @@ class AlphaFoldLoss(nn.Module):
crop_len = batch["aatype"].shape[-1] crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown): if(not _return_breakdown):
return cum_loss return cum_loss
......
...@@ -39,13 +39,8 @@ def parse_file( ...@@ -39,13 +39,8 @@ def parse_file(
local_data["seq"] = seq local_data["seq"] = seq
local_data["resolution"] = mmcif.header["resolution"] local_data["resolution"] = mmcif.header["resolution"]
cluster_size = chain_cluster_size_dict.get(full_name.upper(), None) cluster_size = chain_cluster_size_dict.get(full_name.upper(), -1)
if(cluster_size is None): local_data["cluster_size"] = cluster_size
print(file_id)
out.pop(full_name)
continue
else:
local_data["cluster_size"] = cluster_size
elif(ext == ".pdb"): elif(ext == ".pdb"):
with open(os.path.join(args.data_dir, f), "r") as fp: with open(os.path.join(args.data_dir, f), "r") as fp:
pdb_string = fp.read() pdb_string = fp.read()
...@@ -112,7 +107,12 @@ if __name__ == "__main__": ...@@ -112,7 +107,12 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--cluster_file", type=str, default=None, "--cluster_file", type=str, default=None,
help="Path to a cluster file (e.g. PDB40), one cluster per line" help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
) )
parser.add_argument( parser.add_argument(
"--no_workers", type=int, default=4, "--no_workers", type=int, default=4,
......
...@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
def _log(self, loss, loss_breakdown, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"train/loss_epoch",
loss_breakdown["loss"],
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
v,
on_step=False, on_epoch=True, logger=True
)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device): if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device) self.ema.to(batch["aatype"].device)
...@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
# Log it # Log it
self.log( self._log(loss_breakdown)
"train/loss",
loss,
on_step=True, logger=True,
)
self.log(
"train/loss_epoch",
loss,
on_step=False, on_epoch=True, logger=True,
)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"train/{loss_name}",
indiv_loss,
on_step=True, logger=True,
)
with torch.no_grad():
other_metrics = self.compute_validation_metrics(batch, outputs)
for k,v in other_metrics.items():
self.log(f"train/{k}", v, on_step=False, on_epoch=True, logger=True)
return loss return loss
...@@ -125,23 +134,12 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -125,23 +134,12 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss and other metrics # Compute loss and other metrics
batch["use_clamped_fape"] = 0. batch["use_clamped_fape"] = 0.
loss, loss_breakdown = self.loss( _, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True outputs, batch, _return_breakdown=True
) )
self.log("val/loss", loss, on_step=False, on_epoch=True, logger=True)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"val/{loss_name}",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
other_metrics = self.compute_validation_metrics(
batch, outputs, superimposition_metrics=True,
)
for k,v in other_metrics.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, logger=True)
self._log(loss_breakdown, train=False)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
...@@ -440,18 +438,23 @@ if __name__ == "__main__": ...@@ -440,18 +438,23 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--wandb", action="store_true", default=False, "--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
) )
parser.add_argument( parser.add_argument(
"--experiment_name", type=str, default=None, "--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
) )
parser.add_argument( parser.add_argument(
"--wandb_id", type=str, default=None, "--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
) )
parser.add_argument( parser.add_argument(
"--wandb_project", type=str, default=None, "--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
) )
parser.add_argument( parser.add_argument(
"--wandb_entity", type=str, default=None, "--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
) )
parser.add_argument( parser.add_argument(
"--script_modules", type=bool_type, default=False, "--script_modules", type=bool_type, default=False,
...@@ -465,16 +468,27 @@ if __name__ == "__main__": ...@@ -465,16 +468,27 @@ if __name__ == "__main__":
) )
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
) help=(
parser.add_argument( "The virtual length of each training epoch. Stochastic filtering "
"--_alignment_index_path", type=str, default=None, "of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
) )
parser.add_argument( parser.add_argument(
"--log_lr", action="store_true", default=False, "--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
) )
parser.add_argument( parser.add_argument(
"--config_preset", type=str, default="initial_training", "--config_preset", type=str, default="initial_training",
help='Config setting. Choose e.g. "initial_training", "finetuning", "model_1", etc.' help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
) )
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment