losses.py 3.81 KB
Newer Older
flyingdown's avatar
flyingdown committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import math

import torch
from torch import nn as nn
from torch.nn import functional as F


class LongCrossEntropyLoss(nn.Module):
    r""" CrossEntropy loss
    """

    def __init__(self):
        super(LongCrossEntropyLoss, self).__init__()

    def forward(self, output, target):
        output = output.transpose(1, 2)
        target = target.long()

        criterion = nn.CrossEntropyLoss()
        return criterion(output, target)


class MoLLoss(nn.Module):
    r""" Discretized mixture of logistic distributions loss

    Adapted from wavenet vocoder
    (https://github.com/r9y9/wavenet_vocoder/blob/master/wavenet_vocoder/mixture.py)
    Explanation of loss (https://github.com/Rayhane-mamah/Tacotron-2/issues/155)

    Args:
        y_hat (Tensor): Predicted output (n_batch x n_time x n_channel)
        y (Tensor): Target (n_batch x n_time x 1)
        num_classes (int): Number of classes
        log_scale_min (float): Log scale minimum value
        reduce (bool): If True, the losses are averaged or summed for each minibatch

    Returns
        Tensor: loss
    """

    def __init__(self, num_classes=65536, log_scale_min=None, reduce=True):
        super(MoLLoss, self).__init__()
        self.num_classes = num_classes
        self.log_scale_min = log_scale_min
        self.reduce = reduce

    def forward(self, y_hat, y):
        y = y.unsqueeze(-1)

        if self.log_scale_min is None:
            self.log_scale_min = math.log(1e-14)

        assert y_hat.dim() == 3
        assert y_hat.size(-1) % 3 == 0

        nr_mix = y_hat.size(-1) // 3

        # unpack parameters (n_batch, n_time, num_mixtures) x 3
        logit_probs = y_hat[:, :, :nr_mix]
        means = y_hat[:, :, nr_mix: 2 * nr_mix]
        log_scales = torch.clamp(
            y_hat[:, :, 2 * nr_mix: 3 * nr_mix], min=self.log_scale_min
        )

        # (n_batch x n_time x 1) to (n_batch x n_time x num_mixtures)
        y = y.expand_as(means)

        centered_y = y - means
        inv_stdv = torch.exp(-log_scales)
        plus_in = inv_stdv * (centered_y + 1.0 / (self.num_classes - 1))
        cdf_plus = torch.sigmoid(plus_in)
        min_in = inv_stdv * (centered_y - 1.0 / (self.num_classes - 1))
        cdf_min = torch.sigmoid(min_in)

        # log probability for edge case of 0 (before scaling)
        # equivalent: torch.log(F.sigmoid(plus_in))
        log_cdf_plus = plus_in - F.softplus(plus_in)

        # log probability for edge case of 255 (before scaling)
        # equivalent: (1 - F.sigmoid(min_in)).log()
        log_one_minus_cdf_min = -F.softplus(min_in)

        # probability for all other cases
        cdf_delta = cdf_plus - cdf_min

        mid_in = inv_stdv * centered_y
        # log probability in the center of the bin, to be used in extreme cases
        log_pdf_mid = mid_in - log_scales - 2.0 * F.softplus(mid_in)

        inner_inner_cond = (cdf_delta > 1e-5).float()

        inner_inner_out = inner_inner_cond * torch.log(
            torch.clamp(cdf_delta, min=1e-12)
        ) + (1.0 - inner_inner_cond) * (
            log_pdf_mid - math.log((self.num_classes - 1) / 2)
        )
        inner_cond = (y > 0.999).float()
        inner_out = (
            inner_cond * log_one_minus_cdf_min + (1.0 - inner_cond) * inner_inner_out
        )
        cond = (y < -0.999).float()
        log_probs = cond * log_cdf_plus + (1.0 - cond) * inner_out

        log_probs = log_probs + F.log_softmax(logit_probs, -1)

        if self.reduce:
            return -torch.mean(_log_sum_exp(log_probs))
        else:
            return -_log_sum_exp(log_probs).unsqueeze(-1)


def _log_sum_exp(x):
    r""" Numerically stable log_sum_exp implementation that prevents overflow
    """

    axis = len(x.size()) - 1
    m, _ = torch.max(x, dim=axis)
    m2, _ = torch.max(x, dim=axis, keepdim=True)
    return m + torch.log(torch.sum(torch.exp(x - m2), dim=axis))