train.py 19.2 KB
Newer Older
hungchiayu1's avatar
hungchiayu1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import time
import argparse
import json
import logging
import math
import os
import yaml
from pathlib import Path
import diffusers
import datasets
import numpy as np
import pandas as pd
import wandb
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler
from model import TangoFlux
from datasets import load_dataset, Audio
mrfakename's avatar
mrfakename committed
25
from utils import Text2AudioDataset, read_wav_file, pad_wav
hungchiayu1's avatar
hungchiayu1 committed
26
27
28

from diffusers import AutoencoderOobleck
import torchaudio
mrfakename's avatar
mrfakename committed
29

hungchiayu1's avatar
hungchiayu1 committed
30
31
32
33
logger = get_logger(__name__)


def parse_args():
mrfakename's avatar
mrfakename committed
34
35
36
    parser = argparse.ArgumentParser(
        description="Rectified flow for text to audio generation task."
    )
hungchiayu1's avatar
hungchiayu1 committed
37
38

    parser.add_argument(
mrfakename's avatar
mrfakename committed
39
40
41
        "--num_examples",
        type=int,
        default=-1,
hungchiayu1's avatar
hungchiayu1 committed
42
43
44
45
        help="How many examples to use for training and validation.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
46
47
48
        "--text_column",
        type=str,
        default="captions",
hungchiayu1's avatar
hungchiayu1 committed
49
50
51
        help="The name of the column in the datasets containing the input texts.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
52
53
54
        "--audio_column",
        type=str,
        default="location",
hungchiayu1's avatar
hungchiayu1 committed
55
56
57
        help="The name of the column in the datasets containing the audio paths.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
58
59
60
61
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
hungchiayu1's avatar
hungchiayu1 committed
62
63
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
64
65
66
67
        "--adam_beta2",
        type=float,
        default=0.95,
        help="The beta2 parameter for the Adam optimizer.",
hungchiayu1's avatar
hungchiayu1 committed
68
69
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
70
71
72
        "--config",
        type=str,
        default="tangoflux_config.yaml",
hungchiayu1's avatar
hungchiayu1 committed
73
74
75
        help="Config file defining the model size as well as other hyper parameter.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
76
77
78
        "--prefix",
        type=str,
        default="",
hungchiayu1's avatar
hungchiayu1 committed
79
80
        help="Add prefix in text prompts.",
    )
mrfakename's avatar
mrfakename committed
81

hungchiayu1's avatar
hungchiayu1 committed
82
    parser.add_argument(
mrfakename's avatar
mrfakename committed
83
84
85
        "--learning_rate",
        type=float,
        default=3e-5,
hungchiayu1's avatar
hungchiayu1 committed
86
87
88
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
89
        "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
hungchiayu1's avatar
hungchiayu1 committed
90
91
92
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
93
94
95
        "--max_train_steps",
        type=int,
        default=None,
hungchiayu1's avatar
hungchiayu1 committed
96
97
98
99
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
100
101
102
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
hungchiayu1's avatar
hungchiayu1 committed
103
        help="The scheduler type to use.",
mrfakename's avatar
mrfakename committed
104
105
106
107
108
109
110
111
        choices=[
            "linear",
            "cosine",
            "cosine_with_restarts",
            "polynomial",
            "constant",
            "constant_with_warmup",
        ],
hungchiayu1's avatar
hungchiayu1 committed
112
113
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
114
115
116
117
        "--num_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the lr scheduler.",
hungchiayu1's avatar
hungchiayu1 committed
118
119
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
120
121
122
123
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
hungchiayu1 committed
124
125
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
126
127
128
129
        "--adam_weight_decay",
        type=float,
        default=1e-2,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
hungchiayu1 committed
130
131
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
132
        "--seed", type=int, default=None, help="A seed for reproducible training."
hungchiayu1's avatar
hungchiayu1 committed
133
134
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
135
136
137
        "--checkpointing_steps",
        type=str,
        default="best",
hungchiayu1's avatar
hungchiayu1 committed
138
139
140
        help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
141
142
143
144
        "--save_every",
        type=int,
        default=5,
        help="Save model after every how many epochs when checkpointing_steps is set to best.",
hungchiayu1's avatar
hungchiayu1 committed
145
146
147
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
148
149
150
        "--resume_from_checkpoint",
        type=str,
        default=None,
hungchiayu1's avatar
hungchiayu1 committed
151
152
153
154
        help="If the training should continue from a local checkpoint folder.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
155
156
157
        "--load_from_checkpoint",
        type=str,
        default=None,
hungchiayu1's avatar
hungchiayu1 committed
158
159
160
161
162
163
164
        help="Whether to continue training from a model weight",
    )

    args = parser.parse_args()

    return args

mrfakename's avatar
mrfakename committed
165

hungchiayu1's avatar
hungchiayu1 committed
166
167
168
169
170
def main():
    args = parse_args()
    accelerator_log_kwargs = {}

    def load_config(config_path):
mrfakename's avatar
mrfakename committed
171
        with open(config_path, "r") as file:
hungchiayu1's avatar
hungchiayu1 committed
172
173
174
175
            return yaml.safe_load(file)

    config = load_config(args.config)

mrfakename's avatar
mrfakename committed
176
177
178
179
180
    learning_rate = float(config["training"]["learning_rate"])
    num_train_epochs = int(config["training"]["num_train_epochs"])
    num_warmup_steps = int(config["training"]["num_warmup_steps"])
    per_device_batch_size = int(config["training"]["per_device_batch_size"])
    gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
hungchiayu1's avatar
hungchiayu1 committed
181

mrfakename's avatar
mrfakename committed
182
    output_dir = config["paths"]["output_dir"]
hungchiayu1's avatar
hungchiayu1 committed
183

mrfakename's avatar
mrfakename committed
184
185
186
187
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        **accelerator_log_kwargs,
    )
hungchiayu1's avatar
hungchiayu1 committed
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    datasets.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle output directory creation and wandb tracking
    if accelerator.is_main_process:
        if output_dir is None or output_dir == "":
            output_dir = "saved/" + str(int(time.time()))
mrfakename's avatar
mrfakename committed
209

hungchiayu1's avatar
hungchiayu1 committed
210
211
            if not os.path.exists("saved"):
                os.makedirs("saved")
mrfakename's avatar
mrfakename committed
212

hungchiayu1's avatar
hungchiayu1 committed
213
            os.makedirs(output_dir, exist_ok=True)
mrfakename's avatar
mrfakename committed
214

hungchiayu1's avatar
hungchiayu1 committed
215
216
217
218
219
220
221
222
223
        elif output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)

        os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
        with open("{}/summary.jsonl".format(output_dir), "a") as f:
            f.write(json.dumps(dict(vars(args))) + "\n\n")

        accelerator.project_configuration.automatic_checkpoint_naming = False

mrfakename's avatar
mrfakename committed
224
225
226
227
        wandb.init(
            project="Text to Audio Flow matching",
            settings=wandb.Settings(_disable_stats=True),
        )
hungchiayu1's avatar
hungchiayu1 committed
228
229
230
231
232

    accelerator.wait_for_everyone()

    # Get the datasets
    data_files = {}
mrfakename's avatar
mrfakename committed
233
234
235
236
237
238
239
240
241
242
243
244
    # if args.train_file is not None:
    if config["paths"]["train_file"] != "":
        data_files["train"] = config["paths"]["train_file"]
    # if args.validation_file is not None:
    if config["paths"]["val_file"] != "":
        data_files["validation"] = config["paths"]["val_file"]
    if config["paths"]["test_file"] != "":
        data_files["test"] = config["paths"]["test_file"]
    else:
        data_files["test"] = config["paths"]["val_file"]

    extension = "json"
hungchiayu1's avatar
hungchiayu1 committed
245
246
247
    raw_datasets = load_dataset(extension, data_files=data_files)
    text_column, audio_column = args.text_column, args.audio_column

mrfakename's avatar
mrfakename committed
248
249
250
251
    model = TangoFlux(config=config["model"])
    vae = AutoencoderOobleck.from_pretrained(
        "stabilityai/stable-audio-open-1.0", subfolder="vae"
    )
hungchiayu1's avatar
hungchiayu1 committed
252
253
254
255
256

    ## Freeze vae
    for param in vae.parameters():
        vae.requires_grad = False
        vae.eval()
mrfakename's avatar
mrfakename committed
257

hungchiayu1's avatar
hungchiayu1 committed
258
259
260
261
262
    ## Freeze text encoder param
    for param in model.text_encoder.parameters():
        param.requires_grad = False
        model.text_encoder.eval()

mrfakename's avatar
mrfakename committed
263
    prefix = args.prefix
hungchiayu1's avatar
hungchiayu1 committed
264
265

    with accelerator.main_process_first():
mrfakename's avatar
mrfakename committed
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
        train_dataset = Text2AudioDataset(
            raw_datasets["train"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )
        eval_dataset = Text2AudioDataset(
            raw_datasets["validation"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )
        test_dataset = Text2AudioDataset(
            raw_datasets["test"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )

        accelerator.print(
            "Num instances in train: {}, validation: {}, test: {}".format(
                train_dataset.get_num_instances(),
                eval_dataset.get_num_instances(),
                test_dataset.get_num_instances(),
            )
        )

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=train_dataset.collate_fn,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=eval_dataset.collate_fn,
    )
    test_dataloader = DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=test_dataset.collate_fn,
    )
hungchiayu1's avatar
hungchiayu1 committed
317
318
319

    # Optimizer

mrfakename's avatar
mrfakename committed
320
321
322
323
324
325
326
    optimizer_parameters = list(model.transformer.parameters()) + list(
        model.fc.parameters()
    )
    num_trainable_parameters = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
hungchiayu1's avatar
hungchiayu1 committed
327
328
329

    if args.load_from_checkpoint:
        from safetensors.torch import load_file
mrfakename's avatar
mrfakename committed
330

hungchiayu1's avatar
hungchiayu1 committed
331
        w1 = load_file(args.load_from_checkpoint)
mrfakename's avatar
mrfakename committed
332
        model.load_state_dict(w1, strict=False)
hungchiayu1's avatar
hungchiayu1 committed
333
334
335
        logger.info("Weights loaded from{}".format(args.load_from_checkpoint))

    optimizer = torch.optim.AdamW(
mrfakename's avatar
mrfakename committed
336
337
        optimizer_parameters,
        lr=learning_rate,
hungchiayu1's avatar
hungchiayu1 committed
338
339
340
341
342
343
344
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
mrfakename's avatar
mrfakename committed
345
346
347
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
hungchiayu1 committed
348
349
350
351
352
353
354
    if args.max_train_steps is None:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
mrfakename's avatar
mrfakename committed
355
356
357
358
        num_warmup_steps=num_warmup_steps
        * gradient_accumulation_steps
        * accelerator.num_processes,
        num_training_steps=args.max_train_steps * gradient_accumulation_steps,
hungchiayu1's avatar
hungchiayu1 committed
359
360
361
362
    )

    # Prepare everything with our `accelerator`.
    vae, model, optimizer, lr_scheduler = accelerator.prepare(
mrfakename's avatar
mrfakename committed
363
        vae, model, optimizer, lr_scheduler
hungchiayu1's avatar
hungchiayu1 committed
364
365
366
367
368
369
370
    )

    train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
        train_dataloader, eval_dataloader, test_dataloader
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
mrfakename's avatar
mrfakename committed
371
372
373
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
hungchiayu1 committed
374
375
376
377
378
379
380
381
382
383
384
385
386
387
    if overrode_max_train_steps:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.

    # Train!
mrfakename's avatar
mrfakename committed
388
389
390
    total_batch_size = (
        per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
    )
hungchiayu1's avatar
hungchiayu1 committed
391
392
393
394
395

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {per_device_batch_size}")
mrfakename's avatar
mrfakename committed
396
397
398
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
hungchiayu1's avatar
hungchiayu1 committed
399
400
401
402
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # Only show the progress bar once on each machine.
mrfakename's avatar
mrfakename committed
403
404
405
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
hungchiayu1's avatar
hungchiayu1 committed
406
407
408
409

    completed_steps = 0
    starting_epoch = 0
    # Potentially load in the weights and states from a previous save
mrfakename's avatar
mrfakename committed
410
411
    resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
    if resume_from_checkpoint != "":
hungchiayu1's avatar
hungchiayu1 committed
412
413
414
415
        accelerator.load_state(resume_from_checkpoint)
        accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")

    # Duration of the audio clips in seconds
mrfakename's avatar
mrfakename committed
416
417
    best_loss = np.inf
    length = config["training"]["max_audio_duration"]
hungchiayu1's avatar
hungchiayu1 committed
418
419
420
421
422

    for epoch in range(starting_epoch, num_train_epochs):
        model.train()
        total_loss, total_val_loss = 0, 0
        for step, batch in enumerate(train_dataloader):
mrfakename's avatar
mrfakename committed
423

hungchiayu1's avatar
hungchiayu1 committed
424
425
426
427
428
429
430
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                device = model.device
                text, audios, duration, _ = batch

                with torch.no_grad():
                    audio_list = []
mrfakename's avatar
mrfakename committed
431

hungchiayu1's avatar
hungchiayu1 committed
432
433
                    for audio_path in audios:

mrfakename's avatar
mrfakename committed
434
435
436
437
438
439
440
441
                        wav = read_wav_file(
                            audio_path, length
                        )  ## Only read the first 30 seconds of audio
                        if (
                            wav.shape[0] == 1
                        ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                            wav = wav.repeat(2, 1)
                        audio_list.append(wav)
hungchiayu1's avatar
update  
hungchiayu1 committed
442

mrfakename's avatar
mrfakename committed
443
                    audio_input = torch.stack(audio_list, dim=0)
hungchiayu1's avatar
hungchiayu1 committed
444
445
                    audio_input = audio_input.to(device)
                    unwrapped_vae = accelerator.unwrap_model(vae)
mrfakename's avatar
mrfakename committed
446
447
448
449
450
451
452
453
454
455
456
457
458
459

                    duration = torch.tensor(duration, device=device)
                    duration = torch.clamp(
                        duration, max=length
                    )  ## clamp duration to max audio length

                    audio_latent = unwrapped_vae.encode(
                        audio_input
                    ).latent_dist.sample()
                    audio_latent = audio_latent.transpose(
                        1, 2
                    )  ## Tranpose  to (bsz, seq_len, channel)

                loss, _, _, _ = model(audio_latent, text, duration=duration)
hungchiayu1's avatar
hungchiayu1 committed
460
461
                total_loss += loss.detach().float()
                accelerator.backward(loss)
mrfakename's avatar
mrfakename committed
462

hungchiayu1's avatar
hungchiayu1 committed
463
464
465
                if accelerator.sync_gradients:
                    progress_bar.update(1)
                    completed_steps += 1
mrfakename's avatar
mrfakename committed
466

hungchiayu1's avatar
hungchiayu1 committed
467
468
469
470
471
472
473
474
475
476
477
                optimizer.step()
                lr_scheduler.step()

            if completed_steps % 10 == 0 and accelerator.is_main_process:

                total_norm = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2

mrfakename's avatar
mrfakename committed
478
479
480
481
482
                total_norm = total_norm**0.5
                logger.info(
                    f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
                )

hungchiayu1's avatar
hungchiayu1 committed
483
484
485
486
                lr = lr_scheduler.get_last_lr()[0]
                result = {
                    "train_loss": loss.item(),
                    "grad_norm": total_norm,
mrfakename's avatar
mrfakename committed
487
                    "learning_rate": lr,
hungchiayu1's avatar
hungchiayu1 committed
488
                }
mrfakename's avatar
mrfakename committed
489

hungchiayu1's avatar
hungchiayu1 committed
490
491
492
493
494
495
496
497
498
499
500
501
                # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
                wandb.log(result, step=completed_steps)

            # Checks if the accelerator has performed an optimization step behind the scenes

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps }"
                    if output_dir is not None:
                        output_dir = os.path.join(output_dir, output_dir)
                    accelerator.save_state(output_dir)

hungchiayu1's avatar
hungchiayu1 committed
502
503
        if completed_steps >= args.max_train_steps:
            break
hungchiayu1's avatar
hungchiayu1 committed
504
505

        model.eval()
mrfakename's avatar
mrfakename committed
506
507
508
        eval_progress_bar = tqdm(
            range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
        )
hungchiayu1's avatar
hungchiayu1 committed
509
510
511
512
        for step, batch in enumerate(eval_dataloader):
            with accelerator.accumulate(model) and torch.no_grad():
                device = model.device
                text, audios, duration, _ = batch
mrfakename's avatar
mrfakename committed
513

hungchiayu1's avatar
hungchiayu1 committed
514
515
                audio_list = []
                for audio_path in audios:
hungchiayu1's avatar
update  
hungchiayu1 committed
516

mrfakename's avatar
mrfakename committed
517
518
519
520
521
522
523
524
                    wav = read_wav_file(
                        audio_path, length
                    )  ## make sure none of audio exceed 30 sec
                    if (
                        wav.shape[0] == 1
                    ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                        wav = wav.repeat(2, 1)
                    audio_list.append(wav)
hungchiayu1's avatar
hungchiayu1 committed
525

mrfakename's avatar
mrfakename committed
526
                audio_input = torch.stack(audio_list, dim=0)
hungchiayu1's avatar
hungchiayu1 committed
527
                audio_input = audio_input.to(device)
mrfakename's avatar
mrfakename committed
528
                duration = torch.tensor(duration, device=device)
hungchiayu1's avatar
hungchiayu1 committed
529
530
                unwrapped_vae = accelerator.unwrap_model(vae)
                audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
mrfakename's avatar
mrfakename committed
531
532
533
534
535
536
                audio_latent = audio_latent.transpose(
                    1, 2
                )  ## Tranpose  to (bsz, seq_len, channel)

                val_loss, _, _, _ = model(audio_latent, text, duration=duration)

hungchiayu1's avatar
hungchiayu1 committed
537
538
                total_val_loss += val_loss.detach().float()
                eval_progress_bar.update(1)
mrfakename's avatar
mrfakename committed
539

hungchiayu1's avatar
hungchiayu1 committed
540
        if accelerator.is_main_process:
mrfakename's avatar
mrfakename committed
541

hungchiayu1's avatar
hungchiayu1 committed
542
            result = {}
mrfakename's avatar
mrfakename committed
543
544
545
546
547
548
549
550
551
552
            result["epoch"] = float(epoch + 1)

            result["epoch/train_loss"] = round(
                total_loss.item() / len(train_dataloader), 4
            )
            result["epoch/val_loss"] = round(
                total_val_loss.item() / len(eval_dataloader), 4
            )

            wandb.log(result, step=completed_steps)
hungchiayu1's avatar
hungchiayu1 committed
553

mrfakename's avatar
mrfakename committed
554
555
556
            result_string = "Epoch: {}, Loss Train: {}, Val: {}\n".format(
                epoch, result["epoch/train_loss"], result["epoch/val_loss"]
            )
hungchiayu1's avatar
hungchiayu1 committed
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574

            accelerator.print(result_string)

            with open("{}/summary.jsonl".format(output_dir), "a") as f:
                f.write(json.dumps(result) + "\n\n")

            logger.info(result)

            if result["epoch/val_loss"] < best_loss:
                best_loss = result["epoch/val_loss"]
                save_checkpoint = True
            else:
                save_checkpoint = False

        accelerator.wait_for_everyone()
        if accelerator.is_main_process and args.checkpointing_steps == "best":
            if save_checkpoint:
                accelerator.save_state("{}/{}".format(output_dir, "best"))
mrfakename's avatar
mrfakename committed
575

hungchiayu1's avatar
hungchiayu1 committed
576
            if (epoch + 1) % args.save_every == 0:
mrfakename's avatar
mrfakename committed
577
578
579
                accelerator.save_state(
                    "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
                )
hungchiayu1's avatar
hungchiayu1 committed
580
581

        if accelerator.is_main_process and args.checkpointing_steps == "epoch":
mrfakename's avatar
mrfakename committed
582
583
584
585
586
            accelerator.save_state(
                "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
            )


hungchiayu1's avatar
hungchiayu1 committed
587
588
if __name__ == "__main__":
    main()