train_dpo.py 20.3 KB
Newer Older
hungchiayu1's avatar
updates  
hungchiayu1 committed
1
2
3
4
5
6
7
import time
import argparse
import json
import logging
import math
import os
import yaml
mrfakename's avatar
mrfakename committed
8

hungchiayu1's avatar
updates  
hungchiayu1 committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# from tqdm import tqdm
import copy
from pathlib import Path
import diffusers
import datasets
import numpy as np
import pandas as pd
import wandb
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler
mrfakename's avatar
mrfakename committed
26
from tangoflux.model import TangoFlux
hungchiayu1's avatar
updates  
hungchiayu1 committed
27
from datasets import load_dataset, Audio
mrfakename's avatar
mrfakename committed
28
from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
hungchiayu1's avatar
updates  
hungchiayu1 committed
29
30
31
32
33
34
35
36

from diffusers import AutoencoderOobleck
import torchaudio

logger = get_logger(__name__)


def parse_args():
mrfakename's avatar
mrfakename committed
37
38
39
    parser = argparse.ArgumentParser(
        description="Rectified flow for text to audio generation task."
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
40
41

    parser.add_argument(
mrfakename's avatar
mrfakename committed
42
43
44
        "--num_examples",
        type=int,
        default=-1,
hungchiayu1's avatar
updates  
hungchiayu1 committed
45
46
47
48
        help="How many examples to use for training and validation.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
49
50
51
        "--text_column",
        type=str,
        default="captions",
hungchiayu1's avatar
updates  
hungchiayu1 committed
52
53
54
        help="The name of the column in the datasets containing the input texts.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
55
56
57
        "--audio_column",
        type=str,
        default="location",
hungchiayu1's avatar
updates  
hungchiayu1 committed
58
59
60
        help="The name of the column in the datasets containing the audio paths.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
61
62
63
64
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
65
66
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
67
68
69
70
        "--adam_beta2",
        type=float,
        default=0.95,
        help="The beta2 parameter for the Adam optimizer.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
71
72
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
73
74
75
        "--config",
        type=str,
        default="tangoflux_config.yaml",
hungchiayu1's avatar
updates  
hungchiayu1 committed
76
77
78
79
        help="Config file defining the model size.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
80
        "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
hungchiayu1's avatar
updates  
hungchiayu1 committed
81
82
83
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
84
85
86
        "--max_train_steps",
        type=int,
        default=None,
hungchiayu1's avatar
updates  
hungchiayu1 committed
87
88
89
90
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
91
92
93
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
hungchiayu1's avatar
updates  
hungchiayu1 committed
94
        help="The scheduler type to use.",
mrfakename's avatar
mrfakename committed
95
96
97
98
99
100
101
102
        choices=[
            "linear",
            "cosine",
            "cosine_with_restarts",
            "polynomial",
            "constant",
            "constant_with_warmup",
        ],
hungchiayu1's avatar
updates  
hungchiayu1 committed
103
    )
hungchiayu1's avatar
hungchiayu1 committed
104

hungchiayu1's avatar
updates  
hungchiayu1 committed
105
    parser.add_argument(
mrfakename's avatar
mrfakename committed
106
107
108
109
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
updates  
hungchiayu1 committed
110
111
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
112
113
114
115
        "--adam_weight_decay",
        type=float,
        default=1e-2,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
updates  
hungchiayu1 committed
116
117
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
118
        "--seed", type=int, default=None, help="A seed for reproducible training."
hungchiayu1's avatar
updates  
hungchiayu1 committed
119
120
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
121
122
123
        "--checkpointing_steps",
        type=str,
        default="best",
hungchiayu1's avatar
updates  
hungchiayu1 committed
124
125
126
        help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
127
128
129
130
        "--save_every",
        type=int,
        default=5,
        help="Save model after every how many epochs when checkpointing_steps is set to best.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
131
132
133
134
135
    )



    parser.add_argument(
mrfakename's avatar
mrfakename committed
136
137
138
        "--load_from_checkpoint",
        type=str,
        default=None,
hungchiayu1's avatar
updates  
hungchiayu1 committed
139
140
        help="Whether to continue training from a model weight",
    )
hungchiayu1's avatar
hungchiayu1 committed
141
    
hungchiayu1's avatar
updates  
hungchiayu1 committed
142
143
144
145

    args = parser.parse_args()

    # Sanity checks
mrfakename's avatar
mrfakename committed
146
147
148
149
150
151
152
153
154
    # if args.train_file is None and args.validation_file is None:
    #   raise ValueError("Need a training/validation file.")
    # else:
    #  if args.train_file is not None:
    #     extension = args.train_file.split(".")[-1]
    #    assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
    # if args.validation_file is not None:
    #   extension = args.validation_file.split(".")[-1]
    #  assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
hungchiayu1's avatar
updates  
hungchiayu1 committed
155
156
157

    return args

mrfakename's avatar
mrfakename committed
158

hungchiayu1's avatar
updates  
hungchiayu1 committed
159
160
161
162
163
def main():
    args = parse_args()
    accelerator_log_kwargs = {}

    def load_config(config_path):
mrfakename's avatar
mrfakename committed
164
        with open(config_path, "r") as file:
hungchiayu1's avatar
updates  
hungchiayu1 committed
165
166
167
168
            return yaml.safe_load(file)

    config = load_config(args.config)

mrfakename's avatar
mrfakename committed
169
170
171
172
173
    learning_rate = float(config["training"]["learning_rate"])
    num_train_epochs = int(config["training"]["num_train_epochs"])
    num_warmup_steps = int(config["training"]["num_warmup_steps"])
    per_device_batch_size = int(config["training"]["per_device_batch_size"])
    gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
hungchiayu1's avatar
updates  
hungchiayu1 committed
174

mrfakename's avatar
mrfakename committed
175
    output_dir = config["paths"]["output_dir"]
hungchiayu1's avatar
updates  
hungchiayu1 committed
176

mrfakename's avatar
mrfakename committed
177
178
179
180
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        **accelerator_log_kwargs,
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    datasets.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle output directory creation and wandb tracking
    if accelerator.is_main_process:
        if output_dir is None or output_dir == "":
            output_dir = "saved/" + str(int(time.time()))
mrfakename's avatar
mrfakename committed
202

hungchiayu1's avatar
updates  
hungchiayu1 committed
203
204
            if not os.path.exists("saved"):
                os.makedirs("saved")
mrfakename's avatar
mrfakename committed
205

hungchiayu1's avatar
updates  
hungchiayu1 committed
206
            os.makedirs(output_dir, exist_ok=True)
mrfakename's avatar
mrfakename committed
207

hungchiayu1's avatar
updates  
hungchiayu1 committed
208
209
210
211
212
213
214
215
216
        elif output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)

        os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
        with open("{}/summary.jsonl".format(output_dir), "a") as f:
            f.write(json.dumps(dict(vars(args))) + "\n\n")

        accelerator.project_configuration.automatic_checkpoint_naming = False

mrfakename's avatar
mrfakename committed
217
218
219
220
        wandb.init(
            project="Text to Audio Flow matching DPO",
            settings=wandb.Settings(_disable_stats=True),
        )
hungchiayu1's avatar
updates  
hungchiayu1 committed
221
222
223
224
225

    accelerator.wait_for_everyone()

    # Get the datasets
    data_files = {}
mrfakename's avatar
mrfakename committed
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    # if args.train_file is not None:
    if config["paths"]["train_file"] != "":
        data_files["train"] = config["paths"]["train_file"]
    # if args.validation_file is not None:
    if config["paths"]["val_file"] != "":
        data_files["validation"] = config["paths"]["val_file"]
    if config["paths"]["test_file"] != "":
        data_files["test"] = config["paths"]["test_file"]
    else:
        data_files["test"] = config["paths"]["val_file"]

    extension = "json"
    train_dataset = load_dataset(extension, data_files=data_files["train"])
    data_files.pop("train")
hungchiayu1's avatar
updates  
hungchiayu1 committed
240
241
242
    raw_datasets = load_dataset(extension, data_files=data_files)
    text_column, audio_column = args.text_column, args.audio_column

mrfakename's avatar
mrfakename committed
243
244
245
246
    model = TangoFlux(config=config["model"], initialize_reference_model=True)
    vae = AutoencoderOobleck.from_pretrained(
        "stabilityai/stable-audio-open-1.0", subfolder="vae"
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
247
248
249
250
251

    ## Freeze vae
    for param in vae.parameters():
        vae.requires_grad = False
        vae.eval()
mrfakename's avatar
mrfakename committed
252

hungchiayu1's avatar
updates  
hungchiayu1 committed
253
254
255
256
257
    ## Freeze text encoder param
    for param in model.text_encoder.parameters():
        param.requires_grad = False
        model.text_encoder.eval()

mrfakename's avatar
mrfakename committed
258
    prefix = ""
hungchiayu1's avatar
updates  
hungchiayu1 committed
259
260

    with accelerator.main_process_first():
mrfakename's avatar
mrfakename committed
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
        train_dataset = DPOText2AudioDataset(
            train_dataset["train"],
            prefix,
            text_column,
            "chosen",
            "reject",
            "duration",
            args.num_examples,
        )
        eval_dataset = Text2AudioDataset(
            raw_datasets["validation"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )
        test_dataset = Text2AudioDataset(
            raw_datasets["test"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )

        accelerator.print(
            "Num instances in train: {}, validation: {}, test: {}".format(
                train_dataset.get_num_instances(),
                eval_dataset.get_num_instances(),
                test_dataset.get_num_instances(),
            )
        )

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=train_dataset.collate_fn,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=eval_dataset.collate_fn,
    )
    test_dataloader = DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=test_dataset.collate_fn,
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
313
314
315

    # Optimizer

mrfakename's avatar
mrfakename committed
316
317
318
319
320
321
322
    optimizer_parameters = list(model.transformer.parameters()) + list(
        model.fc.parameters()
    )
    num_trainable_parameters = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
hungchiayu1's avatar
updates  
hungchiayu1 committed
323
324
325

    if args.load_from_checkpoint:
        from safetensors.torch import load_file
mrfakename's avatar
mrfakename committed
326

hungchiayu1's avatar
updates  
hungchiayu1 committed
327
        w1 = load_file(args.load_from_checkpoint)
mrfakename's avatar
mrfakename committed
328
        model.load_state_dict(w1, strict=False)
hungchiayu1's avatar
updates  
hungchiayu1 committed
329
330
331
        logger.info("Weights loaded from{}".format(args.load_from_checkpoint))

    import copy
mrfakename's avatar
mrfakename committed
332

hungchiayu1's avatar
updates  
hungchiayu1 committed
333
334
335
336
337
338
    model.ref_transformer = copy.deepcopy(model.transformer)
    model.ref_transformer.requires_grad_ = False
    model.ref_transformer.eval()
    for param in model.ref_transformer.parameters():
        param.requires_grad = False

hungchiayu1's avatar
hungchiayu1 committed
339
    
hungchiayu1's avatar
updates  
hungchiayu1 committed
340
341
342


    optimizer = torch.optim.AdamW(
mrfakename's avatar
mrfakename committed
343
344
        optimizer_parameters,
        lr=learning_rate,
hungchiayu1's avatar
updates  
hungchiayu1 committed
345
346
347
348
349
350
351
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
mrfakename's avatar
mrfakename committed
352
353
354
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
355
356
357
358
359
360
361
    if args.max_train_steps is None:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
mrfakename's avatar
mrfakename committed
362
363
364
365
        num_warmup_steps=num_warmup_steps
        * gradient_accumulation_steps
        * accelerator.num_processes,
        num_training_steps=args.max_train_steps * gradient_accumulation_steps,
hungchiayu1's avatar
updates  
hungchiayu1 committed
366
367
368
369
    )

    # Prepare everything with our `accelerator`.
    vae, model, optimizer, lr_scheduler = accelerator.prepare(
mrfakename's avatar
mrfakename committed
370
        vae, model, optimizer, lr_scheduler
hungchiayu1's avatar
updates  
hungchiayu1 committed
371
372
373
374
375
376
377
    )

    train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
        train_dataloader, eval_dataloader, test_dataloader
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
mrfakename's avatar
mrfakename committed
378
379
380
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
381
382
383
384
385
386
387
388
389
390
    if overrode_max_train_steps:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

hungchiayu1's avatar
hungchiayu1 committed
391
    
hungchiayu1's avatar
updates  
hungchiayu1 committed
392
393

    # Train!
mrfakename's avatar
mrfakename committed
394
395
396
    total_batch_size = (
        per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
397
398
399
400
401

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {per_device_batch_size}")
mrfakename's avatar
mrfakename committed
402
403
404
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
405
406
407
408
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # Only show the progress bar once on each machine.
mrfakename's avatar
mrfakename committed
409
410
411
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
412
413
414
415

    completed_steps = 0
    starting_epoch = 0
    # Potentially load in the weights and states from a previous save
mrfakename's avatar
mrfakename committed
416
417
    resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
    if resume_from_checkpoint != "":
hungchiayu1's avatar
updates  
hungchiayu1 committed
418
419
420
421
        accelerator.load_state(resume_from_checkpoint)
        accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")

    # Duration of the audio clips in seconds
mrfakename's avatar
mrfakename committed
422
423
    best_loss = np.inf
    length = config["training"]["max_audio_duration"]
hungchiayu1's avatar
updates  
hungchiayu1 committed
424
425
426
427

    for epoch in range(starting_epoch, num_train_epochs):
        model.train()
        total_loss, total_val_loss = 0, 0
mrfakename's avatar
mrfakename committed
428

hungchiayu1's avatar
updates  
hungchiayu1 committed
429
430
431
432
433
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                device = accelerator.device
mrfakename's avatar
mrfakename committed
434
                text, audio_w, audio_l, duration, _ = batch
hungchiayu1's avatar
updates  
hungchiayu1 committed
435
436
437
438
439
440

                with torch.no_grad():
                    audio_list_w = []
                    audio_list_l = []
                    for audio_path in audio_w:

mrfakename's avatar
mrfakename committed
441
442
443
444
445
446
447
                        wav = read_wav_file(
                            audio_path, length
                        )  ## Only read the first 30 seconds of audio
                        if (
                            wav.shape[0] == 1
                        ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                            wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
448
449
450
                        audio_list_w.append(wav)

                    for audio_path in audio_l:
mrfakename's avatar
mrfakename committed
451
452
453
454
455
456
457
                        wav = read_wav_file(
                            audio_path, length
                        )  ## Only read the first 30 seconds of audio
                        if (
                            wav.shape[0] == 1
                        ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                            wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
458
459
                        audio_list_l.append(wav)

mrfakename's avatar
mrfakename committed
460
461
462
                    audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
                    audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
                    # audio_input_ = audio_input.to(device)
hungchiayu1's avatar
updates  
hungchiayu1 committed
463
464
                    unwrapped_vae = accelerator.unwrap_model(vae)

mrfakename's avatar
mrfakename committed
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
                    duration = torch.tensor(duration, device=device)
                    duration = torch.clamp(
                        duration, max=length
                    )  ## max duration is 30 sec

                    audio_latent_w = unwrapped_vae.encode(
                        audio_input_w
                    ).latent_dist.sample()
                    audio_latent_l = unwrapped_vae.encode(
                        audio_input_l
                    ).latent_dist.sample()
                    audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
                    audio_latent = audio_latent.transpose(
                        1, 2
                    )  ## Tranpose  to (bsz, seq_len, channel)

                loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
                    audio_latent, text, duration=duration, sft=False
                )
hungchiayu1's avatar
updates  
hungchiayu1 committed
484
485
486
487
488

                total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
mrfakename's avatar
mrfakename committed
489
                # if accelerator.sync_gradients:
hungchiayu1's avatar
updates  
hungchiayu1 committed
490
                if accelerator.sync_gradients:
mrfakename's avatar
mrfakename committed
491
                    # accelerator.clip_grad_value_(model.parameters(),1.0)
hungchiayu1's avatar
updates  
hungchiayu1 committed
492
493
494
495
496
497
498
499
500
501
502
                    progress_bar.update(1)
                    completed_steps += 1

            if completed_steps % 10 == 0 and accelerator.is_main_process:

                total_norm = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2

mrfakename's avatar
mrfakename committed
503
504
505
506
507
                total_norm = total_norm**0.5
                logger.info(
                    f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
                )

hungchiayu1's avatar
updates  
hungchiayu1 committed
508
509
510
511
512
513
                lr = lr_scheduler.get_last_lr()[0]

                result = {
                    "train_loss": loss.item(),
                    "grad_norm": total_norm,
                    "learning_rate": lr,
mrfakename's avatar
mrfakename committed
514
515
516
                    "raw_model_loss": raw_model_loss,
                    "raw_ref_loss": raw_ref_loss,
                    "implicit_acc": implicit_acc,
hungchiayu1's avatar
updates  
hungchiayu1 committed
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
                }

                # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
                wandb.log(result, step=completed_steps)

            # Checks if the accelerator has performed an optimization step behind the scenes

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps }"
                    if output_dir is not None:
                        output_dir = os.path.join(output_dir, output_dir)
                    accelerator.save_state(output_dir)

            if completed_steps >= args.max_train_steps:
                break

        model.eval()
mrfakename's avatar
mrfakename committed
535
536
537
        eval_progress_bar = tqdm(
            range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
        )
hungchiayu1's avatar
updates  
hungchiayu1 committed
538
539
540
541
        for step, batch in enumerate(eval_dataloader):
            with accelerator.accumulate(model) and torch.no_grad():
                device = model.device
                text, audios, duration, _ = batch
mrfakename's avatar
mrfakename committed
542

hungchiayu1's avatar
updates  
hungchiayu1 committed
543
544
                audio_list = []
                for audio_path in audios:
mrfakename's avatar
mrfakename committed
545
546
547
548
549
550
551
552

                    wav = read_wav_file(
                        audio_path, length
                    )  ## Only read the first 30 seconds of audio
                    if (
                        wav.shape[0] == 1
                    ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                        wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
553
554
                    audio_list.append(wav)

mrfakename's avatar
mrfakename committed
555
                audio_input = torch.stack(audio_list, dim=0)
hungchiayu1's avatar
updates  
hungchiayu1 committed
556
                audio_input = audio_input.to(device)
mrfakename's avatar
mrfakename committed
557
                duration = torch.tensor(duration, device=device)
hungchiayu1's avatar
updates  
hungchiayu1 committed
558
559
                unwrapped_vae = accelerator.unwrap_model(vae)
                audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
mrfakename's avatar
mrfakename committed
560
561
562
563
564
565
566
567
                audio_latent = audio_latent.transpose(
                    1, 2
                )  ## Tranpose  to (bsz, seq_len, channel)

                val_loss, _, _, _ = model(
                    audio_latent, text, duration=duration, sft=True
                )

hungchiayu1's avatar
updates  
hungchiayu1 committed
568
569
570
571
572
                total_val_loss += val_loss.detach().float()
                eval_progress_bar.update(1)

        if accelerator.is_main_process:

mrfakename's avatar
mrfakename committed
573
574
            result = {}
            result["epoch"] = float(epoch + 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
575

mrfakename's avatar
mrfakename committed
576
577
578
579
580
581
            result["epoch/train_loss"] = round(
                total_loss.item() / len(train_dataloader), 4
            )
            result["epoch/val_loss"] = round(
                total_val_loss.item() / len(eval_dataloader), 4
            )
hungchiayu1's avatar
updates  
hungchiayu1 committed
582

mrfakename's avatar
mrfakename committed
583
            wandb.log(result, step=completed_steps)
hungchiayu1's avatar
updates  
hungchiayu1 committed
584
585
586
587
588
589

            with open("{}/summary.jsonl".format(output_dir), "a") as f:
                f.write(json.dumps(result) + "\n\n")

            logger.info(result)

mrfakename's avatar
mrfakename committed
590
        save_checkpoint = True
hungchiayu1's avatar
updates  
hungchiayu1 committed
591
592
593
594
595
596
        accelerator.wait_for_everyone()
        if accelerator.is_main_process and args.checkpointing_steps == "best":
            if save_checkpoint:
                accelerator.save_state("{}/{}".format(output_dir, "best"))

            if (epoch + 1) % args.save_every == 0:
mrfakename's avatar
mrfakename committed
597
598
599
                accelerator.save_state(
                    "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
                )
hungchiayu1's avatar
updates  
hungchiayu1 committed
600
601

        if accelerator.is_main_process and args.checkpointing_steps == "epoch":
mrfakename's avatar
mrfakename committed
602
603
604
605
606
            accelerator.save_state(
                "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
            )


hungchiayu1's avatar
updates  
hungchiayu1 committed
607
608
if __name__ == "__main__":
    main()