unet_grad_tts.py 7.2 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
2
3
4
import torch

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
Patrick von Platen's avatar
Patrick von Platen committed
5
from .attention import LinearAttention
6
from .embeddings import get_timestep_embedding
7
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
patil-suraj's avatar
patil-suraj committed
8

9

patil-suraj's avatar
patil-suraj committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class Rezero(torch.nn.Module):
    def __init__(self, fn):
        super(Rezero, self).__init__()
        self.fn = fn
        self.g = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.fn(x) * self.g


class Block(torch.nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
28
29
30
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
        )
patil-suraj's avatar
patil-suraj committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47

    def forward(self, x, mask):
        output = self.block(x * mask)
        return output * mask


class Residual(torch.nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        output = self.fn(x, *args, **kwargs) + x
        return output


class UNetGradTTSModel(ModelMixin, ConfigMixin):
48
    def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
patil-suraj's avatar
patil-suraj committed
49
50
        super(UNetGradTTSModel, self).__init__()

51
        self.register_to_config(
patil-suraj's avatar
patil-suraj committed
52
53
54
55
56
57
            dim=dim,
            dim_mults=dim_mults,
            groups=groups,
            n_spks=n_spks,
            spk_emb_dim=spk_emb_dim,
            n_feats=n_feats,
58
            pe_scale=pe_scale,
patil-suraj's avatar
patil-suraj committed
59
        )
60

patil-suraj's avatar
patil-suraj committed
61
62
63
64
65
66
        self.dim = dim
        self.dim_mults = dim_mults
        self.groups = groups
        self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
        self.spk_emb_dim = spk_emb_dim
        self.pe_scale = pe_scale
67

patil-suraj's avatar
patil-suraj committed
68
        if n_spks > 1:
patil-suraj's avatar
patil-suraj committed
69
            self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
patil-suraj's avatar
style  
patil-suraj committed
70
71
72
            self.spk_mlp = torch.nn.Sequential(
                torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
            )
73

74
        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
patil-suraj's avatar
patil-suraj committed
75
76
77
78
79
80
81
82
83

        dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        self.downs = torch.nn.ModuleList([])
        self.ups = torch.nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
84
85
86
            self.downs.append(
                torch.nn.ModuleList(
                    [
Patrick von Platen's avatar
Patrick von Platen committed
87
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
88
89
90
91
92
93
94
95
96
                            in_channels=dim_in,
                            out_channels=dim_out,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
Patrick von Platen's avatar
Patrick von Platen committed
97
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
98
99
100
101
102
103
104
105
106
                            in_channels=dim_out,
                            out_channels=dim_out,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
107
                        Residual(Rezero(LinearAttention(dim_out))),
108
                        Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
109
110
111
                    ]
                )
            )
patil-suraj's avatar
patil-suraj committed
112
113

        mid_dim = dims[-1]
Patrick von Platen's avatar
Patrick von Platen committed
114
        self.mid_block1 = ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
115
116
117
118
119
120
121
122
123
            in_channels=mid_dim,
            out_channels=mid_dim,
            temb_channels=dim,
            groups=8,
            pre_norm=False,
            eps=1e-5,
            non_linearity="mish",
            overwrite_for_grad_tts=True,
        )
patil-suraj's avatar
patil-suraj committed
124
        self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
Patrick von Platen's avatar
Patrick von Platen committed
125
        self.mid_block2 = ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
126
127
128
129
130
131
132
133
134
            in_channels=mid_dim,
            out_channels=mid_dim,
            temb_channels=dim,
            groups=8,
            pre_norm=False,
            eps=1e-5,
            non_linearity="mish",
            overwrite_for_grad_tts=True,
        )
patil-suraj's avatar
patil-suraj committed
135
136

        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
137
138
139
            self.ups.append(
                torch.nn.ModuleList(
                    [
Patrick von Platen's avatar
Patrick von Platen committed
140
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
141
142
143
144
145
146
147
148
149
                            in_channels=dim_out * 2,
                            out_channels=dim_in,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
Patrick von Platen's avatar
Patrick von Platen committed
150
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
151
152
153
154
155
156
157
158
159
                            in_channels=dim_in,
                            out_channels=dim_in,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
160
                        Residual(Rezero(LinearAttention(dim_in))),
161
                        Upsample2D(dim_in, use_conv_transpose=True),
162
163
164
                    ]
                )
            )
patil-suraj's avatar
patil-suraj committed
165
166
167
        self.final_block = Block(dim, dim)
        self.final_conv = torch.nn.Conv2d(dim, 1, 1)

patil-suraj's avatar
patil-suraj committed
168
    def forward(self, x, timesteps, mu, mask, spk=None):
patil-suraj's avatar
patil-suraj committed
169
170
171
172
        if self.n_spks > 1:
            # Get speaker embedding
            spk = self.spk_emb(spk)

patil-suraj's avatar
patil-suraj committed
173
174
        if not isinstance(spk, type(None)):
            s = self.spk_mlp(spk)
175

176
        t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
patil-suraj's avatar
patil-suraj committed
177
178
179
180
181
182
183
184
185
186
187
188
189
        t = self.mlp(t)

        if self.n_spks < 2:
            x = torch.stack([mu, x], 1)
        else:
            s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
            x = torch.stack([mu, x, s], 1)
        mask = mask.unsqueeze(1)

        hiddens = []
        masks = [mask]
        for resnet1, resnet2, attn, downsample in self.downs:
            mask_down = masks[-1]
Patrick von Platen's avatar
finish  
Patrick von Platen committed
190
191
            x = resnet1(x, t, mask_down)
            x = resnet2(x, t, mask_down)
patil-suraj's avatar
patil-suraj committed
192
193
194
195
196
197
198
            x = attn(x)
            hiddens.append(x)
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, :, ::2])

        masks = masks[:-1]
        mask_mid = masks[-1]
Patrick von Platen's avatar
finish  
Patrick von Platen committed
199
        x = self.mid_block1(x, t, mask_mid)
patil-suraj's avatar
patil-suraj committed
200
        x = self.mid_attn(x)
Patrick von Platen's avatar
finish  
Patrick von Platen committed
201
        x = self.mid_block2(x, t, mask_mid)
patil-suraj's avatar
patil-suraj committed
202
203
204
205

        for resnet1, resnet2, attn, upsample in self.ups:
            mask_up = masks.pop()
            x = torch.cat((x, hiddens.pop()), dim=1)
Patrick von Platen's avatar
finish  
Patrick von Platen committed
206
207
            x = resnet1(x, t, mask_up)
            x = resnet2(x, t, mask_up)
patil-suraj's avatar
patil-suraj committed
208
209
210
211
212
213
            x = attn(x)
            x = upsample(x * mask_up)

        x = self.final_block(x, mask)
        output = self.final_conv(x * mask)

214
        return (output * mask).squeeze(1)