"tools/testModelZooRequirements.txt" did not exist on "553ddbc073416a839d6a3c42bb91280856a7aa1c"
train_dpo.py 22.9 KB
Newer Older
hungchiayu1's avatar
updates  
hungchiayu1 committed
1
2
3
4
5
6
7
import time
import argparse
import json
import logging
import math
import os
import yaml
mrfakename's avatar
mrfakename committed
8

hungchiayu1's avatar
updates  
hungchiayu1 committed
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# from tqdm import tqdm
import copy
from pathlib import Path
import diffusers
import datasets
import numpy as np
import pandas as pd
import wandb
import transformers
import torch
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
from transformers import SchedulerType, get_scheduler
mrfakename's avatar
mrfakename committed
26
from tangoflux.model import TangoFlux
hungchiayu1's avatar
updates  
hungchiayu1 committed
27
from datasets import load_dataset, Audio
mrfakename's avatar
mrfakename committed
28
from tangoflux.utils import Text2AudioDataset, read_wav_file, DPOText2AudioDataset
hungchiayu1's avatar
updates  
hungchiayu1 committed
29
30
31
32
33
34
35
36

from diffusers import AutoencoderOobleck
import torchaudio

logger = get_logger(__name__)


def parse_args():
mrfakename's avatar
mrfakename committed
37
38
39
    parser = argparse.ArgumentParser(
        description="Rectified flow for text to audio generation task."
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
40
41

    parser.add_argument(
mrfakename's avatar
mrfakename committed
42
43
44
        "--num_examples",
        type=int,
        default=-1,
hungchiayu1's avatar
updates  
hungchiayu1 committed
45
46
47
48
        help="How many examples to use for training and validation.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
49
50
51
        "--text_column",
        type=str,
        default="captions",
hungchiayu1's avatar
updates  
hungchiayu1 committed
52
53
54
        help="The name of the column in the datasets containing the input texts.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
55
56
57
        "--audio_column",
        type=str,
        default="location",
hungchiayu1's avatar
updates  
hungchiayu1 committed
58
59
60
        help="The name of the column in the datasets containing the audio paths.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
61
62
63
64
        "--adam_beta1",
        type=float,
        default=0.9,
        help="The beta1 parameter for the Adam optimizer.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
65
66
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
67
68
69
70
        "--adam_beta2",
        type=float,
        default=0.95,
        help="The beta2 parameter for the Adam optimizer.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
71
72
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
73
74
75
        "--config",
        type=str,
        default="tangoflux_config.yaml",
hungchiayu1's avatar
updates  
hungchiayu1 committed
76
77
78
79
        help="Config file defining the model size.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
80
        "--weight_decay", type=float, default=1e-8, help="Weight decay to use."
hungchiayu1's avatar
updates  
hungchiayu1 committed
81
82
83
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
84
85
86
        "--max_train_steps",
        type=int,
        default=None,
hungchiayu1's avatar
updates  
hungchiayu1 committed
87
88
89
90
        help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
91
92
93
        "--lr_scheduler_type",
        type=SchedulerType,
        default="linear",
hungchiayu1's avatar
updates  
hungchiayu1 committed
94
        help="The scheduler type to use.",
mrfakename's avatar
mrfakename committed
95
96
97
98
99
100
101
102
        choices=[
            "linear",
            "cosine",
            "cosine_with_restarts",
            "polynomial",
            "constant",
            "constant_with_warmup",
        ],
hungchiayu1's avatar
updates  
hungchiayu1 committed
103
104
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
105
106
107
108
        "--num_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the lr scheduler.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
109
110
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
111
112
113
114
        "--adam_epsilon",
        type=float,
        default=1e-08,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
updates  
hungchiayu1 committed
115
116
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
117
118
119
120
        "--adam_weight_decay",
        type=float,
        default=1e-2,
        help="Epsilon value for the Adam optimizer",
hungchiayu1's avatar
updates  
hungchiayu1 committed
121
122
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
123
        "--seed", type=int, default=None, help="A seed for reproducible training."
hungchiayu1's avatar
updates  
hungchiayu1 committed
124
125
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
126
127
128
        "--checkpointing_steps",
        type=str,
        default="best",
hungchiayu1's avatar
updates  
hungchiayu1 committed
129
130
131
        help="Whether the various states should be saved at the end of every 'epoch' or 'best' whenever validation loss decreases.",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
132
133
134
135
        "--save_every",
        type=int,
        default=5,
        help="Save model after every how many epochs when checkpointing_steps is set to best.",
hungchiayu1's avatar
updates  
hungchiayu1 committed
136
137
138
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
139
140
141
        "--resume_from_checkpoint",
        type=str,
        default=None,
hungchiayu1's avatar
updates  
hungchiayu1 committed
142
143
144
145
        help="If the training should continue from a local checkpoint folder.",
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
146
147
148
        "--report_to",
        type=str,
        default="all",
hungchiayu1's avatar
updates  
hungchiayu1 committed
149
150
151
152
153
154
155
156
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`,'
            ' `"wandb"`, `"comet_ml"` and `"clearml"`. Use `"all"` (default) to report to all integrations.'
            "Only applicable when `--with_tracking` is passed."
        ),
    )

    parser.add_argument(
mrfakename's avatar
mrfakename committed
157
158
159
        "--load_from_checkpoint",
        type=str,
        default=None,
hungchiayu1's avatar
updates  
hungchiayu1 committed
160
161
162
        help="Whether to continue training from a model weight",
    )
    parser.add_argument(
mrfakename's avatar
mrfakename committed
163
164
165
        "--audio_length",
        type=float,
        default=30,
hungchiayu1's avatar
updates  
hungchiayu1 committed
166
167
168
169
170
171
        help="Audio duration",
    )

    args = parser.parse_args()

    # Sanity checks
mrfakename's avatar
mrfakename committed
172
173
174
175
176
177
178
179
180
    # if args.train_file is None and args.validation_file is None:
    #   raise ValueError("Need a training/validation file.")
    # else:
    #  if args.train_file is not None:
    #     extension = args.train_file.split(".")[-1]
    #    assert extension in ["csv", "json"], "`train_file` should be a csv or a json file."
    # if args.validation_file is not None:
    #   extension = args.validation_file.split(".")[-1]
    #  assert extension in ["csv", "json"], "`validation_file` should be a csv or a json file."
hungchiayu1's avatar
updates  
hungchiayu1 committed
181
182
183

    return args

mrfakename's avatar
mrfakename committed
184

hungchiayu1's avatar
updates  
hungchiayu1 committed
185
186
187
188
189
def main():
    args = parse_args()
    accelerator_log_kwargs = {}

    def load_config(config_path):
mrfakename's avatar
mrfakename committed
190
        with open(config_path, "r") as file:
hungchiayu1's avatar
updates  
hungchiayu1 committed
191
192
193
194
            return yaml.safe_load(file)

    config = load_config(args.config)

mrfakename's avatar
mrfakename committed
195
196
197
198
199
    learning_rate = float(config["training"]["learning_rate"])
    num_train_epochs = int(config["training"]["num_train_epochs"])
    num_warmup_steps = int(config["training"]["num_warmup_steps"])
    per_device_batch_size = int(config["training"]["per_device_batch_size"])
    gradient_accumulation_steps = int(config["training"]["gradient_accumulation_steps"])
hungchiayu1's avatar
updates  
hungchiayu1 committed
200

mrfakename's avatar
mrfakename committed
201
    output_dir = config["paths"]["output_dir"]
hungchiayu1's avatar
updates  
hungchiayu1 committed
202

mrfakename's avatar
mrfakename committed
203
204
205
206
    accelerator = Accelerator(
        gradient_accumulation_steps=gradient_accumulation_steps,
        **accelerator_log_kwargs,
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    datasets.utils.logging.set_verbosity_error()
    diffusers.utils.logging.set_verbosity_error()
    transformers.utils.logging.set_verbosity_error()

    # If passed along, set the training seed now.
    if args.seed is not None:
        set_seed(args.seed)

    # Handle output directory creation and wandb tracking
    if accelerator.is_main_process:
        if output_dir is None or output_dir == "":
            output_dir = "saved/" + str(int(time.time()))
mrfakename's avatar
mrfakename committed
228

hungchiayu1's avatar
updates  
hungchiayu1 committed
229
230
            if not os.path.exists("saved"):
                os.makedirs("saved")
mrfakename's avatar
mrfakename committed
231

hungchiayu1's avatar
updates  
hungchiayu1 committed
232
            os.makedirs(output_dir, exist_ok=True)
mrfakename's avatar
mrfakename committed
233

hungchiayu1's avatar
updates  
hungchiayu1 committed
234
235
236
237
238
239
240
241
242
        elif output_dir is not None:
            os.makedirs(output_dir, exist_ok=True)

        os.makedirs("{}/{}".format(output_dir, "outputs"), exist_ok=True)
        with open("{}/summary.jsonl".format(output_dir), "a") as f:
            f.write(json.dumps(dict(vars(args))) + "\n\n")

        accelerator.project_configuration.automatic_checkpoint_naming = False

mrfakename's avatar
mrfakename committed
243
244
245
246
        wandb.init(
            project="Text to Audio Flow matching DPO",
            settings=wandb.Settings(_disable_stats=True),
        )
hungchiayu1's avatar
updates  
hungchiayu1 committed
247
248
249
250
251

    accelerator.wait_for_everyone()

    # Get the datasets
    data_files = {}
mrfakename's avatar
mrfakename committed
252
253
254
255
256
257
258
259
260
261
262
263
264
265
    # if args.train_file is not None:
    if config["paths"]["train_file"] != "":
        data_files["train"] = config["paths"]["train_file"]
    # if args.validation_file is not None:
    if config["paths"]["val_file"] != "":
        data_files["validation"] = config["paths"]["val_file"]
    if config["paths"]["test_file"] != "":
        data_files["test"] = config["paths"]["test_file"]
    else:
        data_files["test"] = config["paths"]["val_file"]

    extension = "json"
    train_dataset = load_dataset(extension, data_files=data_files["train"])
    data_files.pop("train")
hungchiayu1's avatar
updates  
hungchiayu1 committed
266
267
268
    raw_datasets = load_dataset(extension, data_files=data_files)
    text_column, audio_column = args.text_column, args.audio_column

mrfakename's avatar
mrfakename committed
269
270
271
272
    model = TangoFlux(config=config["model"], initialize_reference_model=True)
    vae = AutoencoderOobleck.from_pretrained(
        "stabilityai/stable-audio-open-1.0", subfolder="vae"
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
273
274
275
276
277

    ## Freeze vae
    for param in vae.parameters():
        vae.requires_grad = False
        vae.eval()
mrfakename's avatar
mrfakename committed
278

hungchiayu1's avatar
updates  
hungchiayu1 committed
279
280
281
282
283
    ## Freeze text encoder param
    for param in model.text_encoder.parameters():
        param.requires_grad = False
        model.text_encoder.eval()

mrfakename's avatar
mrfakename committed
284
    prefix = ""
hungchiayu1's avatar
updates  
hungchiayu1 committed
285
286

    with accelerator.main_process_first():
mrfakename's avatar
mrfakename committed
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        train_dataset = DPOText2AudioDataset(
            train_dataset["train"],
            prefix,
            text_column,
            "chosen",
            "reject",
            "duration",
            args.num_examples,
        )
        eval_dataset = Text2AudioDataset(
            raw_datasets["validation"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )
        test_dataset = Text2AudioDataset(
            raw_datasets["test"],
            prefix,
            text_column,
            audio_column,
            "duration",
            args.num_examples,
        )

        accelerator.print(
            "Num instances in train: {}, validation: {}, test: {}".format(
                train_dataset.get_num_instances(),
                eval_dataset.get_num_instances(),
                test_dataset.get_num_instances(),
            )
        )

    train_dataloader = DataLoader(
        train_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=train_dataset.collate_fn,
    )
    eval_dataloader = DataLoader(
        eval_dataset,
        shuffle=True,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=eval_dataset.collate_fn,
    )
    test_dataloader = DataLoader(
        test_dataset,
        shuffle=False,
        batch_size=config["training"]["per_device_batch_size"],
        collate_fn=test_dataset.collate_fn,
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
339
340
341

    # Optimizer

mrfakename's avatar
mrfakename committed
342
343
344
345
346
347
348
    optimizer_parameters = list(model.transformer.parameters()) + list(
        model.fc.parameters()
    )
    num_trainable_parameters = sum(
        p.numel() for p in model.parameters() if p.requires_grad
    )
    accelerator.print("Num trainable parameters: {}".format(num_trainable_parameters))
hungchiayu1's avatar
updates  
hungchiayu1 committed
349
350
351

    if args.load_from_checkpoint:
        from safetensors.torch import load_file
mrfakename's avatar
mrfakename committed
352

hungchiayu1's avatar
updates  
hungchiayu1 committed
353
        w1 = load_file(args.load_from_checkpoint)
mrfakename's avatar
mrfakename committed
354
        model.load_state_dict(w1, strict=False)
hungchiayu1's avatar
updates  
hungchiayu1 committed
355
356
357
        logger.info("Weights loaded from{}".format(args.load_from_checkpoint))

    import copy
mrfakename's avatar
mrfakename committed
358

hungchiayu1's avatar
updates  
hungchiayu1 committed
359
360
361
362
363
364
365
    model.ref_transformer = copy.deepcopy(model.transformer)
    model.ref_transformer.requires_grad_ = False
    model.ref_transformer.eval()
    for param in model.ref_transformer.parameters():
        param.requires_grad = False

    @torch.no_grad()
mrfakename's avatar
mrfakename committed
366
367
368
    def initialize_or_update_ref_transformer(
        model, accelerator: Accelerator, alpha=0.5
    ):
hungchiayu1's avatar
updates  
hungchiayu1 committed
369
370
371
372
373
374
        """
        Initializes or updates ref_transformer as alpha * ref + 1-alpha * transformer.

        Args:
            model (torch.nn.Module): The main model containing the 'transformer' attribute.
            accelerator (Accelerator): The Accelerator instance used to unwrap the model.
mrfakename's avatar
mrfakename committed
375
            initial_ref_model (torch.nn.Module, optional): An optional initial reference model.
hungchiayu1's avatar
updates  
hungchiayu1 committed
376
                If not provided, ref_transformer is initialized as a copy of transformer.
mrfakename's avatar
mrfakename committed
377

hungchiayu1's avatar
updates  
hungchiayu1 committed
378
379
380
381
382
        Returns:
            torch.nn.Module: The model with the updated ref_transformer.
        """
        # Unwrap the model to access the original underlying model
        unwrapped_model = accelerator.unwrap_model(model)
mrfakename's avatar
mrfakename committed
383

hungchiayu1's avatar
updates  
hungchiayu1 committed
384
        with torch.no_grad():
mrfakename's avatar
mrfakename committed
385
386
387
388
389
390
            for ref_param, model_param in zip(
                unwrapped_model.ref_transformer.parameters(),
                unwrapped_model.transformer.parameters(),
            ):
                average_param = alpha * ref_param.data + (1 - alpha) * model_param.data

hungchiayu1's avatar
updates  
hungchiayu1 committed
391
392
393
                ref_param.data.copy_(average_param)

        unwrapped_model.ref_transformer.eval()
mrfakename's avatar
mrfakename committed
394
        unwrapped_model.ref_transformer.requires_grad_ = False
hungchiayu1's avatar
updates  
hungchiayu1 committed
395
396
397
398
399
400
401
402
403
404
405
406
        for param in unwrapped_model.ref_transformer.parameters():
            param.requires_grad = False

        return model

    model.ref_transformer = copy.deepcopy(model.transformer)
    model.ref_transformer.requires_grad_ = False
    model.ref_transformer.eval()
    for param in model.ref_transformer.parameters():
        param.requires_grad = False

    optimizer = torch.optim.AdamW(
mrfakename's avatar
mrfakename committed
407
408
        optimizer_parameters,
        lr=learning_rate,
hungchiayu1's avatar
updates  
hungchiayu1 committed
409
410
411
412
413
414
415
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Scheduler and math around the number of training steps.
    overrode_max_train_steps = False
mrfakename's avatar
mrfakename committed
416
417
418
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
419
420
421
422
423
424
425
    if args.max_train_steps is None:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
        overrode_max_train_steps = True

    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
mrfakename's avatar
mrfakename committed
426
427
428
429
        num_warmup_steps=num_warmup_steps
        * gradient_accumulation_steps
        * accelerator.num_processes,
        num_training_steps=args.max_train_steps * gradient_accumulation_steps,
hungchiayu1's avatar
updates  
hungchiayu1 committed
430
431
432
433
    )

    # Prepare everything with our `accelerator`.
    vae, model, optimizer, lr_scheduler = accelerator.prepare(
mrfakename's avatar
mrfakename committed
434
        vae, model, optimizer, lr_scheduler
hungchiayu1's avatar
updates  
hungchiayu1 committed
435
436
437
438
439
440
441
    )

    train_dataloader, eval_dataloader, test_dataloader = accelerator.prepare(
        train_dataloader, eval_dataloader, test_dataloader
    )

    # We need to recalculate our total training steps as the size of the training dataloader may have changed.
mrfakename's avatar
mrfakename committed
442
443
444
    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
445
446
447
448
449
450
451
452
453
454
455
456
457
458
    if overrode_max_train_steps:
        args.max_train_steps = num_train_epochs * num_update_steps_per_epoch
    # Afterwards we recalculate our number of training epochs
    num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)

    # Figure out how many steps we should save the Accelerator states
    checkpointing_steps = args.checkpointing_steps
    if checkpointing_steps is not None and checkpointing_steps.isdigit():
        checkpointing_steps = int(checkpointing_steps)

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.

    # Train!
mrfakename's avatar
mrfakename committed
459
460
461
    total_batch_size = (
        per_device_batch_size * accelerator.num_processes * gradient_accumulation_steps
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
462
463
464
465
466

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num Epochs = {num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {per_device_batch_size}")
mrfakename's avatar
mrfakename committed
467
468
469
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
470
471
472
473
    logger.info(f"  Gradient Accumulation steps = {gradient_accumulation_steps}")
    logger.info(f"  Total optimization steps = {args.max_train_steps}")

    # Only show the progress bar once on each machine.
mrfakename's avatar
mrfakename committed
474
475
476
    progress_bar = tqdm(
        range(args.max_train_steps), disable=not accelerator.is_local_main_process
    )
hungchiayu1's avatar
updates  
hungchiayu1 committed
477
478
479
480

    completed_steps = 0
    starting_epoch = 0
    # Potentially load in the weights and states from a previous save
mrfakename's avatar
mrfakename committed
481
482
    resume_from_checkpoint = config["paths"]["resume_from_checkpoint"]
    if resume_from_checkpoint != "":
hungchiayu1's avatar
updates  
hungchiayu1 committed
483
484
485
486
        accelerator.load_state(resume_from_checkpoint)
        accelerator.print(f"Resumed from local checkpoint: {resume_from_checkpoint}")

    # Duration of the audio clips in seconds
mrfakename's avatar
mrfakename committed
487
488
    best_loss = np.inf
    length = config["training"]["max_audio_duration"]
hungchiayu1's avatar
updates  
hungchiayu1 committed
489
490
491
492

    for epoch in range(starting_epoch, num_train_epochs):
        model.train()
        total_loss, total_val_loss = 0, 0
mrfakename's avatar
mrfakename committed
493

hungchiayu1's avatar
updates  
hungchiayu1 committed
494
495
496
497
498
        for step, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            with accelerator.accumulate(model):
                optimizer.zero_grad()
                device = accelerator.device
mrfakename's avatar
mrfakename committed
499
                text, audio_w, audio_l, duration, _ = batch
hungchiayu1's avatar
updates  
hungchiayu1 committed
500
501
502
503
504
505

                with torch.no_grad():
                    audio_list_w = []
                    audio_list_l = []
                    for audio_path in audio_w:

mrfakename's avatar
mrfakename committed
506
507
508
509
510
511
512
                        wav = read_wav_file(
                            audio_path, length
                        )  ## Only read the first 30 seconds of audio
                        if (
                            wav.shape[0] == 1
                        ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                            wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
513
514
515
                        audio_list_w.append(wav)

                    for audio_path in audio_l:
mrfakename's avatar
mrfakename committed
516
517
518
519
520
521
522
                        wav = read_wav_file(
                            audio_path, length
                        )  ## Only read the first 30 seconds of audio
                        if (
                            wav.shape[0] == 1
                        ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                            wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
523
524
                        audio_list_l.append(wav)

mrfakename's avatar
mrfakename committed
525
526
527
                    audio_input_w = torch.stack(audio_list_w, dim=0).to(device)
                    audio_input_l = torch.stack(audio_list_l, dim=0).to(device)
                    # audio_input_ = audio_input.to(device)
hungchiayu1's avatar
updates  
hungchiayu1 committed
528
529
                    unwrapped_vae = accelerator.unwrap_model(vae)

mrfakename's avatar
mrfakename committed
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
                    duration = torch.tensor(duration, device=device)
                    duration = torch.clamp(
                        duration, max=length
                    )  ## max duration is 30 sec

                    audio_latent_w = unwrapped_vae.encode(
                        audio_input_w
                    ).latent_dist.sample()
                    audio_latent_l = unwrapped_vae.encode(
                        audio_input_l
                    ).latent_dist.sample()
                    audio_latent = torch.cat((audio_latent_w, audio_latent_l), dim=0)
                    audio_latent = audio_latent.transpose(
                        1, 2
                    )  ## Tranpose  to (bsz, seq_len, channel)

                loss, raw_model_loss, raw_ref_loss, implicit_acc = model(
                    audio_latent, text, duration=duration, sft=False
                )
hungchiayu1's avatar
updates  
hungchiayu1 committed
549
550
551
552
553

                total_loss += loss.detach().float()
                accelerator.backward(loss)
                optimizer.step()
                lr_scheduler.step()
mrfakename's avatar
mrfakename committed
554
                # if accelerator.sync_gradients:
hungchiayu1's avatar
updates  
hungchiayu1 committed
555
                if accelerator.sync_gradients:
mrfakename's avatar
mrfakename committed
556
                    # accelerator.clip_grad_value_(model.parameters(),1.0)
hungchiayu1's avatar
updates  
hungchiayu1 committed
557
558
559
560
561
562
563
564
565
566
567
                    progress_bar.update(1)
                    completed_steps += 1

            if completed_steps % 10 == 0 and accelerator.is_main_process:

                total_norm = 0.0
                for p in model.parameters():
                    if p.grad is not None:
                        param_norm = p.grad.data.norm(2)
                        total_norm += param_norm.item() ** 2

mrfakename's avatar
mrfakename committed
568
569
570
571
572
                total_norm = total_norm**0.5
                logger.info(
                    f"Step {completed_steps}, Loss: {loss.item()}, Grad Norm: {total_norm}"
                )

hungchiayu1's avatar
updates  
hungchiayu1 committed
573
574
575
576
577
578
                lr = lr_scheduler.get_last_lr()[0]

                result = {
                    "train_loss": loss.item(),
                    "grad_norm": total_norm,
                    "learning_rate": lr,
mrfakename's avatar
mrfakename committed
579
580
581
                    "raw_model_loss": raw_model_loss,
                    "raw_ref_loss": raw_ref_loss,
                    "implicit_acc": implicit_acc,
hungchiayu1's avatar
updates  
hungchiayu1 committed
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
                }

                # result["val_loss"] = round(total_val_loss.item()/len(eval_dataloader), 4)
                wandb.log(result, step=completed_steps)

            # Checks if the accelerator has performed an optimization step behind the scenes

            if isinstance(checkpointing_steps, int):
                if completed_steps % checkpointing_steps == 0:
                    output_dir = f"step_{completed_steps }"
                    if output_dir is not None:
                        output_dir = os.path.join(output_dir, output_dir)
                    accelerator.save_state(output_dir)

            if completed_steps >= args.max_train_steps:
                break

        model.eval()
mrfakename's avatar
mrfakename committed
600
601
602
        eval_progress_bar = tqdm(
            range(len(eval_dataloader)), disable=not accelerator.is_local_main_process
        )
hungchiayu1's avatar
updates  
hungchiayu1 committed
603
604
605
606
        for step, batch in enumerate(eval_dataloader):
            with accelerator.accumulate(model) and torch.no_grad():
                device = model.device
                text, audios, duration, _ = batch
mrfakename's avatar
mrfakename committed
607

hungchiayu1's avatar
updates  
hungchiayu1 committed
608
609
                audio_list = []
                for audio_path in audios:
mrfakename's avatar
mrfakename committed
610
611
612
613
614
615
616
617

                    wav = read_wav_file(
                        audio_path, length
                    )  ## Only read the first 30 seconds of audio
                    if (
                        wav.shape[0] == 1
                    ):  ## If this audio is mono, we repeat the channel so it become "fake stereo"
                        wav = wav.repeat(2, 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
618
619
                    audio_list.append(wav)

mrfakename's avatar
mrfakename committed
620
                audio_input = torch.stack(audio_list, dim=0)
hungchiayu1's avatar
updates  
hungchiayu1 committed
621
                audio_input = audio_input.to(device)
mrfakename's avatar
mrfakename committed
622
                duration = torch.tensor(duration, device=device)
hungchiayu1's avatar
updates  
hungchiayu1 committed
623
624
                unwrapped_vae = accelerator.unwrap_model(vae)
                audio_latent = unwrapped_vae.encode(audio_input).latent_dist.sample()
mrfakename's avatar
mrfakename committed
625
626
627
628
629
630
631
632
                audio_latent = audio_latent.transpose(
                    1, 2
                )  ## Tranpose  to (bsz, seq_len, channel)

                val_loss, _, _, _ = model(
                    audio_latent, text, duration=duration, sft=True
                )

hungchiayu1's avatar
updates  
hungchiayu1 committed
633
634
635
636
637
                total_val_loss += val_loss.detach().float()
                eval_progress_bar.update(1)

        if accelerator.is_main_process:

mrfakename's avatar
mrfakename committed
638
639
            result = {}
            result["epoch"] = float(epoch + 1)
hungchiayu1's avatar
updates  
hungchiayu1 committed
640

mrfakename's avatar
mrfakename committed
641
642
643
644
645
646
            result["epoch/train_loss"] = round(
                total_loss.item() / len(train_dataloader), 4
            )
            result["epoch/val_loss"] = round(
                total_val_loss.item() / len(eval_dataloader), 4
            )
hungchiayu1's avatar
updates  
hungchiayu1 committed
647

mrfakename's avatar
mrfakename committed
648
            wandb.log(result, step=completed_steps)
hungchiayu1's avatar
updates  
hungchiayu1 committed
649
650
651
652
653
654

            with open("{}/summary.jsonl".format(output_dir), "a") as f:
                f.write(json.dumps(result) + "\n\n")

            logger.info(result)

mrfakename's avatar
mrfakename committed
655
        save_checkpoint = True
hungchiayu1's avatar
updates  
hungchiayu1 committed
656
657
658
659
660
661
        accelerator.wait_for_everyone()
        if accelerator.is_main_process and args.checkpointing_steps == "best":
            if save_checkpoint:
                accelerator.save_state("{}/{}".format(output_dir, "best"))

            if (epoch + 1) % args.save_every == 0:
mrfakename's avatar
mrfakename committed
662
663
664
                accelerator.save_state(
                    "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
                )
hungchiayu1's avatar
updates  
hungchiayu1 committed
665
666

        if accelerator.is_main_process and args.checkpointing_steps == "epoch":
mrfakename's avatar
mrfakename committed
667
668
669
670
671
            accelerator.save_state(
                "{}/{}".format(output_dir, "epoch_" + str(epoch + 1))
            )


hungchiayu1's avatar
updates  
hungchiayu1 committed
672
673
if __name__ == "__main__":
    main()