"R-package/vscode:/vscode.git/clone" did not exist on "8b61a15085cf2ac88792341511083c53737880b1"
tvd.py 446 Bytes
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch.nn as nn

from liger_kernel.ops import LigerTVDLossFunction


class LigerTVDLoss(nn.Module):
    def __init__(self, reduction="batchmean", ignore_index: int = -100):
        super(LigerTVDLoss, self).__init__()
        self.reduction = reduction
        self.ignore_index = ignore_index

    def forward(self, p, q, shift_labels=None):
        return LigerTVDLossFunction.apply(p, q, shift_labels, self.reduction, self.ignore_index)