"configs/rec/rec_resnet_stn_bilstm_att.yml" did not exist on "2cda3b614a5e059fae3053f5f616b52fa9fee421"
cross_entropy_apex.py 2.07 KB
Newer Older
Tri Dao's avatar
Tri Dao committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
import torch.nn as nn

import xentropy_cuda_lib


# https://github.com/NVIDIA/apex/blob/master/apex/contrib/xentropy/softmax_xentropy.py
class SoftmaxCrossEntropyLossFn(torch.autograd.Function):
    @staticmethod
    def forward(ctx, logits, labels, smoothing=0.0, padding_idx=0, inplace_backward=False):
        losses, max_log_sum_exp = xentropy_cuda_lib.forward(
            logits, labels, smoothing)
        losses.masked_fill_(labels==padding_idx, 0)
        ctx.save_for_backward(logits, max_log_sum_exp, labels)
        ctx.smoothing = smoothing
        ctx.padding_idx = padding_idx
        ctx.inplace_backward = inplace_backward
        return losses

    @staticmethod
    def backward(ctx, grad_loss):
        logits, max_log_sum_exp, labels = ctx.saved_tensors
        if not grad_loss.is_contiguous():
            grad_loss = grad_loss.contiguous()
        grad_loss.masked_fill_(labels==ctx.padding_idx, 0)
        grad_logits = xentropy_cuda_lib.backward(grad_loss, logits, max_log_sum_exp, labels,
                                                 ctx.smoothing, ctx.inplace_backward)
        return grad_logits, None, None, None, None


class CrossEntropyLossApex(nn.Module):

    def __init__(self, ignore_index=-100, reduction='mean', label_smoothing=0.0,
                 inplace_backward=False):
        super().__init__()
        if reduction not in ['mean', 'none']:
            raise NotImplementedError("Only support reduction = 'mean' or 'none'")
        self.ignore_index = ignore_index
        self.reduction = reduction
        self.label_smoothing = label_smoothing
        self.inplace_backward = inplace_backward

    def forward(self, input, target):
        assert input.is_cuda and target.is_cuda
        # SoftmaxCrossEntropyLoss implicitly casts to float
        loss = SoftmaxCrossEntropyLossFn.apply(input, target, self.label_smoothing,
                                               self.ignore_index, self.inplace_backward)
        if self.reduction == 'mean':
            return loss.sum() / (target != self.ignore_index).sum()
        else:
            return loss