importtorchfromtorchimportTensorfromtorchmetricsimportMetric,AccuracyclassAccuracyMine(Accuracy):"""Wrap torchmetrics.Accuracy to take argmax of y in case of Mixup. """defupdate(self,preds:Tensor,target:Tensor)->None:# type: ignoresuper().update(preds,target.argmax(dim=-1)iftarget.is_floating_point()elsetarget)