"docs/vscode:/vscode.git/clone" did not exist on "d8d75d256a7b31edd7bf6c6d6a5ad5df66bf2105"
unet_grad_tts.py 7.57 KB
Newer Older
patil-suraj's avatar
patil-suraj committed
1
2
3
4
import torch

from ..configuration_utils import ConfigMixin
from ..modeling_utils import ModelMixin
Patrick von Platen's avatar
Patrick von Platen committed
5
from .attention import LinearAttention
6
from .embeddings import get_timestep_embedding
7
from .resnet import Downsample2D, ResnetBlock2D, Upsample2D
8
from .unet_new import UNetMidBlock2D
patil-suraj's avatar
patil-suraj committed
9

10

patil-suraj's avatar
patil-suraj committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
class Mish(torch.nn.Module):
    def forward(self, x):
        return x * torch.tanh(torch.nn.functional.softplus(x))


class Rezero(torch.nn.Module):
    def __init__(self, fn):
        super(Rezero, self).__init__()
        self.fn = fn
        self.g = torch.nn.Parameter(torch.zeros(1))

    def forward(self, x):
        return self.fn(x) * self.g


class Block(torch.nn.Module):
    def __init__(self, dim, dim_out, groups=8):
        super(Block, self).__init__()
29
30
31
        self.block = torch.nn.Sequential(
            torch.nn.Conv2d(dim, dim_out, 3, padding=1), torch.nn.GroupNorm(groups, dim_out), Mish()
        )
patil-suraj's avatar
patil-suraj committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

    def forward(self, x, mask):
        output = self.block(x * mask)
        return output * mask


class Residual(torch.nn.Module):
    def __init__(self, fn):
        super(Residual, self).__init__()
        self.fn = fn

    def forward(self, x, *args, **kwargs):
        output = self.fn(x, *args, **kwargs) + x
        return output


class UNetGradTTSModel(ModelMixin, ConfigMixin):
49
    def __init__(self, dim, dim_mults=(1, 2, 4), groups=8, n_spks=None, spk_emb_dim=64, n_feats=80, pe_scale=1000):
patil-suraj's avatar
patil-suraj committed
50
51
        super(UNetGradTTSModel, self).__init__()

52
        self.register_to_config(
patil-suraj's avatar
patil-suraj committed
53
54
55
56
57
58
            dim=dim,
            dim_mults=dim_mults,
            groups=groups,
            n_spks=n_spks,
            spk_emb_dim=spk_emb_dim,
            n_feats=n_feats,
59
            pe_scale=pe_scale,
patil-suraj's avatar
patil-suraj committed
60
        )
61

patil-suraj's avatar
patil-suraj committed
62
63
64
65
66
67
        self.dim = dim
        self.dim_mults = dim_mults
        self.groups = groups
        self.n_spks = n_spks if not isinstance(n_spks, type(None)) else 1
        self.spk_emb_dim = spk_emb_dim
        self.pe_scale = pe_scale
68

patil-suraj's avatar
patil-suraj committed
69
        if n_spks > 1:
patil-suraj's avatar
patil-suraj committed
70
            self.spk_emb = torch.nn.Embedding(n_spks, spk_emb_dim)
patil-suraj's avatar
style  
patil-suraj committed
71
72
73
            self.spk_mlp = torch.nn.Sequential(
                torch.nn.Linear(spk_emb_dim, spk_emb_dim * 4), Mish(), torch.nn.Linear(spk_emb_dim * 4, n_feats)
            )
74

75
        self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(), torch.nn.Linear(dim * 4, dim))
patil-suraj's avatar
patil-suraj committed
76
77
78
79
80
81
82
83
84

        dims = [2 + (1 if n_spks > 1 else 0), *map(lambda m: dim * m, dim_mults)]
        in_out = list(zip(dims[:-1], dims[1:]))
        self.downs = torch.nn.ModuleList([])
        self.ups = torch.nn.ModuleList([])
        num_resolutions = len(in_out)

        for ind, (dim_in, dim_out) in enumerate(in_out):
            is_last = ind >= (num_resolutions - 1)
85
86
87
            self.downs.append(
                torch.nn.ModuleList(
                    [
Patrick von Platen's avatar
Patrick von Platen committed
88
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
89
90
91
92
93
94
95
96
97
                            in_channels=dim_in,
                            out_channels=dim_out,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
Patrick von Platen's avatar
Patrick von Platen committed
98
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
99
100
101
102
103
104
105
106
107
                            in_channels=dim_out,
                            out_channels=dim_out,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
108
                        Residual(Rezero(LinearAttention(dim_out))),
109
                        Downsample2D(dim_out, use_conv=True, padding=1) if not is_last else torch.nn.Identity(),
110
111
112
                    ]
                )
            )
patil-suraj's avatar
patil-suraj committed
113
114

        mid_dim = dims[-1]
115
116
117
118
119
120
121
122
123
124
125

        self.mid = UNetMidBlock2D(
            in_channels=mid_dim,
            temb_channels=dim,
            resnet_groups=8,
            resnet_pre_norm=False,
            resnet_eps=1e-5,
            resnet_act_fn="mish",
            attention_layer_type="linear",
        )

Patrick von Platen's avatar
Patrick von Platen committed
126
        self.mid_block1 = ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
127
128
129
130
131
132
133
134
135
            in_channels=mid_dim,
            out_channels=mid_dim,
            temb_channels=dim,
            groups=8,
            pre_norm=False,
            eps=1e-5,
            non_linearity="mish",
            overwrite_for_grad_tts=True,
        )
patil-suraj's avatar
patil-suraj committed
136
        self.mid_attn = Residual(Rezero(LinearAttention(mid_dim)))
Patrick von Platen's avatar
Patrick von Platen committed
137
        self.mid_block2 = ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
138
139
140
141
142
143
144
145
146
            in_channels=mid_dim,
            out_channels=mid_dim,
            temb_channels=dim,
            groups=8,
            pre_norm=False,
            eps=1e-5,
            non_linearity="mish",
            overwrite_for_grad_tts=True,
        )
147
148
149
        self.mid.resnet_1 = self.mid_block1
        self.mid.attn = self.mid_attn
        self.mid.resnet_2 = self.mid_block2
Patrick von Platen's avatar
Patrick von Platen committed
150

patil-suraj's avatar
patil-suraj committed
151
        for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
152
153
154
            self.ups.append(
                torch.nn.ModuleList(
                    [
Patrick von Platen's avatar
Patrick von Platen committed
155
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
156
157
158
159
160
161
162
163
164
                            in_channels=dim_out * 2,
                            out_channels=dim_in,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
Patrick von Platen's avatar
Patrick von Platen committed
165
                        ResnetBlock2D(
Patrick von Platen's avatar
Patrick von Platen committed
166
167
168
169
170
171
172
173
174
                            in_channels=dim_in,
                            out_channels=dim_in,
                            temb_channels=dim,
                            groups=8,
                            pre_norm=False,
                            eps=1e-5,
                            non_linearity="mish",
                            overwrite_for_grad_tts=True,
                        ),
175
                        Residual(Rezero(LinearAttention(dim_in))),
176
                        Upsample2D(dim_in, use_conv_transpose=True),
177
178
179
                    ]
                )
            )
patil-suraj's avatar
patil-suraj committed
180
181
182
        self.final_block = Block(dim, dim)
        self.final_conv = torch.nn.Conv2d(dim, 1, 1)

patil-suraj's avatar
patil-suraj committed
183
    def forward(self, x, timesteps, mu, mask, spk=None):
patil-suraj's avatar
patil-suraj committed
184
185
186
187
        if self.n_spks > 1:
            # Get speaker embedding
            spk = self.spk_emb(spk)

patil-suraj's avatar
patil-suraj committed
188
189
        if not isinstance(spk, type(None)):
            s = self.spk_mlp(spk)
190

191
        t = get_timestep_embedding(timesteps, self.dim, scale=self.pe_scale)
patil-suraj's avatar
patil-suraj committed
192
193
194
195
196
197
198
199
200
201
202
203
204
        t = self.mlp(t)

        if self.n_spks < 2:
            x = torch.stack([mu, x], 1)
        else:
            s = s.unsqueeze(-1).repeat(1, 1, x.shape[-1])
            x = torch.stack([mu, x, s], 1)
        mask = mask.unsqueeze(1)

        hiddens = []
        masks = [mask]
        for resnet1, resnet2, attn, downsample in self.downs:
            mask_down = masks[-1]
Patrick von Platen's avatar
finish  
Patrick von Platen committed
205
206
            x = resnet1(x, t, mask_down)
            x = resnet2(x, t, mask_down)
patil-suraj's avatar
patil-suraj committed
207
208
209
210
211
212
213
            x = attn(x)
            hiddens.append(x)
            x = downsample(x * mask_down)
            masks.append(mask_down[:, :, :, ::2])

        masks = masks[:-1]
        mask_mid = masks[-1]
214
215

        x = self.mid(x, t, mask=mask_mid)
patil-suraj's avatar
patil-suraj committed
216
217
218
219

        for resnet1, resnet2, attn, upsample in self.ups:
            mask_up = masks.pop()
            x = torch.cat((x, hiddens.pop()), dim=1)
Patrick von Platen's avatar
finish  
Patrick von Platen committed
220
221
            x = resnet1(x, t, mask_up)
            x = resnet2(x, t, mask_up)
patil-suraj's avatar
patil-suraj committed
222
223
224
225
226
227
            x = attn(x)
            x = upsample(x * mask_up)

        x = self.final_block(x, mask)
        output = self.final_conv(x * mask)

228
        return (output * mask).squeeze(1)