main.py 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
import argparse
import logging
import os
from collections import defaultdict
from datetime import datetime
from time import time
from typing import List

import torch
import torchaudio
11
12
13
from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss
from processing import NormalizeDB
14
15
from torch.optim import Adam
from torch.utils.data import DataLoader
16
from torchaudio.models.wavernn import WaveRNN
17
from utils import count_parameters, MetricLogger, save_checkpoint
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43


def parse_args():
    parser = argparse.ArgumentParser()

    parser.add_argument(
        "--workers",
        default=4,
        type=int,
        metavar="N",
        help="number of data loading workers",
    )
    parser.add_argument(
        "--checkpoint",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint",
    )
    parser.add_argument(
        "--epochs",
        default=8000,
        type=int,
        metavar="N",
        help="number of total epochs to run",
    )
44
    parser.add_argument("--start-epoch", default=0, type=int, metavar="N", help="manual epoch number")
45
46
47
48
49
50
51
    parser.add_argument(
        "--print-freq",
        default=10,
        type=int,
        metavar="N",
        help="print frequency in epochs",
    )
jimchen90's avatar
jimchen90 committed
52
53
54
55
56
57
58
    parser.add_argument(
        "--dataset",
        default="ljspeech",
        choices=["ljspeech", "libritts"],
        type=str,
        help="select dataset to train with",
    )
59
    parser.add_argument("--batch-size", default=256, type=int, metavar="N", help="mini-batch size")
60
    parser.add_argument(
61
62
63
64
65
        "--learning-rate",
        default=1e-4,
        type=float,
        metavar="LR",
        help="learning rate",
66
67
68
69
70
71
72
73
    )
    parser.add_argument("--clip-grad", metavar="NORM", type=float, default=4.0)
    parser.add_argument(
        "--mulaw",
        default=True,
        action="store_true",
        help="if used, waveform is mulaw encoded",
    )
74
    parser.add_argument("--jit", default=False, action="store_true", help="if used, model is jitted")
75
76
77
78
79
80
81
    parser.add_argument(
        "--upsample-scales",
        default=[5, 5, 11],
        type=List[int],
        help="the list of upsample scales",
    )
    parser.add_argument(
82
83
84
85
        "--n-bits",
        default=8,
        type=int,
        help="the bits of output waveform",
86
87
88
89
90
91
92
93
94
95
96
97
98
99
    )
    parser.add_argument(
        "--sample-rate",
        default=22050,
        type=int,
        help="the rate of audio dimensions (samples per second)",
    )
    parser.add_argument(
        "--hop-length",
        default=275,
        type=int,
        help="the number of samples between the starts of consecutive frames",
    )
    parser.add_argument(
100
101
102
103
        "--win-length",
        default=1100,
        type=int,
        help="the length of the STFT window",
104
105
    )
    parser.add_argument(
106
107
108
109
        "--f-min",
        default=40.0,
        type=float,
        help="the minimum frequency",
110
111
112
113
114
115
116
117
    )
    parser.add_argument(
        "--min-level-db",
        default=-100,
        type=float,
        help="the minimum db value for spectrogam normalization",
    )
    parser.add_argument(
118
119
120
121
        "--n-res-block",
        default=10,
        type=int,
        help="the number of ResBlock in stack",
122
123
    )
    parser.add_argument(
124
125
126
127
        "--n-rnn",
        default=512,
        type=int,
        help="the dimension of RNN layer",
128
129
    )
    parser.add_argument(
130
131
132
133
        "--n-fc",
        default=512,
        type=int,
        help="the dimension of fully connected layer",
134
135
136
137
138
139
140
141
    )
    parser.add_argument(
        "--kernel-size",
        default=5,
        type=int,
        help="the number of kernel size in the first Conv1d layer",
    )
    parser.add_argument(
142
143
144
145
        "--n-freq",
        default=80,
        type=int,
        help="the number of spectrogram bins to use",
146
147
148
149
150
151
152
153
    )
    parser.add_argument(
        "--n-hidden-melresnet",
        default=128,
        type=int,
        help="the number of hidden dimensions of resblock in melresnet",
    )
    parser.add_argument(
154
155
156
157
        "--n-output-melresnet",
        default=128,
        type=int,
        help="the output dimension of melresnet",
158
159
    )
    parser.add_argument(
160
161
162
163
        "--n-fft",
        default=2048,
        type=int,
        help="the number of Fourier bins",
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
    )
    parser.add_argument(
        "--loss",
        default="crossentropy",
        choices=["crossentropy", "mol"],
        type=str,
        help="the type of loss",
    )
    parser.add_argument(
        "--seq-len-factor",
        default=5,
        type=int,
        help="the length of each waveform to process per batch = hop_length * seq_len_factor",
    )
    parser.add_argument(
        "--val-ratio",
        default=0.1,
        type=float,
        help="the ratio of waveforms for validation",
    )
    parser.add_argument(
185
186
187
188
        "--file-path",
        default="",
        type=str,
        help="the path of audio files",
189
    )
190
    parser.add_argument(
191
192
193
194
        "--normalization",
        default=True,
        action="store_true",
        help="if True, spectrogram is normalized",
195
    )
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

    args = parser.parse_args()
    return args


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch):

    model.train()

    sums = defaultdict(lambda: 0.0)
    start1 = time()

    metric = MetricLogger("train_iteration")
    metric["epoch"] = epoch

211
    for waveform, specgram, target in data_loader:
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230

        start2 = time()

        waveform = waveform.to(device)
        specgram = specgram.to(device)
        target = target.to(device)

        output = model(waveform, specgram)
        output, target = output.squeeze(1), target.squeeze(1)

        loss = criterion(output, target)
        loss_item = loss.item()
        sums["loss"] += loss_item
        metric["loss"] = loss_item

        optimizer.zero_grad()
        loss.backward()

        if args.clip_grad > 0:
231
            gradient = torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip_grad)
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
            sums["gradient"] += gradient.item()
            metric["gradient"] = gradient.item()

        optimizer.step()

        metric["iteration"] = sums["iteration"]
        metric["time"] = time() - start2
        metric()
        sums["iteration"] += 1

    avg_loss = sums["loss"] / len(data_loader)

    metric = MetricLogger("train_epoch")
    metric["epoch"] = epoch
    metric["loss"] = sums["loss"] / len(data_loader)
    metric["gradient"] = avg_loss
    metric["time"] = time() - start1
    metric()


def validate(model, criterion, data_loader, device, epoch):

    with torch.no_grad():

        model.eval()
        sums = defaultdict(lambda: 0.0)
        start = time()

260
        for waveform, specgram, target in data_loader:
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296

            waveform = waveform.to(device)
            specgram = specgram.to(device)
            target = target.to(device)

            output = model(waveform, specgram)
            output, target = output.squeeze(1), target.squeeze(1)

            loss = criterion(output, target)
            sums["loss"] += loss.item()

        avg_loss = sums["loss"] / len(data_loader)

        metric = MetricLogger("validation")
        metric["epoch"] = epoch
        metric["loss"] = avg_loss
        metric["time"] = time() - start
        metric()

        return avg_loss


def main(args):

    devices = ["cuda" if torch.cuda.is_available() else "cpu"]

    logging.info("Start time: {}".format(str(datetime.now())))

    melkwargs = {
        "n_fft": args.n_fft,
        "power": 1,
        "hop_length": args.hop_length,
        "win_length": args.win_length,
    }

    transforms = torch.nn.Sequential(
297
        torchaudio.transforms.MelSpectrogram(
298
299
            sample_rate=args.sample_rate,
            n_mels=args.n_freq,
300
            f_min=args.f_min,
301
302
            mel_scale="slaney",
            norm="slaney",
303
            **melkwargs,
304
        ),
305
        NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
306
307
    )

jimchen90's avatar
jimchen90 committed
308
    train_dataset, val_dataset = split_process_dataset(args, transforms)
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

    loader_training_params = {
        "num_workers": args.workers,
        "pin_memory": False,
        "shuffle": True,
        "drop_last": False,
    }
    loader_validation_params = loader_training_params.copy()
    loader_validation_params["shuffle"] = False

    collate_fn = collate_factory(args)

    train_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_training_params,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        collate_fn=collate_fn,
        **loader_validation_params,
    )

334
    n_classes = 2**args.n_bits if args.loss == "crossentropy" else 30
335

336
    model = WaveRNN(
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
        upsample_scales=args.upsample_scales,
        n_classes=n_classes,
        hop_length=args.hop_length,
        n_res_block=args.n_res_block,
        n_rnn=args.n_rnn,
        n_fc=args.n_fc,
        kernel_size=args.kernel_size,
        n_freq=args.n_freq,
        n_hidden=args.n_hidden_melresnet,
        n_output=args.n_output_melresnet,
    )

    if args.jit:
        model = torch.jit.script(model)

    model = torch.nn.DataParallel(model)
    model = model.to(devices[0], non_blocking=True)

    n = count_parameters(model)
    logging.info(f"Number of parameters: {n}")

    # Optimizer
    optimizer_params = {
        "lr": args.learning_rate,
    }

    optimizer = Adam(model.parameters(), **optimizer_params)

    criterion = LongCrossEntropyLoss() if args.loss == "crossentropy" else MoLLoss()

    best_loss = 10.0

    if args.checkpoint and os.path.isfile(args.checkpoint):
        logging.info(f"Checkpoint: loading '{args.checkpoint}'")
        checkpoint = torch.load(args.checkpoint)

        args.start_epoch = checkpoint["epoch"]
        best_loss = checkpoint["best_loss"]

        model.load_state_dict(checkpoint["state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer"])

379
        logging.info(f"Checkpoint: loaded '{args.checkpoint}' at epoch {checkpoint['epoch']}")
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
    else:
        logging.info("Checkpoint: not found")

        save_checkpoint(
            {
                "epoch": args.start_epoch,
                "state_dict": model.state_dict(),
                "best_loss": best_loss,
                "optimizer": optimizer.state_dict(),
            },
            False,
            args.checkpoint,
        )

    for epoch in range(args.start_epoch, args.epochs):

        train_one_epoch(
397
398
399
400
401
402
            model,
            criterion,
            optimizer,
            train_loader,
            devices[0],
            epoch,
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
        )

        if not (epoch + 1) % args.print_freq or epoch == args.epochs - 1:

            sum_loss = validate(model, criterion, val_loader, devices[0], epoch)

            is_best = sum_loss < best_loss
            best_loss = min(sum_loss, best_loss)
            save_checkpoint(
                {
                    "epoch": epoch + 1,
                    "state_dict": model.state_dict(),
                    "best_loss": best_loss,
                    "optimizer": optimizer.state_dict(),
                },
                is_best,
                args.checkpoint,
            )

    logging.info(f"End time: {datetime.now()}")


if __name__ == "__main__":

    logging.basicConfig(level=logging.INFO)
    args = parse_args()
    main(args)