import torch from torchmetrics import Metric class Accuracy(Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("correct", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, logits, target): logits, target = ( logits.detach().to(self.correct.device), target.detach().to(self.correct.device), ) preds = logits.argmax(dim=-1) preds = preds[target != -100] target = target[target != -100] if target.numel() == 0: return 1 assert preds.shape == target.shape self.correct += torch.sum(preds == target) self.total += target.numel() def compute(self): return self.correct / self.total class Scalar(Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("scalar", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, scalar): if isinstance(scalar, torch.Tensor): scalar = scalar.detach().to(self.scalar.device) else: scalar = torch.tensor(scalar).float().to(self.scalar.device) self.scalar += scalar self.total += 1 def compute(self): return self.scalar / self.total class VQAScore(Metric): def __init__(self, dist_sync_on_step=False): super().__init__(dist_sync_on_step=dist_sync_on_step) self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") def update(self, logits, target): logits, target = ( logits.detach().float().to(self.score.device), target.detach().float().to(self.score.device), ) logits = torch.max(logits, 1)[1] one_hots = torch.zeros(*target.size()).to(target) one_hots.scatter_(1, logits.view(-1, 1), 1) scores = one_hots * target self.score += scores.sum() self.total += len(logits) def compute(self): return self.score / self.total