losses.py 3.73 KB
Newer Older
1
2
3
4
5
6
7
8
import math

import torch
from torch import nn as nn
from torch.nn import functional as F


class LongCrossEntropyLoss(nn.Module):
9
    r"""CrossEntropy loss"""
10
11
12
13
14
15
16
17
18
19
20
21
22

    def __init__(self):
        super(LongCrossEntropyLoss, self).__init__()

    def forward(self, output, target):
        output = output.transpose(1, 2)
        target = target.long()

        criterion = nn.CrossEntropyLoss()
        return criterion(output, target)


class MoLLoss(nn.Module):
23
    r"""Discretized mixture of logistic distributions loss
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58

    Adapted from wavenet vocoder
    (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
    Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)

    Args:
        y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
        y (Tensor): Target (n_batch x n_time x 1)
        num_classes (int): Number of classes
        log_scale_min (float): Log scale minimum value
        reduce (bool): If True, the losses are averaged or summed for each minibatch

    Returns
        Tensor: loss
    """

    def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
        super(MoLLoss, self).__init__()
        self.num_classes = num_classes
        self.log_scale_min = log_scale_min
        self.reduce = reduce

    def forward(self, y_hat, y):
        y = y.unsqueeze(-1)

        if self.log_scale_min is None:
            self.log_scale_min = math.log(1e-14)

        assert y_hat.dim() == 3
        assert y_hat.size(-1) % 3 == 0

        nr_mix = y_hat.size(-1) // 3

        # unpack parameters (n_batch, n_time, num_mixtures) x 3
        logit_probs = y_hat[:, :, :nr_mix]
59
60
        means = y_hat[:, :, nr_mix : 2 * nr_mix]
        log_scales = torch.clamp(y_hat[:, :, 2 * nr_mix : 3 * nr_mix], min=self.log_scale_min)
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88

        # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
        y = y.expand_as(means)

        centered_y = y - means
        inv_stdv = torch.exp(-log_scales)
        plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
        cdf_plus = torch.sigmoid(plus_in)
        min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
        cdf_min = torch.sigmoid(min_in)

        # log probability for edge case of 0 (before scaling)
        # equivalent: torch.log(F.sigmoid(plus_in))
        log_cdf_plus = plus_in - F.softplus(plus_in)

        # log probability for edge case of 255 (before scaling)
        # equivalent: (1 - F.sigmoid(min_in)).log()
        log_one_minus_cdf_min = -F.softplus(min_in)

        # probability for all other cases
        cdf_delta = cdf_plus - cdf_min

        mid_in = inv_stdv * centered_y
        # log probability in the center of the bin, to be used in extreme cases
        log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)

        inner_inner_cond = (cdf_delta > 1e-5).float()

89
        inner_inner_out = inner_inner_cond * torch.log(torch.clamp(cdf_delta, min=1e-12)) + (1.0 - inner_inner_cond) * (
90
91
92
            log_pdf_mid - math.log((self.num_classes - 1) / 2)
        )
        inner_cond = (y > 0.999).float()
93
        inner_out = inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
94
95
96
97
98
99
100
101
102
103
104
105
        cond = (y < -0.999).float()
        log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out

        log_probs = log_probs + F.log_softmax(logit_probs, -1)

        if self.reduce:
            return -torch.mean(_log_sum_exp(log_probs))
        else:
            return -_log_sum_exp(log_probs).unsqueeze(-1)


def _log_sum_exp(x):
106
    r"""Numerically stable log_sum_exp implementation that prevents overflow"""
107
108
109
110
111

    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis)
    m2, _ = torch.max(x, dim=axis, keepdim=True)
    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))