Unverified Commit 6192549c authored by bofeng huang's avatar bofeng huang Committed by GitHub
Browse files

[examples/speech-recognition] Add SpecAugment to run_speech_recognition_seq2seq.py (#21942)



* Add specaugment to run_speech_recognition_seq2seq.py

* Remove useless argument: text_column

* Fix quality

* Update return_attention_mask condition

* Update specaugment arguments only for whisper models

* Remove SpecAugment arguments from ModelArguments, only leave default values for simplicity

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Update apply_spec_augment only for whisper models

* Apply suggestions from code review
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>

* Rename return_attention_mask to forward_attention_mask to avoid confusion with wav2vec2 models

---------
Co-authored-by: default avatarSanchit Gandhi <93869735+sanchit-gandhi@users.noreply.github.com>
parent b427b263
...@@ -113,6 +113,12 @@ class ModelArguments: ...@@ -113,6 +113,12 @@ class ModelArguments:
suppress_tokens: List[int] = field( suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."} default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
) )
apply_spec_augment: bool = field(
default=False,
metadata={
"help": "Whether to apply *SpecAugment* data augmentation to the input features. This is currently only relevant for Wav2Vec2, HuBERT, WavLM and Whisper models."
},
)
@dataclass @dataclass
...@@ -127,10 +133,6 @@ class DataTrainingArguments: ...@@ -127,10 +133,6 @@ class DataTrainingArguments:
dataset_config_name: Optional[str] = field( dataset_config_name: Optional[str] = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."}
) )
text_column: Optional[str] = field(
default=None,
metadata={"help": "The name of the column in the datasets containing the full texts (for summarization)."},
)
overwrite_cache: bool = field( overwrite_cache: bool = field(
default=False, metadata={"help": "Overwrite the cached training and evaluation sets"} default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
) )
...@@ -227,10 +229,13 @@ class DataCollatorSpeechSeq2SeqWithPadding: ...@@ -227,10 +229,13 @@ class DataCollatorSpeechSeq2SeqWithPadding:
The processor used for processing the data. The processor used for processing the data.
decoder_start_token_id (`int`) decoder_start_token_id (`int`)
The begin-of-sentence of the decoder. The begin-of-sentence of the decoder.
forward_attention_mask (`bool`)
Whether to return attention_mask.
""" """
processor: Any processor: Any
decoder_start_token_id: int decoder_start_token_id: int
forward_attention_mask: bool
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
...@@ -241,6 +246,9 @@ class DataCollatorSpeechSeq2SeqWithPadding: ...@@ -241,6 +246,9 @@ class DataCollatorSpeechSeq2SeqWithPadding:
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
if self.forward_attention_mask:
batch["attention_mask"] = torch.LongTensor([feature["attention_mask"] for feature in features])
labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt") labels_batch = self.processor.tokenizer.pad(label_features, return_tensors="pt")
# replace padding with -100 to ignore loss correctly # replace padding with -100 to ignore loss correctly
...@@ -367,6 +375,10 @@ def main(): ...@@ -367,6 +375,10 @@ def main():
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens}) config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
# SpecAugment for whisper models
if getattr(config, "model_type", None) == "whisper":
config.update({"apply_spec_augment": model_args.apply_spec_augment})
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -418,6 +430,12 @@ def main(): ...@@ -418,6 +430,12 @@ def main():
text_column_name = data_args.text_column_name text_column_name = data_args.text_column_name
model_input_name = feature_extractor.model_input_names[0] model_input_name = feature_extractor.model_input_names[0]
do_lower_case = data_args.do_lower_case do_lower_case = data_args.do_lower_case
# if SpecAugment is used for whisper models, return attention_mask to guide the mask along time axis
forward_attention_mask = (
getattr(config, "model_type", None) == "whisper"
and getattr(config, "apply_spec_augment", False)
and getattr(config, "mask_time_prob", 0) > 0
)
if data_args.max_train_samples is not None: if data_args.max_train_samples is not None:
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples)) raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
...@@ -428,10 +446,14 @@ def main(): ...@@ -428,10 +446,14 @@ def main():
def prepare_dataset(batch): def prepare_dataset(batch):
# process audio # process audio
sample = batch[audio_column_name] sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) inputs = feature_extractor(
sample["array"], sampling_rate=sample["sampling_rate"], return_attention_mask=forward_attention_mask
)
# process audio length # process audio length
batch[model_input_name] = inputs.get(model_input_name)[0] batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(sample["array"]) batch["input_length"] = len(sample["array"])
if forward_attention_mask:
batch["attention_mask"] = inputs.get("attention_mask")[0]
# process targets # process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
...@@ -496,6 +518,7 @@ def main(): ...@@ -496,6 +518,7 @@ def main():
data_collator = DataCollatorSpeechSeq2SeqWithPadding( data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor, processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id, decoder_start_token_id=model.config.decoder_start_token_id,
forward_attention_mask=forward_attention_mask,
) )
# 11. Initialize Trainer # 11. Initialize Trainer
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment