run_parler_tts_training.py 46 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Yoach Lacombe's avatar
Yoach Lacombe committed
17
""" Train Parler-TTS using 🤗 Accelerate"""
18
19
20
21

import logging
import os
import re
22
import sys
Yoach Lacombe's avatar
Yoach Lacombe committed
23
import time
24
from multiprocess import set_start_method
25
from datetime import timedelta
26
27

from tqdm import tqdm
Yoach Lacombe's avatar
Yoach Lacombe committed
28
from pathlib import Path
29
30

import torch
31
32
33
34
35
from torch.utils.data import DataLoader

import datasets
from datasets import DatasetDict, Dataset, IterableDataset, concatenate_datasets

36
from huggingface_hub import HfApi
37
38

import transformers
Yoach Lacombe's avatar
Yoach Lacombe committed
39
from transformers import AutoFeatureExtractor, AutoTokenizer, HfArgumentParser
40
from transformers.trainer_pt_utils import LengthGroupedSampler
Yoach Lacombe's avatar
Yoach Lacombe committed
41
from transformers.optimization import get_scheduler
Yoach Lacombe's avatar
Yoach Lacombe committed
42
from transformers.utils import send_example_telemetry
43

44

45
46
47
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs, TorchDynamoPlugin
from accelerate.utils.memory import release_memory
48

Yoach Lacombe's avatar
Yoach Lacombe committed
49
50
from parler_tts import (
    ParlerTTSConfig,
51
    ParlerTTSForConditionalGeneration,
Yoach Lacombe's avatar
Yoach Lacombe committed
52
53
    build_delay_pattern_mask,
)
54

Dan Lyth's avatar
Dan Lyth committed
55
56
57
58
from training.utils import get_last_checkpoint, rotate_checkpoints, log_pred, log_metric
from training.arguments import ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments
from training.data import load_multiple_datasets, DataCollatorParlerTTSWithPadding, DataCollatorEncodecWithPadding
from training.eval import clap_similarity, wer
59
60


61
logger = logging.getLogger(__name__)
62

Yoach Lacombe's avatar
Yoach Lacombe committed
63

64
65
66
67
68
def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

Yoach Lacombe's avatar
Yoach Lacombe committed
69
    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, ParlerTTSTrainingArguments))
70
71
72
73
74
75
76
77
78
    if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
        # If we pass only one argument to the script and it's the path to a json file,
        # let's parse it to get our arguments.
        model_args, data_args, training_args = parser.parse_json_file(json_file=os.path.abspath(sys.argv[1]))
    else:
        model_args, data_args, training_args = parser.parse_args_into_dataclasses()

    # Sending telemetry. Tracking the example usage helps us better allocate resources to maintain them. The
    # information sent is the one passed as arguments along with your Python/PyTorch versions.
Yoach Lacombe's avatar
Yoach Lacombe committed
79
    send_example_telemetry("run_parler_tts", model_args, data_args)
Yoach Lacombe's avatar
Yoach Lacombe committed
80

Yoach Lacombe's avatar
Yoach Lacombe committed
81
82
83
84
85
86
    if training_args.dtype == "float16":
        mixed_precision = "fp16"
    elif training_args.dtype == "bfloat16":
        mixed_precision = "bf16"
    else:
        mixed_precision = "no"
Yoach Lacombe's avatar
Yoach Lacombe committed
87
88
89
90
91
92
93
94
95

    if data_args.pad_to_max_length and (
        data_args.max_duration_in_seconds is None
        or data_args.max_prompt_token_length is None
        or data_args.max_description_token_length is None
    ):
        raise ValueError(
            "`pad_to_max_length` is `True` but one of the following parameters has not been set: `max_duration_in_seconds`, `max_prompt_token_length`, `max_description_token_length`"
        )
96
97

    padding = "max_length" if data_args.pad_to_max_length else "longest"
98

99
    ####### A. Preparation
100
101
102
    kwargs_handlers = [InitProcessGroupKwargs(timeout=timedelta(minutes=60))]
    if training_args.torch_compile:
        # TODO(YL): add more compile modes?
Yoach Lacombe's avatar
Yoach Lacombe committed
103
104
        kwargs_handlers.append(TorchDynamoPlugin(backend="inductor", mode="default"))  # reduce-overhead

Yoach Lacombe's avatar
Yoach Lacombe committed
105
106
107
108
109
    accelerator = Accelerator(
        gradient_accumulation_steps=training_args.gradient_accumulation_steps,
        mixed_precision=mixed_precision,
        log_with=training_args.report_to,
        project_dir=training_args.output_dir,
110
        kwargs_handlers=kwargs_handlers,
Yoach Lacombe's avatar
Yoach Lacombe committed
111
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133

    accelerator.init_trackers(
        project_name=data_args.wandb_project,
        config={
            "learning_rate": training_args.learning_rate,
            "model_name_or_path": model_args.model_name_or_path,
            "num_train_epochs": training_args.num_train_epochs,
            "gradient_accumulation_steps": training_args.gradient_accumulation_steps,
            "per_device_train_batch_size": training_args.per_device_train_batch_size,
            "global_batch_size": training_args.per_device_train_batch_size * accelerator.num_processes,
            "mixed_precision": mixed_precision,
            "lr_scheduler_type": training_args.lr_scheduler_type,
            "warmup_steps": training_args.warmup_steps,
            "freeze_text_encoder": model_args.freeze_text_encoder,
            "max_duration_in_seconds": data_args.max_duration_in_seconds,
            "weight_decay": training_args.weight_decay,
            "adam_beta1": training_args.adam_beta1,
            "adam_beta2": training_args.adam_beta2,
            "temperature": model_args.temperature,
        },
    )

Yoach Lacombe's avatar
Yoach Lacombe committed
134
    # Detecting last checkpoint and eventually continue from last checkpoint
135
136
137
138
139
140
141
142
    last_checkpoint = None
    if os.path.isdir(training_args.output_dir) and training_args.do_train and not training_args.overwrite_output_dir:
        last_checkpoint = get_last_checkpoint(training_args.output_dir)
        if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
            raise ValueError(
                f"Output directory ({training_args.output_dir}) already exists and is not empty. "
                "Use --overwrite_output_dir to overcome."
            )
Yoach Lacombe's avatar
Yoach Lacombe committed
143
        elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
144
145
146
147
148
149
150
151
152
153
154
            logger.info(
                f"Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change "
                "the `--output_dir` or add `--overwrite_output_dir` to train from scratch."
            )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        handlers=[logging.StreamHandler(sys.stdout)],
    )
155
    logger.setLevel(logging.INFO if accelerator.is_main_process else logging.WARN)
156

Yoach Lacombe's avatar
Yoach Lacombe committed
157
    # Log a small summary on each proces
158
159
160
161
    logger.warning(
        f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}, "
        f"distributed training: {training_args.parallel_mode.value == 'distributed'}, 16-bits training: {training_args.fp16}"
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
162
163
164
165

    # Set the verbosity to info of the Transformers logger (on main process only)
    if accelerator.is_local_main_process:
        datasets.utils.logging.set_verbosity_warning()
166
        transformers.utils.logging.set_verbosity_info()
Yoach Lacombe's avatar
Yoach Lacombe committed
167
168
169
170
    else:
        datasets.utils.logging.set_verbosity_error()
        transformers.utils.logging.set_verbosity_error()

171
172
173
174
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed before initializing model.
    set_seed(training_args.seed)
175
    num_workers = data_args.preprocessing_num_workers
Yoach Lacombe's avatar
Yoach Lacombe committed
176

177
178
179
    # 1. First, lett's instantiate the feature extractor, tokenizers and model
    # Note for distributed training, the .from_pretrained methods guarantee that only
    # one local process can concurrently download model & vocab.
Yoach Lacombe's avatar
Yoach Lacombe committed
180

181
182
183
184
185
186
187
188
    # load feature extractor
    feature_extractor = AutoFeatureExtractor.from_pretrained(
        model_args.feature_extractor_name or model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        token=data_args.token,
        trust_remote_code=data_args.trust_remote_code,
    )
    sampling_rate = feature_extractor.sampling_rate
Yoach Lacombe's avatar
Yoach Lacombe committed
189

190
191
192
193
194
195
196
    # load prompt tokenizer
    prompt_tokenizer = AutoTokenizer.from_pretrained(
        model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        token=data_args.token,
        trust_remote_code=data_args.trust_remote_code,
        use_fast=model_args.use_fast_tokenizer,
Yoach Lacombe's avatar
Yoach Lacombe committed
197
        padding_side="left",  # prompt has to be padded on the left bc it's preprend to codebooks hidden states
198
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
199

200
201
202
203
204
205
206
207
    # load description tokenizer
    description_tokenizer = AutoTokenizer.from_pretrained(
        model_args.description_tokenizer_name or model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        token=data_args.token,
        trust_remote_code=data_args.trust_remote_code,
        use_fast=model_args.use_fast_tokenizer,
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
208

209
    if model_args.use_fast_tokenizer:
Yoach Lacombe's avatar
Yoach Lacombe committed
210
211
212
        logger.warning(
            "Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235"
        )
213
214
        prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
        description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
215

216
    # 2. Now, let's load the dataset
Yoach Lacombe's avatar
Yoach Lacombe committed
217

218
219
    if data_args.save_to_disk is not None:
        os.makedirs(data_args.save_to_disk, exist_ok=True)
Yoach Lacombe's avatar
Yoach Lacombe committed
220

221
222
223
224
    # assume that the dataset has been saved to `save_to_disk` if the latter is not empty
    dataset_was_precomputed = len(os.listdir(data_args.save_to_disk)) > 0
    if dataset_was_precomputed:
        vectorized_datasets = datasets.load_from_disk(data_args.save_to_disk)
Yoach Lacombe's avatar
Yoach Lacombe committed
225
    else:
226
227
228
229
        raw_datasets = DatasetDict()

        columns_to_keep = {
            "target_audio_column_name": data_args.target_audio_column_name,
Yoach Lacombe's avatar
Yoach Lacombe committed
230
            "prompt_column_name": data_args.prompt_column_name,
231
232
        }
        if data_args.description_column_name is not None:
233
            columns_to_keep["description_column_name"] = data_args.description_column_name
Yoach Lacombe's avatar
Yoach Lacombe committed
234

235
236
237
238
239
240
241
242
243
244
245
246
247
        if training_args.do_train:
            raw_datasets["train"] = load_multiple_datasets(
                accelerator,
                data_args.train_dataset_name,
                data_args.train_dataset_config_name,
                metadata_dataset_names=data_args.train_metadata_dataset_name,
                splits=data_args.train_split_name,
                dataset_samples=data_args.train_dataset_samples,
                seed=training_args.seed,
                cache_dir=model_args.cache_dir,
                num_proc=data_args.preprocessing_num_workers,
                id_column_name=data_args.id_column_name,
                columns_to_keep=columns_to_keep.values(),
248
                prompt_column_name=data_args.prompt_column_name,
249
250
                audio_column_name=data_args.target_audio_column_name,
                sampling_rate=sampling_rate,
251
                logger=logger,
252
253
                # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
            )
Yoach Lacombe's avatar
Yoach Lacombe committed
254

255
256
257
258
259
260
            for key in columns_to_keep:
                if columns_to_keep[key] not in raw_datasets["train"].column_names:
                    raise ValueError(
                        f"--{key} '{columns_to_keep[key]}' not found in dataset '{data_args.train_dataset_name}'."
                        f" Make sure to set `--{key}` to the correct audio column - one of"
                        f" {', '.join(raw_datasets['train'].column_names)}."
Yoach Lacombe's avatar
Yoach Lacombe committed
261
                    )
262
263
264
265
266
267
268
269

            if data_args.max_train_samples is not None:
                raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))

        if training_args.do_eval:
            raw_datasets["eval"] = load_multiple_datasets(
                accelerator,
                data_args.eval_dataset_name if data_args.eval_dataset_name else data_args.train_dataset_name,
Yoach Lacombe's avatar
Yoach Lacombe committed
270
271
272
                data_args.eval_dataset_config_name
                if data_args.eval_dataset_config_name
                else data_args.train_dataset_config_name,
273
274
275
276
277
278
                metadata_dataset_names=data_args.eval_metadata_dataset_name,
                splits=data_args.eval_split_name,
                cache_dir=model_args.cache_dir,
                num_proc=data_args.preprocessing_num_workers,
                id_column_name=data_args.id_column_name,
                columns_to_keep=columns_to_keep.values(),
279
280
281
                prompt_column_name=data_args.prompt_column_name,
                audio_column_name=data_args.target_audio_column_name,
                sampling_rate=sampling_rate,
282
                logger=logger,
283
284
                # streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
            )
285

286
            if data_args.max_eval_samples is not None:
Yoach Lacombe's avatar
Yoach Lacombe committed
287
288
289
                raw_datasets["eval"] = (
                    raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
                )
290

291
    # 3. Next, let's load the config.
Yoach Lacombe's avatar
Yoach Lacombe committed
292
    config = ParlerTTSConfig.from_pretrained(
293
294
295
296
297
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        token=data_args.token,
        trust_remote_code=data_args.trust_remote_code,
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
298

299
    # update pad token id and decoder_start_token_id
Yoach Lacombe's avatar
Yoach Lacombe committed
300
301
    config.update(
        {
Yoach Lacombe's avatar
Yoach Lacombe committed
302
            "pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else config.pad_token_id,
Yoach Lacombe's avatar
Yoach Lacombe committed
303
304
            "decoder_start_token_id": model_args.decoder_start_token_id
            if model_args.decoder_start_token_id is not None
305
            else config.decoder_start_token_id,
Yoach Lacombe's avatar
Yoach Lacombe committed
306
307
308
        }
    )

Yoach Lacombe's avatar
Yoach Lacombe committed
309
    # create model
Yoach Lacombe's avatar
Yoach Lacombe committed
310
    model = ParlerTTSForConditionalGeneration.from_pretrained(
311
312
313
314
315
316
        model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
        config=config,
        token=data_args.token,
        trust_remote_code=data_args.trust_remote_code,
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
317

318
319
320
    # enable gradient checkpointing if necessary
    if training_args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
Yoach Lacombe's avatar
Yoach Lacombe committed
321

322
    # 4. Now we preprocess the datasets including loading the audio, resampling and normalization
323
324
325
    # Thankfully, `datasets` takes care of automatically loading and resampling the audio,
    # so that we just need to set the correct target sampling rate and normalize the input
    # via the `feature_extractor`
Yoach Lacombe's avatar
Yoach Lacombe committed
326

327
    # derive max & min input length for sample rate & max duration
328
329
330
    sampling_rate = feature_extractor.sampling_rate
    max_target_length = data_args.max_duration_in_seconds * sampling_rate
    min_target_length = data_args.min_duration_in_seconds * sampling_rate
331
332
333
334
    target_audio_column_name = data_args.target_audio_column_name
    description_column_name = data_args.description_column_name
    prompt_column_name = data_args.prompt_column_name
    feature_extractor_input_name = feature_extractor.model_input_names[0]
Yoach Lacombe's avatar
Yoach Lacombe committed
335
336
    audio_encoder_pad_token_id = config.decoder.pad_token_id
    audio_encoder_eos_token_id = config.decoder.eos_token_id
Yoach Lacombe's avatar
Yoach Lacombe committed
337
338
339
    audio_encoder_bos_token_id = model.generation_config.decoder_start_token_id
    max_length = model.generation_config.max_length
    num_codebooks = model.decoder.config.num_codebooks
Yoach Lacombe's avatar
Yoach Lacombe committed
340
    bandwidth = model_args.bandwidth
Yoach Lacombe's avatar
Yoach Lacombe committed
341

342
343
    # Freeze Encoders
    model.freeze_encoders(model_args.freeze_text_encoder)
Yoach Lacombe's avatar
Yoach Lacombe committed
344

345
346
347
348
349
    # Test all gather - used for warmout and avoiding timeout
    test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
    gathered_tensor = accelerator.gather(test_tensor)
    print("gathered_tensor", gathered_tensor)
    accelerator.wait_for_everyone()
Yoach Lacombe's avatar
Yoach Lacombe committed
350
351

    if not dataset_was_precomputed:
352
        # Filter on text length
353
        if description_column_name is not None and data_args.max_text_length is not None:
354
355
356
357
358
359
360
            with accelerator.main_process_first():
                # filter description that is shorter than max_text_length
                raw_datasets = raw_datasets.filter(
                    lambda x: len(x) < data_args.max_text_length,
                    num_proc=num_workers,
                    input_columns=[description_column_name],
                )
361

362
363
364
365
        # Preprocessing the dataset.
        # We need to tokenize the texts.
        def pass_through_processors(description, prompt):
            batch = {}
Yoach Lacombe's avatar
Yoach Lacombe committed
366

367
368
            batch["input_ids"] = description_tokenizer(description.strip())["input_ids"]
            batch["prompt_input_ids"] = prompt_tokenizer(prompt.strip())["input_ids"]
369
370

            return batch
Yoach Lacombe's avatar
Yoach Lacombe committed
371

372
        with accelerator.main_process_first():
373
            # this is a trick to avoid to rewrite the entire audio column which takes ages
374
            vectorized_datasets = raw_datasets.map(
375
376
                pass_through_processors,
                remove_columns=next(iter(raw_datasets.values())).column_names,
377
                input_columns=[description_column_name, prompt_column_name],
378
379
380
                num_proc=num_workers,
                desc="preprocess datasets",
            )
381

382
        # We use Accelerate to perform distributed inference
383
        # T5 doesn't support fp16
Yoach Lacombe's avatar
Yoach Lacombe committed
384
        autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))
385
386

        # Now we encode the audio labels with encodec.
387
        ####### B. Encode audio
388

389
        logger.info("*** Encode target audio with encodec ***")
Yoach Lacombe's avatar
Yoach Lacombe committed
390

391
392
        # no need to prepare audio_decoder because used for inference without mixed precision
        # see: https://huggingface.co/docs/accelerate/main/en/package_reference/accelerator#accelerate.Accelerator.prepare
393
394
395
396
        if training_args.torch_compile:
            audio_decoder = accelerator.prepare_model(model.audio_encoder, evaluation_mode=True)
        else:
            audio_decoder = model.audio_encoder
397

Yoach Lacombe's avatar
Yoach Lacombe committed
398
399
400
401
402
403
404
        encoder_data_collator = DataCollatorEncodecWithPadding(
            feature_extractor,
            audio_column_name=target_audio_column_name,
            feature_extractor_input_name=feature_extractor_input_name,
            max_length=max_target_length,
            padding=padding,
        )
405
406
407
408
409
410
411
412
413

        def apply_audio_decoder(batch):
            len_audio = batch.pop("len_audio")
            audio_decoder.to(batch["input_values"].device).eval()
            with torch.no_grad():
                labels = audio_decoder.encode(**batch, bandwidth=bandwidth)["audio_codes"]
            output = {}
            output["len_audio"] = len_audio
            # (1, bsz, codebooks, seq_len) -> (bsz, seq_len, codebooks)
Yoach Lacombe's avatar
Yoach Lacombe committed
414
415
            output["labels"] = labels.squeeze(0).transpose(1, 2)
            output["ratio"] = torch.ones_like(len_audio) * labels.shape[-1] / len_audio.max()
Yoach Lacombe's avatar
Yoach Lacombe committed
416
            return output
417

418
419
        for split in vectorized_datasets:
            data_loader = DataLoader(
420
                raw_datasets[split],
Yoach Lacombe's avatar
Yoach Lacombe committed
421
                batch_size=training_args.audio_encoder_per_device_batch_size,
422
423
424
                collate_fn=encoder_data_collator,
                num_workers=training_args.dataloader_num_workers,
                pin_memory=True,
425
            )
Yoach Lacombe's avatar
Yoach Lacombe committed
426
427
            data_loader = accelerator.prepare(data_loader)

428
429
430
431
432
433
            all_generated_labels = []
            all_lens = []
            for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
                generate_labels = apply_audio_decoder(batch)
                generate_labels = accelerator.pad_across_processes(generate_labels, dim=1, pad_index=0)
                generate_labels = accelerator.gather_for_metrics(generate_labels)
Yoach Lacombe's avatar
Yoach Lacombe committed
434

435
                if accelerator.is_main_process:
Yoach Lacombe's avatar
Yoach Lacombe committed
436
                    lab = generate_labels["labels"].cpu().transpose(1, 2).to(torch.int16)
437
438
                    rat = generate_labels["ratio"].cpu().squeeze()
                    lens = generate_labels["len_audio"].cpu().squeeze()
Yoach Lacombe's avatar
Yoach Lacombe committed
439
440
                    lab = [l[:, : int(ratio * length)] for (l, ratio, length) in zip(lab, rat, lens)]

441
442
                    all_generated_labels.extend(lab)
                    all_lens.extend(lens)
Yoach Lacombe's avatar
Yoach Lacombe committed
443

444
445
            # (1, codebooks, seq_len) where seq_len=1
            bos_labels = torch.ones((1, num_codebooks, 1)) * audio_encoder_bos_token_id
Yoach Lacombe's avatar
Yoach Lacombe committed
446

447
            if accelerator.is_main_process:
448
                tmp_labels = Dataset.from_dict({"labels": all_generated_labels, "target_length": all_lens})
Yoach Lacombe's avatar
Yoach Lacombe committed
449
450
451
452
                tmp_labels.save_to_disk(
                    os.path.join(data_args.temporary_save_to_disk, split),
                    num_proc=1 if split == "eval" else data_args.preprocessing_num_workers,
                )
453
454
            accelerator.wait_for_everyone()
            del all_generated_labels
Yoach Lacombe's avatar
Yoach Lacombe committed
455

456
            tmp_labels = datasets.load_from_disk(os.path.join(data_args.temporary_save_to_disk, split))
457
458
            with accelerator.main_process_first():
                vectorized_datasets[split] = concatenate_datasets([vectorized_datasets[split], tmp_labels], axis=1)
Yoach Lacombe's avatar
Yoach Lacombe committed
459

460
            def postprocess_dataset(labels):
461
                # (1, codebooks, seq_len)
Yoach Lacombe's avatar
Yoach Lacombe committed
462
                labels = torch.tensor(labels).unsqueeze(0)
463
464
                # add bos
                labels = torch.cat([bos_labels, labels], dim=-1)
Yoach Lacombe's avatar
Yoach Lacombe committed
465
466
467
468
469
470
471
472
473

                labels, delay_pattern_mask = build_delay_pattern_mask(
                    labels,
                    bos_token_id=audio_encoder_bos_token_id,
                    pad_token_id=audio_encoder_eos_token_id,
                    max_length=labels.shape[-1] + num_codebooks,
                    num_codebooks=num_codebooks,
                )

474
475
476
477
478
479
                # the first ids of the delay pattern mask are precisely labels, we use the rest of the labels mask
                # to take care of EOS
                # we want labels to look like this:
                #  - [B, a, b, E, E, E, E]
                #  - [B, B, c, d, E, E, E]
                #  - [B, B, B, e, f, E, E]
Yoach Lacombe's avatar
Yoach Lacombe committed
480
481
482
                #  - [B, B, B, B, g, h, E]
                labels = torch.where(delay_pattern_mask == -1, audio_encoder_eos_token_id, delay_pattern_mask)

483
484
                # the first timestamp is associated to a row full of BOS, let's get rid of it
                # we also remove the last timestampts (full of PAD)
485
                output = {"labels": labels[:, 1:]}
486
487
488
489
490
                return output

            with accelerator.main_process_first():
                vectorized_datasets[split] = vectorized_datasets[split].map(
                    postprocess_dataset,
Yoach Lacombe's avatar
Yoach Lacombe committed
491
                    num_proc=data_args.preprocessing_num_workers,  # this one is resource consuming if many processor.
492
                    input_columns=["labels"],
493
494
495
496
                    desc="Postprocessing labeling",
                )

        accelerator.free_memory()
497
        del generate_labels, all_lens
498

499
        with accelerator.main_process_first():
500
            # NOTE: filtering is done at the end because in the `datasets` library, caching audio files is done after most operations
Yoach Lacombe's avatar
Yoach Lacombe committed
501
            # caching audio files is time and disk-space consuming, so we want to avoid it at all costs, especially for large (>1Kh) audio datasets.
502
503
            # That's also why we avoid to concat the processed datasets (vectorized_datasets) with the audio column present in raw_datasets.

504
505
506
507
508
509
510
511
512
            def is_audio_in_length_range(length):
                return length > min_target_length and length < max_target_length

            # filter data that is shorter than min_target_length
            vectorized_datasets = vectorized_datasets.filter(
                is_audio_in_length_range,
                num_proc=num_workers,
                input_columns=["target_length"],
            )
Yoach Lacombe's avatar
Yoach Lacombe committed
513

514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
            if description_column_name is not None and data_args.max_description_token_length is not None:
                with accelerator.main_process_first():
                    # filter description that is shorter than max_text_length
                    vectorized_datasets = vectorized_datasets.filter(
                        lambda x: len(x) < data_args.max_description_token_length,
                        num_proc=num_workers,
                        input_columns=["input_ids"],
                    )

            if data_args.max_prompt_token_length is not None:
                with accelerator.main_process_first():
                    # filter description that is shorter than max_text_length
                    vectorized_datasets = vectorized_datasets.filter(
                        lambda x: len(x) < data_args.max_prompt_token_length,
                        num_proc=num_workers,
                        input_columns=["prompt_input_ids"],
                    )
Yoach Lacombe's avatar
Yoach Lacombe committed
531

532
    if data_args.save_to_disk is not None and not dataset_was_precomputed:
533
        if accelerator.is_main_process:
Yoach Lacombe's avatar
Yoach Lacombe committed
534
535
536
537
            vectorized_datasets.save_to_disk(
                data_args.save_to_disk,
                num_proc=min(data_args.preprocessing_num_workers, len(vectorized_datasets["eval"]) - 1),
            )
538
        logger.info(f"Dataset saved at {data_args.save_to_disk}")
Yoach Lacombe's avatar
Yoach Lacombe committed
539

540
541
542
    audio_max_length = None
    if training_args.torch_compile:
        audio_max_length = max(vectorized_datasets["train"]["target_length"])
Yoach Lacombe's avatar
Yoach Lacombe committed
543
        with accelerator.main_process_first():
544
            max_sample = vectorized_datasets["train"].filter(
Yoach Lacombe's avatar
Yoach Lacombe committed
545
546
547
548
                lambda x: x == audio_max_length,
                num_proc=num_workers,
                input_columns=["target_length"],
            )
549
        audio_max_length = torch.tensor(max_sample[0]["labels"]).shape[1]
550
551
552
553
554
555

    # for large datasets it is advised to run the preprocessing on a
    # single machine first with ``args.preprocessing_only`` since there will mostly likely
    # be a timeout when running the script in distributed mode.
    # In a second step ``args.preprocessing_only`` can then be set to `False` to load the
    # cached dataset
556
    if data_args.preprocessing_only and data_args.save_to_disk is None:
Yoach Lacombe's avatar
Yoach Lacombe committed
557
558
559
        raise ValueError(
            "`preprocessing_only=True` but `save_to_disk` is not set. The latter should indicates where to save the dataset locally."
        )
560
561
    elif data_args.preprocessing_only:
        logger.info(f"Data preprocessing finished. Files save at {data_args.save_to_disk}")
562
        return
Yoach Lacombe's avatar
Yoach Lacombe committed
563

564
    # 6. Next, we can prepare the training.
Yoach Lacombe's avatar
Yoach Lacombe committed
565

Yoach Lacombe's avatar
Yoach Lacombe committed
566
567
    # Let's use word CLAP similary and WER metrics as our evaluation metrics,
    def compute_metrics(audios, descriptions, prompts, device="cpu"):
568
        results = {}
Yoach Lacombe's avatar
Yoach Lacombe committed
569
        input_ids = descriptions
570
        texts = description_tokenizer.batch_decode(input_ids, skip_special_tokens=True)
Yoach Lacombe's avatar
Yoach Lacombe committed
571
572
        prompts = prompt_tokenizer.batch_decode(prompts, skip_special_tokens=True)
        audios = [a.cpu().numpy() for a in audios]
Yoach Lacombe's avatar
Yoach Lacombe committed
573

574
575
576
        clap_score = clap_similarity(model_args.clap_model_name_or_path, texts, audios, device)
        results["clap"] = clap_score

Yoach Lacombe's avatar
Yoach Lacombe committed
577
578
579
580
581
582
583
584
        word_error, transcriptions = wer(
            model_args.asr_model_name_or_path,
            prompts,
            audios,
            device,
            training_args.per_device_eval_batch_size,
            sampling_rate,
        )
Yoach Lacombe's avatar
Yoach Lacombe committed
585
        results["wer"] = word_error
586

Yoach Lacombe's avatar
Yoach Lacombe committed
587
        return results, texts, prompts, audios, transcriptions
Yoach Lacombe's avatar
Yoach Lacombe committed
588

Yoach Lacombe's avatar
Yoach Lacombe committed
589
590
591
592
593
594
    # Define Training Schedule
    # Store some constants
    per_device_train_batch_size = int(training_args.per_device_train_batch_size)
    train_batch_size = per_device_train_batch_size * accelerator.num_processes
    gradient_accumulation_steps = int(training_args.gradient_accumulation_steps)
    per_device_eval_batch_size = int(training_args.per_device_eval_batch_size)
Yoach Lacombe's avatar
Yoach Lacombe committed
595

Yoach Lacombe's avatar
Yoach Lacombe committed
596
597
598
599
600
601
602
603
604
605
606
607
    if training_args.max_steps < 0:
        num_epochs = int(training_args.num_train_epochs)
        steps_per_epoch = len(vectorized_datasets["train"]) // (train_batch_size * gradient_accumulation_steps)
        total_train_steps = steps_per_epoch * num_epochs
    elif training_args.max_steps > 0:
        logger.info("max_steps is given, it will override any value given in num_train_epochs")
        total_train_steps = int(training_args.max_steps)
        # Setting a very large number of epochs so we go as many times as necessary over the iterator.
        num_epochs = sys.maxsize
        steps_per_epoch = total_train_steps

    if training_args.eval_steps is None:
Yoach Lacombe's avatar
Yoach Lacombe committed
608
        logger.info(f"eval_steps is not set, evaluating at the end of each epoch")
Yoach Lacombe's avatar
Yoach Lacombe committed
609
610
611
        eval_steps = steps_per_epoch
    else:
        eval_steps = training_args.eval_steps
Yoach Lacombe's avatar
Yoach Lacombe committed
612

613
    # T5 doesn't support fp16
Yoach Lacombe's avatar
Yoach Lacombe committed
614
615
    autocast_kwargs = AutocastKwargs(enabled=(mixed_precision != "fp16"))

Yoach Lacombe's avatar
Yoach Lacombe committed
616
617
618
619
620
621
    # Define optimizer, LR scheduler, collator
    optimizer = torch.optim.AdamW(
        params=model.parameters(),
        lr=training_args.learning_rate,
        betas=(training_args.adam_beta1, training_args.adam_beta2),
        eps=training_args.adam_epsilon,
622
        weight_decay=training_args.weight_decay,
Yoach Lacombe's avatar
Yoach Lacombe committed
623
    )
624

Yoach Lacombe's avatar
Yoach Lacombe committed
625
626
627
628
    # LR scheduler gets stepped by `num_processes` each time -> account for this in warmup / total steps
    lr_scheduler = get_scheduler(
        name=training_args.lr_scheduler_type,
        optimizer=optimizer,
Yoach Lacombe's avatar
Yoach Lacombe committed
629
        num_warmup_steps=training_args.get_warmup_steps(total_train_steps) * accelerator.num_processes,
Yoach Lacombe's avatar
Yoach Lacombe committed
630
631
        num_training_steps=total_train_steps * accelerator.num_processes,
    )
632
633

    # Instantiate custom data collator
Yoach Lacombe's avatar
Yoach Lacombe committed
634
    data_collator = DataCollatorParlerTTSWithPadding(
Yoach Lacombe's avatar
Yoach Lacombe committed
635
636
637
638
639
640
641
        prompt_tokenizer=prompt_tokenizer,
        description_tokenizer=description_tokenizer,
        pad_to_multiple_of=data_args.pad_to_multiple_of,
        padding=padding,
        prompt_max_length=data_args.max_prompt_token_length,
        description_max_length=data_args.max_description_token_length,
        audio_max_length=audio_max_length,
642
    )
Yoach Lacombe's avatar
Yoach Lacombe committed
643

Yoach Lacombe's avatar
Yoach Lacombe committed
644
645
    # Prepare everything with accelerate
    model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
Yoach Lacombe's avatar
Yoach Lacombe committed
646

Yoach Lacombe's avatar
Yoach Lacombe committed
647
648
    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {total_train_steps * train_batch_size * gradient_accumulation_steps}")
649
    logger.info("  Instantaneous batch size per device =" f" {per_device_train_batch_size}")
Yoach Lacombe's avatar
Yoach Lacombe committed
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
    logger.info("  Gradient accumulation steps =" f" {gradient_accumulation_steps}")
    logger.info(
        f"  Total train batch size (w. parallel & distributed) = {train_batch_size * gradient_accumulation_steps}"
    )
    logger.info(f"  Total optimization steps = {total_train_steps}")

    # ======================== Training ================================
    train_time = 0
    train_start = time.time()
    steps_trained_progress_bar = tqdm(
        range(total_train_steps), desc="Train steps ... ", position=0, disable=not accelerator.is_local_main_process
    )
    continue_training = True
    epochs_trained = 0
    cur_step = 0

    checkpoint = None
    if training_args.resume_from_checkpoint is not None:
        checkpoint = training_args.resume_from_checkpoint
    elif last_checkpoint is not None:
        checkpoint = last_checkpoint
Yoach Lacombe's avatar
Yoach Lacombe committed
671

Yoach Lacombe's avatar
Yoach Lacombe committed
672
673
    if accelerator.is_main_process:
        if training_args.push_to_hub:
674
675
676
            api = HfApi(token=training_args.hub_token)

            # Create repo (repo_name from args or inferred)
Yoach Lacombe's avatar
Yoach Lacombe committed
677
678
679
            repo_name = training_args.hub_model_id
            if repo_name is None:
                repo_name = Path(training_args.output_dir).absolute().name
680
            repo_id = api.create_repo(repo_name, exist_ok=True).repo_id
Yoach Lacombe's avatar
Yoach Lacombe committed
681
682
683
684
685
686
687

            with open(os.path.join(training_args.output_dir, ".gitignore"), "w+") as gitignore:
                if "wandb" not in gitignore:
                    gitignore.write("wandb\n")
        elif training_args.output_dir is not None:
            os.makedirs(training_args.output_dir, exist_ok=True)
    accelerator.wait_for_everyone()
Yoach Lacombe's avatar
Yoach Lacombe committed
688

Yoach Lacombe's avatar
Yoach Lacombe committed
689
690
691
692
693
694
    # Now save everything to be able to create a single processor later
    # make sure all processes wait until data is saved
    with accelerator.main_process_first():
        # only the main process saves them
        if accelerator.is_main_process:
            # save feature extractor, tokenizer and config
Yoach Lacombe's avatar
Yoach Lacombe committed
695
696
697
698
699
            if (
                model_args.prompt_tokenizer_name is None
                and model_args.description_tokenizer_name
                or (model_args.prompt_tokenizer_name == model_args.description_tokenizer_name)
            ):
Yoach Lacombe's avatar
Yoach Lacombe committed
700
701
                prompt_tokenizer.save_pretrained(training_args.output_dir)
            else:
Yoach Lacombe's avatar
Yoach Lacombe committed
702
                logger.warning(
703
                    f"Prompt tokenizer ('{model_args.prompt_tokenizer_name}') and description tokenizer ('{model_args.description_tokenizer_name}') are not the same. Saving only the prompt tokenizer."
Yoach Lacombe's avatar
Yoach Lacombe committed
704
                )
Yoach Lacombe's avatar
Yoach Lacombe committed
705
                prompt_tokenizer.save_pretrained(training_args.output_dir)
Yoach Lacombe's avatar
Yoach Lacombe committed
706

Yoach Lacombe's avatar
Yoach Lacombe committed
707
708
            feature_extractor.save_pretrained(training_args.output_dir)
            config.save_pretrained(training_args.output_dir)
Yoach Lacombe's avatar
Yoach Lacombe committed
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725

    if checkpoint is not None:
        accelerator.load_state(checkpoint)
        # Find num steps and epoch from saved state string pattern
        pattern = r"checkpoint-(\d+)-epoch-(\d+)"
        match = re.search(pattern, checkpoint)
        cur_step = int(match.group(1))
        epochs_trained = int(match.group(2))

        logger.info("  Continuing training from checkpoint, will skip to saved global_step")
        logger.info(f"  Continuing training from epoch {epochs_trained}")
        logger.info(f"  Continuing training from global step {cur_step}")

        steps_trained_progress_bar.update(cur_step)

        for epoch in range(0, epochs_trained):
            vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
Yoach Lacombe's avatar
Yoach Lacombe committed
726

Yoach Lacombe's avatar
Yoach Lacombe committed
727
728
        if training_args.max_steps < 0:
            # we know exactly the number of steps per epoch, so can skip through the required number of batches
729
            resume_step = (cur_step - epochs_trained * steps_per_epoch) * gradient_accumulation_steps
Yoach Lacombe's avatar
Yoach Lacombe committed
730
731
732
733
734
735
736
737
        else:
            # Currently we don't know how many steps we've taken in the current epoch
            # So we just shuffle the dataset one extra time and start from a fresh epoch
            # This is "good enough" for our purposes but not fully correct
            resume_step = None
            vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
    else:
        resume_step = None
Yoach Lacombe's avatar
Yoach Lacombe committed
738

Yoach Lacombe's avatar
Yoach Lacombe committed
739
740
    gen_kwargs = {
        "do_sample": model_args.do_sample,
yoach@huggingface.co's avatar
yoach@huggingface.co committed
741
        "temperature": model_args.temperature,
Yoach Lacombe's avatar
Yoach Lacombe committed
742
743
        "max_length": model_args.max_length,
    }
Yoach Lacombe's avatar
Yoach Lacombe committed
744

Yoach Lacombe's avatar
Yoach Lacombe committed
745
746
747
    # Define gradient update step fn
    def train_step(
        batch,
748
749
        accelerator,
        autocast_kwargs,
Yoach Lacombe's avatar
Yoach Lacombe committed
750
751
    ):
        model.train()
Yoach Lacombe's avatar
Yoach Lacombe committed
752

753
        if mixed_precision == "fp16":
754
755
            # fp16 doesn't work with T5-like models
            with accelerator.autocast(autocast_handler=autocast_kwargs):
756
                if training_args.parallel_mode.value != "distributed":
Yoach Lacombe's avatar
Yoach Lacombe committed
757
758
759
                    encoder_outputs = model.text_encoder(
                        input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
                    )
760
                else:
Yoach Lacombe's avatar
Yoach Lacombe committed
761
762
763
                    encoder_outputs = model.module.text_encoder(
                        input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
                    )
764
                batch["encoder_outputs"] = encoder_outputs
Yoach Lacombe's avatar
Yoach Lacombe committed
765

Yoach Lacombe's avatar
Yoach Lacombe committed
766
767
768
769
770
771
        outputs = model(**batch)
        # CE (data) loss
        ce_loss = outputs.loss

        metrics = {"loss": ce_loss}
        return ce_loss, metrics
Yoach Lacombe's avatar
Yoach Lacombe committed
772

Yoach Lacombe's avatar
Yoach Lacombe committed
773
    # Define eval fn
Yoach Lacombe's avatar
Yoach Lacombe committed
774
775
776
777
778
    def eval_step(
        batch,
        accelerator,
        autocast_kwargs,
    ):
Yoach Lacombe's avatar
Yoach Lacombe committed
779
780
781
        eval_model = model if not training_args.torch_compile else model._orig_mod
        eval_model.eval()

782
        if mixed_precision == "fp16":
783
784
            # fp16 doesn't work with T5-like models
            with accelerator.autocast(autocast_handler=autocast_kwargs):
Yoach Lacombe's avatar
Yoach Lacombe committed
785
786
                with torch.no_grad():
                    if training_args.parallel_mode.value != "distributed" or training_args.torch_compile:
Yoach Lacombe's avatar
Yoach Lacombe committed
787
788
789
                        encoder_outputs = eval_model.text_encoder(
                            input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
                        )
Yoach Lacombe's avatar
Yoach Lacombe committed
790
                    else:
Yoach Lacombe's avatar
Yoach Lacombe committed
791
792
793
                        encoder_outputs = eval_model.module.text_encoder(
                            input_ids=batch.get("input_ids"), attention_mask=batch.get("attention_mask", None)
                        )
794
                batch["encoder_outputs"] = encoder_outputs
Yoach Lacombe's avatar
Yoach Lacombe committed
795
796

        with torch.no_grad():
Yoach Lacombe's avatar
Yoach Lacombe committed
797
            outputs = eval_model(**batch)
Yoach Lacombe's avatar
Yoach Lacombe committed
798
799
800
801
802
803
        # CE (data) loss
        ce_loss = outputs.loss
        metrics = {"loss": ce_loss}
        return metrics

    def generate_step(batch):
804
        batch.pop("decoder_attention_mask", None)
Yoach Lacombe's avatar
Yoach Lacombe committed
805
        eval_model = accelerator.unwrap_model(model, keep_fp32_wrapper=mixed_precision != "fp16").eval()
Yoach Lacombe's avatar
Yoach Lacombe committed
806
807
808
809
        if training_args.torch_compile:
            eval_model = model._orig_mod

        output_audios = eval_model.generate(**batch, **gen_kwargs)
Yoach Lacombe's avatar
Yoach Lacombe committed
810
811
812
813
814
        output_audios = accelerator.pad_across_processes(output_audios, dim=1, pad_index=0)
        return output_audios

    for epoch in range(epochs_trained, num_epochs):
        vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
Yoach Lacombe's avatar
Yoach Lacombe committed
815
816
817
        sampler = None
        if training_args.group_by_length:
            sampler = LengthGroupedSampler(train_batch_size, lengths=vectorized_datasets["train"]["target_length"])
Yoach Lacombe's avatar
Yoach Lacombe committed
818
819
820
821
        train_dataloader = DataLoader(
            vectorized_datasets["train"],
            collate_fn=data_collator,
            batch_size=per_device_train_batch_size,
Yoach Lacombe's avatar
Yoach Lacombe committed
822
            sampler=sampler,
Yoach Lacombe's avatar
Yoach Lacombe committed
823
824
825
826
827
828
829
830
831
832
833
834
835
836
            num_workers=training_args.dataloader_num_workers,
            pin_memory=training_args.dataloader_pin_memory,
        )
        train_dataloader = accelerator.prepare(train_dataloader)
        if hasattr(train_dataloader, "dataset") and isinstance(train_dataloader.dataset, IterableDataset):
            train_dataloader.dataset.set_epoch(epoch)

        if resume_step is not None:
            # Skip the first N batches in the dataloader when resuming from a checkpoint
            train_dataloader = accelerator.skip_first_batches(train_dataloader, resume_step)
            resume_step = None

        for batch in train_dataloader:
            with accelerator.accumulate(model):
837
                loss, train_metric = train_step(batch, accelerator, autocast_kwargs)
Yoach Lacombe's avatar
Yoach Lacombe committed
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(model.parameters(), training_args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Check if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                steps_trained_progress_bar.update(1)
                cur_step += 1

                if cur_step % training_args.logging_steps == 0:
                    steps_trained_progress_bar.write(
                        f"Step... ({cur_step} / {total_train_steps} | Loss:"
                        f" {train_metric['loss']}, Learning Rate:"
                        f" {lr_scheduler.get_last_lr()[0]})"
                    )
                    log_metric(
                        accelerator,
                        metrics=train_metric,
                        learning_rate=lr_scheduler.get_last_lr()[0],
                        train_time=train_time + time.time() - train_start,
                        step=cur_step,
                        epoch=epoch,
                        prefix="train",
                    )

                # save checkpoint and weights after each save_steps and at the end of training
                if (cur_step % training_args.save_steps == 0) or cur_step == total_train_steps:
                    intermediate_dir = os.path.join(training_args.output_dir, f"checkpoint-{cur_step}-epoch-{epoch}")
Yoach Lacombe's avatar
Yoach Lacombe committed
869
                    # safe_serialization=False to avoid shared tensors saving issue (TODO(YL): it's a temporary fix)
870
871
                    # https://github.com/huggingface/transformers/issues/27293#issuecomment-1872560074
                    accelerator.save_state(output_dir=intermediate_dir, safe_serialization=False)
Yoach Lacombe's avatar
Yoach Lacombe committed
872
873
                    accelerator.wait_for_everyone()
                    if accelerator.is_main_process:
Yoach Lacombe's avatar
Yoach Lacombe committed
874
875
876
                        rotate_checkpoints(
                            training_args.save_total_limit, output_dir=training_args.output_dir, logger=logger
                        )
Yoach Lacombe's avatar
Yoach Lacombe committed
877
878
879

                        if cur_step == total_train_steps:
                            # un-wrap student model for save
Yoach Lacombe's avatar
Yoach Lacombe committed
880
881
                            unwrapped_model = accelerator.unwrap_model(model)
                            unwrapped_model.save_pretrained(training_args.output_dir)
Yoach Lacombe's avatar
Yoach Lacombe committed
882
883

                        if training_args.push_to_hub:
884
885
886
                            api.upload_folder(
                                repo_id=repo_id,
                                folder_path=training_args.output_dir,
Yoach Lacombe's avatar
Yoach Lacombe committed
887
                                commit_message=f"Saving train state of step {cur_step}",
888
                                run_as_future=True,
Yoach Lacombe's avatar
Yoach Lacombe committed
889
890
891
892
893
894
895
896
897
898
                            )

                if training_args.do_eval and (cur_step % eval_steps == 0 or cur_step == total_train_steps):
                    train_time += time.time() - train_start
                    # ======================== Evaluating ==============================
                    eval_metrics = []
                    eval_preds = []
                    eval_descriptions = []
                    eval_prompts = []
                    eval_start = time.time()
Yoach Lacombe's avatar
Yoach Lacombe committed
899

Yoach Lacombe's avatar
Yoach Lacombe committed
900
901
                    # release training input batch
                    batch = release_memory(batch)
Yoach Lacombe's avatar
Yoach Lacombe committed
902
903
904
905
906

                    validation_dataloader = DataLoader(
                        vectorized_datasets["eval"],
                        collate_fn=data_collator,
                        batch_size=per_device_eval_batch_size,
907
                        drop_last=False,
Yoach Lacombe's avatar
Yoach Lacombe committed
908
909
910
911
912
913
914
                        num_workers=training_args.dataloader_pin_memory,
                        pin_memory=training_args.dataloader_pin_memory,
                    )
                    validation_dataloader = accelerator.prepare(validation_dataloader)

                    for batch in tqdm(
                        validation_dataloader,
915
                        desc=f"Evaluating - Inference ...",
Yoach Lacombe's avatar
Yoach Lacombe committed
916
917
918
919
                        position=2,
                        disable=not accelerator.is_local_main_process,
                    ):
                        # Model forward
920
                        eval_metric = eval_step(batch, accelerator, autocast_kwargs)
Yoach Lacombe's avatar
Yoach Lacombe committed
921
922
923
                        eval_metric = accelerator.gather_for_metrics(eval_metric)
                        eval_metrics.append(eval_metric)

924
925
926
927
928
929
930
931
932
933
                    if training_args.predict_with_generate:
                        validation_dataloader = DataLoader(
                            vectorized_datasets["eval"],
                            collate_fn=data_collator,
                            batch_size=per_device_eval_batch_size,
                            drop_last=False,
                            num_workers=training_args.dataloader_pin_memory,
                            pin_memory=training_args.dataloader_pin_memory,
                        )
                        validation_dataloader = accelerator.prepare(validation_dataloader)
Yoach Lacombe's avatar
Yoach Lacombe committed
934
                        # generation
935
                        for batch in tqdm(
Yoach Lacombe's avatar
Yoach Lacombe committed
936
937
938
939
940
                            validation_dataloader,
                            desc=f"Evaluating - Generation ...",
                            position=2,
                            disable=not accelerator.is_local_main_process,
                        ):
Yoach Lacombe's avatar
Yoach Lacombe committed
941
942
                            generated_audios = generate_step(batch)
                            # Gather all predictions and targets
Yoach Lacombe's avatar
Yoach Lacombe committed
943
944
945
946
947
948
                            generated_audios, input_ids, prompts = accelerator.pad_across_processes(
                                (generated_audios, batch["input_ids"], batch["prompt_input_ids"]), dim=1, pad_index=0
                            )
                            generated_audios, input_ids, prompts = accelerator.gather_for_metrics(
                                (generated_audios, input_ids, prompts)
                            )
949
950
951
                            eval_preds.extend(generated_audios.to("cpu"))
                            eval_descriptions.extend(input_ids.to("cpu"))
                            eval_prompts.extend(prompts.to("cpu"))
Yoach Lacombe's avatar
Yoach Lacombe committed
952
953
954
955

                    eval_time = time.time() - eval_start
                    # normalize eval metrics
                    eval_metrics = {
Yoach Lacombe's avatar
Yoach Lacombe committed
956
957
                        key: torch.mean(torch.cat([d[key].unsqueeze(0) for d in eval_metrics]))
                        for key in eval_metrics[0]
Yoach Lacombe's avatar
Yoach Lacombe committed
958
959
960
961
962
963
964
965
966
967
                    }

                    # compute metrics
                    metrics_desc = ""
                    if training_args.predict_with_generate:
                        metric_values, pred_descriptions, pred_prompts, audios, transcriptions = compute_metrics(
                            eval_preds, eval_descriptions, eval_prompts, accelerator.device
                        )
                        eval_metrics.update(metric_values)
                        metrics_desc = " ".join([f"Eval {key}: {value} |" for key, value in metric_values.items()])
968
969
970
971
972
973
974
975
976
977
978
                        if "wandb" in training_args.report_to:
                            log_pred(
                                accelerator,
                                pred_descriptions,
                                pred_prompts,
                                transcriptions,
                                audios,
                                sampling_rate=sampling_rate,
                                step=cur_step,
                                prefix="eval",
                            )
Yoach Lacombe's avatar
Yoach Lacombe committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993

                    # Print metrics and update progress bar
                    steps_trained_progress_bar.write(
                        f"Eval results for step ({cur_step} / {total_train_steps} | Eval Loss: {eval_metrics['loss']} |"
                        f" {metrics_desc})"
                    )

                    log_metric(
                        accelerator,
                        metrics=eval_metrics,
                        train_time=eval_time,
                        step=cur_step,
                        epoch=epoch,
                        prefix="eval",
                    )
Yoach Lacombe's avatar
Yoach Lacombe committed
994

995
996
997
998
999
1000
1001
                    # release eval batch and relax metrics
                    eval_metrics = []
                    eval_preds = []
                    eval_descriptions = []
                    eval_prompts = []
                    batch = release_memory(batch)

Yoach Lacombe's avatar
Yoach Lacombe committed
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                    # flush the train metrics
                    train_start = time.time()

                # break condition
                if cur_step == total_train_steps:
                    continue_training = False
                    break

        if not continue_training:
            break

    accelerator.end_training()
1014
1015
1016


if __name__ == "__main__":
1017
    set_start_method("spawn")
Yoach Lacombe's avatar
Yoach Lacombe committed
1018
    main()