Commit 50f72d45 authored by Gustaf Ahdritz's avatar Gustaf Ahdritz
Browse files

Improve logging

parent a3e8ebbc
...@@ -213,16 +213,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset): ...@@ -213,16 +213,16 @@ class OpenFoldSingleDataset(torch.utils.data.Dataset):
def deterministic_train_filter( def deterministic_train_filter(
prot_data_cache_entry: Any, chain_data_cache_entry: Any,
max_resolution: float = 9., max_resolution: float = 9.,
max_single_aa_prop: float = 0.8, max_single_aa_prop: float = 0.8,
) -> bool: ) -> bool:
# Hard filters # Hard filters
resolution = prot_data_cache_entry.get("resolution", None) resolution = chain_data_cache_entry.get("resolution", None)
if(resolution is not None and resolution > max_resolution): if(resolution is not None and resolution > max_resolution):
return False return False
seq = prot_data_cache_entry["seq"] seq = chain_data_cache_entry["seq"]
counts = {} counts = {}
for aa in seq: for aa in seq:
counts.setdefault(aa, 0) counts.setdefault(aa, 0)
...@@ -236,16 +236,16 @@ def deterministic_train_filter( ...@@ -236,16 +236,16 @@ def deterministic_train_filter(
def get_stochastic_train_filter_prob( def get_stochastic_train_filter_prob(
prot_data_cache_entry: Any, chain_data_cache_entry: Any,
) -> List[float]: ) -> List[float]:
# Stochastic filters # Stochastic filters
probabilities = [] probabilities = []
cluster_size = prot_data_cache_entry.get("cluster_size", None) cluster_size = chain_data_cache_entry.get("cluster_size", None)
if(cluster_size is not None and cluster_size > 0): if(cluster_size is not None and cluster_size > 0):
probabilities.append(1 / cluster_size) probabilities.append(1 / cluster_size)
chain_length = len(prot_data_cache_entry["seq"]) chain_length = len(chain_data_cache_entry["seq"])
probabilities.append((1 / 512) * (max(min(chain_length, 512), 256))) probabilities.append((1 / 512) * (max(min(chain_length, 512), 256)))
# Risk of underflow here? # Risk of underflow here?
...@@ -267,7 +267,7 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -267,7 +267,7 @@ class OpenFoldDataset(torch.utils.data.Dataset):
datasets: Sequence[OpenFoldSingleDataset], datasets: Sequence[OpenFoldSingleDataset],
probabilities: Sequence[int], probabilities: Sequence[int],
epoch_len: int, epoch_len: int,
prot_data_cache_paths: List[str], chain_data_cache_paths: List[str],
generator: torch.Generator = None, generator: torch.Generator = None,
_roll_at_init: bool = True, _roll_at_init: bool = True,
): ):
...@@ -276,10 +276,10 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -276,10 +276,10 @@ class OpenFoldDataset(torch.utils.data.Dataset):
self.epoch_len = epoch_len self.epoch_len = epoch_len
self.generator = generator self.generator = generator
self.prot_data_caches = [] self.chain_data_caches = []
for path in prot_data_cache_paths: for path in chain_data_cache_paths:
with open(path, "r") as fp: with open(path, "r") as fp:
self.prot_data_caches.append(json.load(fp)) self.chain_data_caches.append(json.load(fp))
def looped_shuffled_dataset_idx(dataset_len): def looped_shuffled_dataset_idx(dataset_len):
while True: while True:
...@@ -298,19 +298,19 @@ class OpenFoldDataset(torch.utils.data.Dataset): ...@@ -298,19 +298,19 @@ class OpenFoldDataset(torch.utils.data.Dataset):
max_cache_len = int(epoch_len * probabilities[dataset_idx]) max_cache_len = int(epoch_len * probabilities[dataset_idx])
dataset = self.datasets[dataset_idx] dataset = self.datasets[dataset_idx]
idx_iter = looped_shuffled_dataset_idx(len(dataset)) idx_iter = looped_shuffled_dataset_idx(len(dataset))
prot_data_cache = self.prot_data_caches[dataset_idx] chain_data_cache = self.chain_data_caches[dataset_idx]
while True: while True:
weights = [] weights = []
idx = [] idx = []
for _ in range(max_cache_len): for _ in range(max_cache_len):
candidate_idx = next(idx_iter) candidate_idx = next(idx_iter)
chain_id = dataset.idx_to_chain_id(candidate_idx) chain_id = dataset.idx_to_chain_id(candidate_idx)
prot_data_cache_entry = prot_data_cache[chain_id] chain_data_cache_entry = chain_data_cache[chain_id]
if(not deterministic_train_filter(prot_data_cache_entry)): if(not deterministic_train_filter(chain_data_cache_entry)):
continue continue
p = get_stochastic_train_filter_prob( p = get_stochastic_train_filter_prob(
prot_data_cache_entry, chain_data_cache_entry,
) )
weights.append([1. - p, p]) weights.append([1. - p, p])
idx.append(candidate_idx) idx.append(candidate_idx)
...@@ -471,10 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -471,10 +471,10 @@ class OpenFoldDataModule(pl.LightningDataModule):
max_template_date: str, max_template_date: str,
train_data_dir: Optional[str] = None, train_data_dir: Optional[str] = None,
train_alignment_dir: Optional[str] = None, train_alignment_dir: Optional[str] = None,
train_prot_data_cache_path: Optional[str] = None, train_chain_data_cache_path: Optional[str] = None,
distillation_data_dir: Optional[str] = None, distillation_data_dir: Optional[str] = None,
distillation_alignment_dir: Optional[str] = None, distillation_alignment_dir: Optional[str] = None,
distillation_prot_data_cache_path: Optional[str] = None, distillation_chain_data_cache_path: Optional[str] = None,
val_data_dir: Optional[str] = None, val_data_dir: Optional[str] = None,
val_alignment_dir: Optional[str] = None, val_alignment_dir: Optional[str] = None,
predict_data_dir: Optional[str] = None, predict_data_dir: Optional[str] = None,
...@@ -496,11 +496,11 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -496,11 +496,11 @@ class OpenFoldDataModule(pl.LightningDataModule):
self.max_template_date = max_template_date self.max_template_date = max_template_date
self.train_data_dir = train_data_dir self.train_data_dir = train_data_dir
self.train_alignment_dir = train_alignment_dir self.train_alignment_dir = train_alignment_dir
self.train_prot_data_cache_path = train_prot_data_cache_path self.train_chain_data_cache_path = train_chain_data_cache_path
self.distillation_data_dir = distillation_data_dir self.distillation_data_dir = distillation_data_dir
self.distillation_alignment_dir = distillation_alignment_dir self.distillation_alignment_dir = distillation_alignment_dir
self.distillation_prot_data_cache_path = ( self.distillation_chain_data_cache_path = (
distillation_prot_data_cache_path distillation_chain_data_cache_path
) )
self.val_data_dir = val_data_dir self.val_data_dir = val_data_dir
self.val_alignment_dir = val_alignment_dir self.val_alignment_dir = val_alignment_dir
...@@ -589,22 +589,22 @@ class OpenFoldDataModule(pl.LightningDataModule): ...@@ -589,22 +589,22 @@ class OpenFoldDataModule(pl.LightningDataModule):
datasets = [train_dataset, distillation_dataset] datasets = [train_dataset, distillation_dataset]
d_prob = self.config.train.distillation_prob d_prob = self.config.train.distillation_prob
probabilities = [1 - d_prob, d_prob] probabilities = [1 - d_prob, d_prob]
prot_data_cache_paths = [ chain_data_cache_paths = [
self.train_prot_data_cache_path, self.train_chain_data_cache_path,
self.distillation_prot_data_cache_path, self.distillation_chain_data_cache_path,
] ]
else: else:
datasets = [train_dataset] datasets = [train_dataset]
probabilities = [1.] probabilities = [1.]
prot_data_cache_paths = [ chain_data_cache_paths = [
self.train_prot_data_cache_path, self.train_chain_data_cache_path,
] ]
self.train_dataset = OpenFoldDataset( self.train_dataset = OpenFoldDataset(
datasets=datasets, datasets=datasets,
probabilities=probabilities, probabilities=probabilities,
epoch_len=self.train_epoch_len, epoch_len=self.train_epoch_len,
prot_data_cache_paths=prot_data_cache_paths, chain_data_cache_paths=chain_data_cache_paths,
_roll_at_init=False, _roll_at_init=False,
) )
......
...@@ -1552,7 +1552,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1552,7 +1552,7 @@ class AlphaFoldLoss(nn.Module):
super(AlphaFoldLoss, self).__init__() super(AlphaFoldLoss, self).__init__()
self.config = config self.config = config
def forward(self, out, batch): def forward(self, out, batch, _return_breakdown=False):
if "violation" not in out.keys(): if "violation" not in out.keys():
out["violation"] = find_structural_violations( out["violation"] = find_structural_violations(
batch, batch,
...@@ -1609,6 +1609,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1609,6 +1609,7 @@ class AlphaFoldLoss(nn.Module):
) )
cum_loss = 0. cum_loss = 0.
losses = {}
for loss_name, loss_fn in loss_fns.items(): for loss_name, loss_fn in loss_fns.items():
weight = self.config[loss_name].weight weight = self.config[loss_name].weight
loss = loss_fn() loss = loss_fn()
...@@ -1616,6 +1617,9 @@ class AlphaFoldLoss(nn.Module): ...@@ -1616,6 +1617,9 @@ class AlphaFoldLoss(nn.Module):
logging.warning(f"{loss_name} loss is NaN. Skipping...") logging.warning(f"{loss_name} loss is NaN. Skipping...")
loss = loss.new_tensor(0., requires_grad=True) loss = loss.new_tensor(0., requires_grad=True)
cum_loss = cum_loss + weight * loss cum_loss = cum_loss + weight * loss
losses[loss_name] = loss.detach().clone()
losses["unscaled_loss"] = cum_loss.detach().clone()
# Scale the loss by the square root of the minimum of the crop size and # Scale the loss by the square root of the minimum of the crop size and
# the (average) sequence length. See subsection 1.9. # the (average) sequence length. See subsection 1.9.
...@@ -1623,4 +1627,7 @@ class AlphaFoldLoss(nn.Module): ...@@ -1623,4 +1627,7 @@ class AlphaFoldLoss(nn.Module):
crop_len = batch["aatype"].shape[-1] crop_len = batch["aatype"].shape[-1]
cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len)) cum_loss = cum_loss * torch.sqrt(min(seq_len, crop_len))
return cum_loss if(not _return_breakdown):
return cum_loss
return cum_loss, losses
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from Bio.SVDSuperimposer import SVDSuperimposer
import numpy as np
import torch
def _superimpose_np(reference, coords):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[N, 3] reference array
coords:
[N, 3] array
Returns:
A tuple of [N, 3] superimposed coords and the final RMSD.
"""
sup = SVDSuperimposer()
sup.set(reference, coords)
sup.run()
return sup.get_transformed(), sup.get_rms()
def _superimpose_single(reference, coords):
reference_np = reference.detach().cpu().numpy()
coords_np = coords.detach().cpu().numpy()
superimposed, rmsd = _superimpose_np(reference_np, coords_np)
return coords.new_tensor(superimposed), coords.new_tensor(rmsd)
def superimpose(reference, coords):
"""
Superimposes coordinates onto a reference by minimizing RMSD using SVD.
Args:
reference:
[*, N, 3] reference tensor
coords:
[*, N, 3] tensor
Returns:
A tuple of [*, N, 3] superimposed coords and [*] final RMSDs.
"""
batch_dims = reference.shape[:-2]
flat_reference = reference.reshape((-1,) + reference.shape[-2:])
flat_coords = coords.reshape((-1,) + reference.shape[-2:])
superimposed_list = []
rmsds = []
for r, c in zip(flat_reference, flat_coords):
superimposed, rmsd = _superimpose_single(r, c)
superimposed_list.append(superimposed)
rmsds.append(rmsd)
superimposed_stacked = torch.stack(superimposed_list, dim=0)
rmsds_stacked = torch.stack(rmsds, dim=0)
superimposed_reshaped = superimposed_stacked.reshape(
batch_dims + coords.shape[-2:]
)
rmsds_reshaped = rmsds_stacked.reshape(
batch_dims
)
return superimposed_reshaped, rmsds_reshaped
# Copyright 2021 AlQuraishi Laboratory
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
def gdt(p1, p2, mask, cutoffs):
n = torch.sum(mask, dim=-1)
p1 = p1.float()
p2 = p2.float()
distances = torch.sqrt(torch.sum((p1 - p2)**2, dim=-1))
scores = []
for c in cutoffs:
score = torch.sum((distances <= c) * mask, dim=-1) / n
scores.append(score)
return sum(scores) / len(scores)
def gdt_ts(p1, p2, mask):
return gdt(p1, p2, mask, [1., 2., 4., 8.])
def gdt_ha(p1, p2, mask):
return gdt(p1, p2, mask, [0.5, 1., 2., 4.])
...@@ -26,14 +26,20 @@ from openfold.data.data_modules import ( ...@@ -26,14 +26,20 @@ from openfold.data.data_modules import (
) )
from openfold.model.model import AlphaFold from openfold.model.model import AlphaFold
from openfold.model.torchscript import script_preset_ from openfold.model.torchscript import script_preset_
from openfold.np import residue_constants
from openfold.utils.callbacks import ( from openfold.utils.callbacks import (
EarlyStoppingVerbose, EarlyStoppingVerbose,
) )
from openfold.utils.exponential_moving_average import ExponentialMovingAverage from openfold.utils.exponential_moving_average import ExponentialMovingAverage
from openfold.utils.argparse import remove_arguments from openfold.utils.argparse import remove_arguments
from openfold.utils.loss import AlphaFoldLoss, lddt_ca from openfold.utils.loss import AlphaFoldLoss, lddt_ca, compute_drmsd
from openfold.utils.seed import seed_everything from openfold.utils.seed import seed_everything
from openfold.utils.superimposition import superimpose
from openfold.utils.tensor_utils import tensor_tree_map from openfold.utils.tensor_utils import tensor_tree_map
from openfold.utils.validation_metrics import (
gdt_ts,
gdt_ha,
)
from scripts.zero_to_fp32 import ( from scripts.zero_to_fp32 import (
get_fp32_state_dict_from_zero_checkpoint get_fp32_state_dict_from_zero_checkpoint
) )
...@@ -52,6 +58,7 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -52,6 +58,7 @@ class OpenFoldWrapper(pl.LightningModule):
) )
self.cached_weights = None self.cached_weights = None
self.last_lr_step = 0
def forward(self, batch): def forward(self, batch):
return self.model(batch) return self.model(batch)
...@@ -67,43 +74,132 @@ class OpenFoldWrapper(pl.LightningModule): ...@@ -67,43 +74,132 @@ class OpenFoldWrapper(pl.LightningModule):
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
# Compute loss # Compute loss
loss = self.loss(outputs, batch) loss, loss_breakdown = self.loss(
outputs, batch, _return_breakdown=True
)
# Log it # Log it
self.log("train/loss", loss, on_step=True, logger=True) self.log(
"train/loss",
loss,
on_step=True, logger=True,
)
self.log(
"train/loss_epoch",
loss,
on_step=False, on_epoch=True, logger=True,
)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"train/{loss_name}",
indiv_loss,
on_step=True, logger=True,
)
with torch.no_grad():
other_metrics = self.compute_validation_metrics(batch, outputs)
for k,v in other_metrics.items():
self.log(f"train/{k}", v, on_step=False, on_epoch=True, logger=True)
return loss return loss
def on_before_zero_grad(self, *args, **kwargs): def on_before_zero_grad(self, *args, **kwargs):
self.ema.update(self.model) self.ema.update(self.model)
# def training_step_end(self, outputs):
# # Temporary measure to address DeepSpeed scheduler bug
# if(self.trainer.global_step != self.last_lr_step):
# self.lr_schedulers().step()
# self.last_lr_step = self.trainer.global_step
def validation_step(self, batch, batch_idx): def validation_step(self, batch, batch_idx):
# At the start of validation, load the EMA weights # At the start of validation, load the EMA weights
if(self.cached_weights is None): if(self.cached_weights is None):
self.cached_weights = self.model.state_dict() self.cached_weights = self.model.state_dict()
self.model.load_state_dict(self.ema.state_dict()["params"]) self.model.load_state_dict(self.ema.state_dict()["params"])
# Calculate validation loss # Run the model
outputs = self(batch) outputs = self(batch)
batch = tensor_tree_map(lambda t: t[..., -1], batch) batch = tensor_tree_map(lambda t: t[..., -1], batch)
lddt_ca_score = lddt_ca(
outputs["final_atom_positions"],
batch["all_atom_positions"],
batch["all_atom_mask"],
eps=self.config.globals.eps,
per_residue=False,
)
self.log("val/lddt_ca", lddt_ca_score, logger=True)
# Compute loss and other metrics
batch["use_clamped_fape"] = 0. batch["use_clamped_fape"] = 0.
loss = self.loss(outputs, batch) loss, loss_breakdown = self.loss(
self.log("val/loss", loss, logger=True) outputs, batch, _return_breakdown=True
)
self.log("val/loss", loss, on_step=False, on_epoch=True, logger=True)
for loss_name, indiv_loss in loss_breakdown.items():
self.log(
f"val/{loss_name}",
indiv_loss,
on_step=False, on_epoch=True, logger=True,
)
other_metrics = self.compute_validation_metrics(
batch, outputs, superimposition_metrics=True,
)
for k,v in other_metrics.items():
self.log(f"val/{k}", v, on_step=False, on_epoch=True, logger=True)
def validation_epoch_end(self, _): def validation_epoch_end(self, _):
# Restore the model weights to normal # Restore the model weights to normal
self.model.load_state_dict(self.cached_weights) self.model.load_state_dict(self.cached_weights)
self.cached_weights = None self.cached_weights = None
def compute_validation_metrics(self,
batch,
outputs,
superimposition_metrics=False
):
metrics = {}
gt_coords = batch["all_atom_positions"]
pred_coords = outputs["final_atom_positions"]
all_atom_mask = batch["all_atom_mask"]
# This is super janky for superimposition. Fix later
gt_coords_masked = gt_coords * all_atom_mask[..., None]
pred_coords_masked = pred_coords * all_atom_mask[..., None]
ca_pos = residue_constants.atom_order["CA"]
gt_coords_masked_ca = gt_coords_masked[..., ca_pos, :]
pred_coords_masked_ca = pred_coords_masked[..., ca_pos, :]
all_atom_mask_ca = all_atom_mask[..., ca_pos]
lddt_ca_score = lddt_ca(
pred_coords,
gt_coords,
all_atom_mask,
eps=self.config.globals.eps,
per_residue=False,
)
metrics["lddt_ca"] = lddt_ca_score
drmsd_ca_score = compute_drmsd(
pred_coords_masked_ca,
gt_coords_masked_ca,
mask=all_atom_mask_ca,
)
metrics["drmsd_ca"] = drmsd_ca_score
if(superimposition_metrics):
superimposed_pred, _ = superimpose(
gt_coords_masked_ca, pred_coords_masked_ca
)
gdt_ts_score = gdt_ts(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
gdt_ha_score = gdt_ha(
superimposed_pred, gt_coords_masked_ca, all_atom_mask_ca
)
metrics["gdt_ts"] = gdt_ts_score
metrics["gdt_ta"] = gdt_ha_score
return metrics
def configure_optimizers(self, def configure_optimizers(self,
learning_rate: float = 1e-3, learning_rate: float = 1e-3,
eps: float = 1e-5, eps: float = 1e-5,
...@@ -180,6 +276,10 @@ def main(args): ...@@ -180,6 +276,10 @@ def main(args):
) )
callbacks.append(perf) callbacks.append(perf)
if(args.log_lr):
lr_monitor = LearningRateMonitor(logging_interval="step")
callbacks.append(lr_monitor)
loggers = [] loggers = []
if(args.wandb): if(args.wandb):
wdb_logger = WandbLogger( wdb_logger = WandbLogger(
...@@ -202,7 +302,7 @@ def main(args): ...@@ -202,7 +302,7 @@ def main(args):
strategy = DDPPlugin(find_unused_parameters=False) strategy = DDPPlugin(find_unused_parameters=False)
else: else:
strategy = None strategy = None
trainer = pl.Trainer.from_argparse_args( trainer = pl.Trainer.from_argparse_args(
args, args,
default_root_dir=args.output_dir, default_root_dir=args.output_dir,
...@@ -366,6 +466,12 @@ if __name__ == "__main__": ...@@ -366,6 +466,12 @@ if __name__ == "__main__":
parser.add_argument( parser.add_argument(
"--train_epoch_len", type=int, default=10000, "--train_epoch_len", type=int, default=10000,
) )
parser.add_argument(
"--_alignment_index_path", type=str, default=None,
)
parser.add_argument(
"--log_lr", action="store_true", default=False,
)
parser = pl.Trainer.add_argparse_args(parser) parser = pl.Trainer.add_argparse_args(parser)
# Disable the initial validation pass # Disable the initial validation pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment