Commit 441af9a4 authored by yoach@huggingface.co's avatar yoach@huggingface.co
Browse files

update training script with precomputation

parent 5acad845
......@@ -27,6 +27,7 @@ import warnings
import math
import time
from multiprocess import set_start_method
from datetime import timedelta
import evaluate
......@@ -69,7 +70,7 @@ AutoModel.register(DACConfig, DACModel)
from accelerate import Accelerator
from accelerate.utils import set_seed, AutocastKwargs
from accelerate.utils import set_seed, AutocastKwargs, InitProcessGroupKwargs
from accelerate.utils.memory import release_memory
from stable_speech import StableSpeechForConditionalGeneration, StableSpeechConfig, apply_delay_pattern_mask, build_delay_pattern_mask
......@@ -253,6 +254,10 @@ class ModelArguments:
default=6, # TODO
metadata={"help": "Audio encoder bandwidth."},
)
precompute_text_hidden_states: bool = field(
default=False,
metadata={"help": "Whether to precompute text hidden states. Only work when the text encoder is freezed"},
)
......@@ -377,6 +382,9 @@ class DataTrainingArguments:
min_duration_in_seconds: float = field(
default=0.0, metadata={"help": "Filter audio files that are shorter than `min_duration_in_seconds` seconds"}
)
max_text_length: int = field(
default=500, metadata={"help": "Max description lengths in number of characters."}
)
preprocessing_only: bool = field(
default=False,
metadata={
......@@ -671,6 +679,8 @@ def load_multiple_datasets(
seed: Optional[int] = None,
id_column_name: Optional[str] = None,
columns_to_keep: Optional[Set[str]] = None,
sampling_rate: Optional[int] = None,
audio_column_name: Optional[str] = None,
**kwargs,
) -> Union[Dataset, IterableDataset]:
dataset_names_dict = convert_dataset_str_to_list(
......@@ -696,6 +706,12 @@ def load_multiple_datasets(
)
dataset_features = dataset.features.keys()
if sampling_rate is not None and audio_column_name is not None:
# resample target audio
dataset = dataset.cast_column(
audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
metadata_dataset_name = dataset_dict["metadata_dataset_name"]
if metadata_dataset_name is not None:
metadata_dataset = load_dataset(
......@@ -705,26 +721,36 @@ def load_multiple_datasets(
streaming=streaming,
**kwargs,
)
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
# TODO(YL): I forgot to create unique ids for MLS english.
# To iterate faster, I bypass the original id check and do another one. - Done once
# if dataset_dict["name"] == "stable-speech/mls_eng_10k":
# def concat_ids(book_id, speaker_id, begin_time):
# return {"id": f"{book_id}_{speaker_id}_{str(begin_time).replace('.', '_')}"}
# dataset = dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.map(concat_ids, input_columns=["book_id", "speaker_id", "begin_time"], num_proc=24)
# metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
if dataset_dict["name"] != "stable-speech/mls_eng_10k":
if id_column_name is not None and id_column_name not in dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the dataset columns"
f"- one of {', '.join(list(dataset.column_names))}."
)
if id_column_name is not None and id_column_name not in metadata_dataset.column_names:
raise ValueError(
f"id_column_name={id_column_name} but has not been found in the metadata dataset columns"
f"- one of {', '.join(list(metadata_dataset.column_names))}."
)
elif id_column_name is not None:
metadata_dataset = metadata_dataset.rename_column(id_column_name, f"metadata_{id_column_name}")
metadata_columns_to_remove = set(metadata_dataset.column_names).intersection(set(dataset.column_names))
metadata_dataset = metadata_dataset.remove_columns(metadata_columns_to_remove)
dataset = concatenate_datasets([dataset, metadata_dataset], axis=1)
if id_column_name is not None:
if id_column_name is not None and dataset_dict["name"] != "stable-speech/mls_eng_10k":
if len(dataset.filter(lambda id1, id2: id1!=id2, input_columns=[id_column_name, f"metadata_{id_column_name}"])) != 0:
raise ValueError(f"Concatenate didn't work. Some ids don't correspond on dataset {dataset_dict['name']}")
......@@ -777,12 +803,14 @@ def main():
else:
mixed_precision = "no"
####### A. Preparation
accelerator = Accelerator(
gradient_accumulation_steps=training_args.gradient_accumulation_steps,
mixed_precision=mixed_precision,
log_with=training_args.report_to,
project_dir=training_args.output_dir,
kwargs_handlers=[InitProcessGroupKwargs(timeout=timedelta(minutes=60))],
)
accelerator.init_trackers(project_name=data_args.wandb_project, config={
......@@ -848,8 +876,45 @@ def main():
# Set seed before initializing model.
set_seed(training_args.seed)
num_workers = data_args.preprocessing_num_workers
# 1. First, lett's instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
# load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
)
sampling_rate = feature_extractor.sampling_rate
# load prompt tokenizer
prompt_tokenizer = AutoTokenizer.from_pretrained(
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
# load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235")
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# 1. First, let's load the dataset
# 2. Now, let's load the dataset
if data_args.save_to_disk is not None:
os.makedirs(data_args.save_to_disk, exist_ok=True)
......@@ -866,7 +931,7 @@ def main():
"prompt_column_name": data_args.prompt_column_name
}
if data_args.description_column_name is not None:
columns_to_keep["description_column_nam"] = data_args.description_column_name
columns_to_keep["description_column_name"] = data_args.description_column_name
if training_args.do_train:
raw_datasets["train"] = load_multiple_datasets(
......@@ -881,6 +946,8 @@ def main():
num_proc=data_args.preprocessing_num_workers,
id_column_name=data_args.id_column_name,
columns_to_keep=columns_to_keep.values(),
audio_column_name=data_args.target_audio_column_name,
sampling_rate=sampling_rate,
# streaming=data_args.streaming, TODO(SG): optionally enable streaming mode
)
......@@ -910,11 +977,11 @@ def main():
)
if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
raw_datasets["eval"] = raw_datasets["eval"].shuffle(seed=training_args.seed).select(range(data_args.max_eval_samples))
# 2. Next, let's load the config as we might need it to create
# load config TODO: add the option to create the config from scratch
# 3. Next, let's load the config.
# TODO(YL): add the option to create the config from scratch
config = StableSpeechConfig.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
......@@ -923,50 +990,13 @@ def main():
)
# update pad token id and decoder_start_token_id
# TODO: verify if this makes sense, maybe should do it for model.decoder
# TODO(YL): verify if this makes sense, maybe should do it for model.decoder
config.update({
"pad_token_id": model_args.pad_token_id if model_args.pad_token_id is not None else model.config.pad_token_id,
"decoder_start_token_id": model_args.decoder_start_token_id if model_args.decoder_start_token_id is not None else model.config.decoder_start_token_id,
})
# 3. Now we can instantiate the feature extractor, tokenizers and model
# Note for distributed training, the .from_pretrained methods guarantee that only
# one local process can concurrently download model & vocab.
# load feature extractor
feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
)
# load prompt tokenizer
prompt_tokenizer = AutoTokenizer.from_pretrained(
model_args.prompt_tokenizer_name or model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
padding_side="left", # prompt has to be padded on the left bc it's preprend to codebooks hidden states
)
# load description tokenizer
description_tokenizer = AutoTokenizer.from_pretrained(
model_args.description_tokenizer_name or model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
token=data_args.token,
trust_remote_code=data_args.trust_remote_code,
use_fast=model_args.use_fast_tokenizer,
)
if model_args.use_fast_tokenizer:
logger.warning("Disabling fast tokenizer warning: https://github.com/huggingface/transformers/blob/main/src/transformers/tokenization_utils_base.py#L3231-L3235")
prompt_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
description_tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
# create model + TODO: not from_pretrained probably
# create model + TODO(YL): not from_pretrained probably
model = StableSpeechForConditionalGeneration.from_pretrained(
model_args.model_name_or_path,
cache_dir=model_args.cache_dir,
......@@ -1002,12 +1032,23 @@ def main():
# Freeze Encoders
model.freeze_encoders(model_args.freeze_text_encoder)
if not dataset_was_precomputed:
# resample target audio
raw_datasets = raw_datasets.cast_column(
target_audio_column_name, datasets.features.Audio(sampling_rate=sampling_rate)
)
# TODO: remove
# Test all gather - used for warmout and avoiding timeout
test_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
gathered_tensor = accelerator.gather(test_tensor)
print("gathered_tensor", gathered_tensor)
accelerator.wait_for_everyone()
if not dataset_was_precomputed:
# Filter on text length
if description_column_name is not None:
with accelerator.main_process_first():
# filter description that is shorter than max_text_length
raw_datasets = raw_datasets.filter(
lambda x: len(x) < data_args.max_text_length,
num_proc=num_workers,
input_columns=[description_column_name],
)
# Preprocessing the datasets.
# We need to read the audio files as arrays and tokenize the texts.
......@@ -1023,7 +1064,8 @@ def main():
# load audio
target_sample = batch[target_audio_column_name]
labels = feature_extractor(target_sample["array"], sampling_rate=target_sample["sampling_rate"])
arr = target_sample["array"]
labels = feature_extractor(arr[:min(len(arr), max_target_length+10)], sampling_rate=target_sample["sampling_rate"])
batch["labels"] = labels["input_values"]
# take length of raw audio waveform
......@@ -1038,6 +1080,7 @@ def main():
desc="preprocess datasets",
)
with accelerator.main_process_first():
def is_audio_in_length_range(length):
return length > min_target_length and length < max_target_length
......@@ -1048,10 +1091,58 @@ def main():
input_columns=["target_length"],
)
# 5. Now we encode the audio labels with encodec.
# We use Accelerate to perform distributed inference
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
####### B. Encode audio
####### B. (Optional) Encode text if text encoder is freezed
if model_args.freeze_text_encoder and model_args.precompute_text_hidden_states:
text_data_collator = T5TextCollatorStableSpeechWithPadding(description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of)
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=training_args.text_encode_per_device_eval_batch_size,
collate_fn=text_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
all_encoder_lengths = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
return output
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor.
input_columns=["input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
)
accelerator.free_memory()
del data_loader, all_encoder_outputs, all_encoder_lengths
# Now we encode the audio labels with encodec.
####### C. Encode audio
logger.info("*** Encode target audio with encodec ***")
......@@ -1133,7 +1224,7 @@ def main():
output["prompt_input_ids"] = prompt_input_ids
return output
# TODO: done multiple times, how to deal with it.
# TODO(YL): done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
......@@ -1146,53 +1237,8 @@ def main():
accelerator.free_memory()
del generate_labels
del generate_labels, all_generated_labels, all_lens, all_ratios
# T5 doesn't support fp16
autocast_kwargs = AutocastKwargs(enabled= (mixed_precision != "fp16"))
####### C. Encode text if text encoder is freezed
if model_args.freeze_text_encoder:
text_data_collator = T5TextCollatorStableSpeechWithPadding(description_tokenizer, pad_to_multiple_of=data_args.pad_to_multiple_of)
for split in vectorized_datasets:
data_loader = DataLoader(
vectorized_datasets[split],
batch_size=training_args.text_encode_per_device_eval_batch_size,
collate_fn=text_data_collator,
num_workers=training_args.dataloader_num_workers,
pin_memory=True,
)
data_loader = accelerator.prepare(data_loader)
all_encoder_outputs = []
all_encoder_lengths = []
for batch in tqdm(data_loader, disable=not accelerator.is_local_main_process):
model.text_encoder.to(batch["input_ids"].device)
with accelerator.autocast(autocast_handler=autocast_kwargs):
encoder_outputs = model.text_encoder(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
encoder_outputs = accelerator.pad_across_processes(encoder_outputs, dim=1, pad_index=prompt_tokenizer.pad_token_id)
encoder_outputs = accelerator.gather_for_metrics(encoder_outputs)
lengths = accelerator.gather_for_metrics(batch["len_input_ids"])
# TODO: check it works multi device
all_encoder_outputs.extend(encoder_outputs.last_hidden_state.to("cpu"))
all_encoder_lengths.extend(lengths.to("cpu"))
def postprocess_dataset(input_ids, idx):
output = {"encoder_outputs": BaseModelOutput(last_hidden_state=all_encoder_outputs[idx][:all_encoder_lengths[idx]])}
return output
# TODO: done multiple times, how to deal with it.
with accelerator.main_process_first():
vectorized_datasets[split] = vectorized_datasets[split].map(
postprocess_dataset,
num_proc=1, # this one is resource consuming if many processor.
input_columns=["input_ids"],
desc="Postprocessing labeling",
with_indices=True,
writer_batch_size=100,
)
if data_args.save_to_disk is not None and not dataset_was_precomputed:
......@@ -1418,7 +1464,7 @@ def main():
):
model.train()
if mixed_precision == "fp16" and not model_args.freeze_text_encoder:
if mixed_precision == "fp16" and not (model_args.freeze_text_encoder and model_args.precompute_text_hidden_states):
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
......@@ -1438,7 +1484,7 @@ def main():
# Define eval fn
def eval_step(batch, accelerator, autocast_kwargs,):
model.eval()
if mixed_precision == "fp16" and not model_args.freeze_text_encoder:
if mixed_precision == "fp16" and not (model_args.freeze_text_encoder and model_args.precompute_text_hidden_states):
# fp16 doesn't work with T5-like models
with accelerator.autocast(autocast_handler=autocast_kwargs):
if training_args.parallel_mode.value != "distributed":
......@@ -1462,7 +1508,7 @@ def main():
for epoch in range(epochs_trained, num_epochs):
vectorized_datasets["train"] = vectorized_datasets["train"].shuffle(training_args.seed)
# TODO: add args
# TODO(YL): add args
sampler = LengthGroupedSampler(train_batch_size, lengths = vectorized_datasets["train"]["target_length"])
train_dataloader = DataLoader(
vectorized_datasets["train"],
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment