from collections import namedtuple from typing import Callable, Optional import pytorch_lightning as pl import torch import torch.nn as nn from torch.optim.optimizer import Optimizer Batch = namedtuple("Batch", ["inputs", "labels"]) class SSLPretrainModule(pl.LightningModule): def __init__( self, model: nn.Module, loss_fn: Callable, optimizer: Optimizer, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, ): super().__init__() self.model = model self.loss_fn = loss_fn self.optimizer = optimizer self.lr_scheduler = lr_scheduler def log_metrics(self, batch: Batch, output, loss, step_type): """Log useful information to TensorBoard. Users are expected to write their customized `log_metrics` method to log information such as loss values, metric scores, etc. Args: batch (Batch): Batch tuple from the dataloader. output: Output generated by the model. loss (Tensor): Generated class step_type (str): Type of step. Choices are "train", "val", and "test". """ pass def training_step(self, batch: Batch, batch_idx): out = self.model(*batch.inputs) loss, num_frame = self.loss_fn(*out, *batch.labels) self.log_metric(batch, out, loss, "train") # normalize the loss based on the sum of num_frame across all GPUs num_frames = self.all_gather(num_frame) self.log( "Gathered number of frames", num_frames.float().sum(), on_step=True, on_epoch=True, ) loss *= num_frames.size(0) / num_frames.sum() # world size / num_frames return loss def validation_step(self, batch, batch_idx): out = self.model(*batch.inputs) loss, _ = self.loss_fn(*out, *batch.labels) self.log_metric(batch, out, loss, "val") return loss