Unverified Commit 772307be authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Making CTC training example more general (#28582)



* add w2v2bert compatibility

* Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 186aa6be
...@@ -132,10 +132,17 @@ class ModelArguments: ...@@ -132,10 +132,17 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field( ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
) )
ctc_zero_infinity: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
" occur when the inputs are too short to be aligned to the targets."
},
)
add_adapter: Optional[bool] = field( add_adapter: Optional[bool] = field(
default=False, default=False,
metadata={ metadata={
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very" "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
"useful to downsample the output length." "useful to downsample the output length."
}, },
) )
...@@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding: ...@@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding:
padding: Union[bool, str] = "longest" padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features] input_features = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad( batch = self.processor.pad(
...@@ -606,6 +616,7 @@ def main(): ...@@ -606,6 +616,7 @@ def main():
"gradient_checkpointing": training_args.gradient_checkpointing, "gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop, "layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction, "ctc_loss_reduction": model_args.ctc_loss_reduction,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"pad_token_id": tokenizer.pad_token_id, "pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer), "vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout, "activation_dropout": model_args.activation_dropout,
...@@ -643,6 +654,7 @@ def main(): ...@@ -643,6 +654,7 @@ def main():
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
feature_extractor_input_name = feature_extractor.model_input_names[0]
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language phoneme_language = data_args.phoneme_language
...@@ -654,8 +666,9 @@ def main(): ...@@ -654,8 +666,9 @@ def main():
sample = batch[audio_column_name] sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch["input_values"] = inputs.input_values[0] batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
batch["input_length"] = len(batch["input_values"]) # take length of raw audio waveform
batch["input_length"] = len(sample["array"].squeeze())
# encode targets # encode targets
additional_kwargs = {} additional_kwargs = {}
...@@ -736,7 +749,9 @@ def main(): ...@@ -736,7 +749,9 @@ def main():
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir) processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor) data_collator = DataCollatorCTCWithPadding(
processor=processor, feature_extractor_input_name=feature_extractor_input_name
)
# Initialize Trainer # Initialize Trainer
trainer = Trainer( trainer = Trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment