Commit 4c7c30b1 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Fix chain data cache script bug, improve logging

parent 61d004a2
......@@ -1627,6 +1627,8 @@ class AlphaFoldLoss(nn.Module):
crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
losses["loss"] = cum_loss.detach().clone()
if(not _return_breakdown):
return cum_loss
......
......@@ -39,13 +39,8 @@ def parse_file(
local_data["seq"] = seq
local_data["resolution"] = mmcif.header["resolution"]
cluster_size = chain_cluster_size_dict.get(full_name.upper(), None)
if(cluster_size is None):
print(file_id)
out.pop(full_name)
continue
else:
local_data["cluster_size"] = cluster_size
cluster_size = chain_cluster_size_dict.get(full_name.upper(), -1)
local_data["cluster_size"] = cluster_size
elif(ext == ".pdb"):
with open(os.path.join(args.data_dir, f), "r") as fp:
pdb_string = fp.read()
......@@ -112,7 +107,12 @@ if __name__ == "__main__":
)
parser.add_argument(
"--cluster_file", type=str, default=None,
help="Path to a cluster file (e.g. PDB40), one cluster per line"
help=(
"Path to a cluster file (e.g. PDB40), one cluster "
"({PROT1_ID}_{CHAIN_ID} {PROT2_ID}_{CHAIN_ID} ...) per line. "
"Chains not in this cluster file will NOT be filtered by cluster "
"size."
)
)
parser.add_argument(
"--no_workers", type=int, default=4,
......
......@@ -63,6 +63,36 @@ class OpenFoldWrapper(pl.LightningModule):
def forward(self, batch):
return self.model(batch)
def _log(self, loss, loss_breakdown, train=True):
phase = "train" if train else "val"
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"{phase}/{loss_name}",
indiv_loss,
on_step=train, on_epoch=(not train), logger=True,
)
if(train):
self.log(
f"train/loss_epoch",
loss_breakdown["loss"],
on_step=False, on_epoch=True, logger=True,
)
with torch.no_grad():
other_metrics = self._compute_validation_metrics(
batch,
outputs,
superimposition_metrics=(not train)
)
for k,v in other_metrics.items():
self.log(
f"{phase}/{k}",
v,
on_step=False, on_epoch=True, logger=True
)
def training_step(self, batch, batch_idx):
if(self.ema.device != batch["aatype"].device):
self.ema.to(batch["aatype"].device)
......@@ -79,28 +109,7 @@ class OpenFoldWrapper(pl.LightningModule):
)
# Log it
self.log(
"train/loss",
loss,
on_step=True, logger=True,
)
self.log(
"train/loss_epoch",
loss,
on_step=False, on_epoch=True, logger=True,
)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"train/{loss_name}",
indiv_loss,
on_step=True, logger=True,
)
with torch.no_grad():
other_metrics = self.compute_validation_metrics(batch, outputs)
for k,v in other_metrics.items():
self.log(f"train/{k}", v, on_step=False, on_epoch=True, logger=True)
self._log(loss_breakdown)
return loss
......@@ -125,23 +134,12 @@ class OpenFoldWrapper(pl.LightningModule):
# Compute loss and other metrics
batch["use_clamped_fape"] = 0.
loss, loss_breakdown = self.loss(
_, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
self.log("val/loss", loss, on_step=False, on_epoch=True, logger=True)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"val/{loss_name}",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
other_metrics = self.compute_validation_metrics(
batch, outputs, superimposition_metrics=True,
)
for k,v in other_metrics.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, logger=True)
self._log(loss_breakdown, train=False)
def validation_epoch_end(self, _):
# Restore the model weights to normal
self.model.load_state_dict(self.cached_weights)
......@@ -440,18 +438,23 @@ if __name__ == "__main__":
)
parser.add_argument(
"--wandb", action="store_true", default=False,
help="Whether to log metrics to Weights & Biases"
)
parser.add_argument(
"--experiment_name", type=str, default=None,
help="Name of the current experiment. Used for wandb logging"
)
parser.add_argument(
"--wandb_id", type=str, default=None,
help="ID of a previous run to be resumed"
)
parser.add_argument(
"--wandb_project", type=str, default=None,
help="Name of the wandb project to which this run will belong"
)
parser.add_argument(
"--wandb_entity", type=str, default=None,
help="wandb username or team name to which runs are attributed"
)
parser.add_argument(
"--script_modules", type=bool_type, default=False,
......@@ -465,16 +468,27 @@ if __name__ == "__main__":
)
parser.add_argument(
"--train_epoch_len", type=int, default=10000,
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
help=(
"The virtual length of each training epoch. Stochastic filtering "
"of training data means that training datasets have no "
"well-defined length. This virtual length affects frequency of "
"validation & checkpointing (by default, one of each per epoch)."
)
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
help="Whether to log the actual learning rate"
)
parser.add_argument(
"--config_preset", type=str, default="initial_training",
help='Config setting. Choose e.g. "initial_training", "finetuning", "model_1", etc.'
help=(
'Config setting. Choose e.g. "initial_training", "finetuning", '
'"model_1", etc. By default, the actual values in the config are '
'used.'
)
)
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser = pl.Trainer.add_argparse_args(parser)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment