You need to sign in or sign up before continuing.
Unverified Commit 772307be authored by Yoach Lacombe's avatar Yoach Lacombe Committed by GitHub
Browse files

Making CTC training example more general (#28582)



* add w2v2bert compatibility

* Update examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
parent 186aa6be
...@@ -132,10 +132,17 @@ class ModelArguments: ...@@ -132,10 +132,17 @@ class ModelArguments:
ctc_loss_reduction: Optional[str] = field( ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
) )
ctc_zero_infinity: Optional[bool] = field(
default=False,
metadata={
"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`. Infinite losses mainly"
" occur when the inputs are too short to be aligned to the targets."
},
)
add_adapter: Optional[bool] = field( add_adapter: Optional[bool] = field(
default=False, default=False,
metadata={ metadata={
"help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2BERT Encoder. Can be very" "help": "Whether a convolutional attention network should be stacked on top of the Wav2Vec2Bert Encoder. Can be very"
"useful to downsample the output length." "useful to downsample the output length."
}, },
) )
...@@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding: ...@@ -316,11 +323,14 @@ class DataCollatorCTCWithPadding:
padding: Union[bool, str] = "longest" padding: Union[bool, str] = "longest"
pad_to_multiple_of: Optional[int] = None pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None
feature_extractor_input_name: Optional[str] = "input_values"
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features] input_features = [
{self.feature_extractor_input_name: feature[self.feature_extractor_input_name]} for feature in features
]
label_features = [{"input_ids": feature["labels"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad( batch = self.processor.pad(
...@@ -606,6 +616,7 @@ def main(): ...@@ -606,6 +616,7 @@ def main():
"gradient_checkpointing": training_args.gradient_checkpointing, "gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop, "layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction, "ctc_loss_reduction": model_args.ctc_loss_reduction,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"pad_token_id": tokenizer.pad_token_id, "pad_token_id": tokenizer.pad_token_id,
"vocab_size": len(tokenizer), "vocab_size": len(tokenizer),
"activation_dropout": model_args.activation_dropout, "activation_dropout": model_args.activation_dropout,
...@@ -643,6 +654,7 @@ def main(): ...@@ -643,6 +654,7 @@ def main():
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers num_workers = data_args.preprocessing_num_workers
feature_extractor_input_name = feature_extractor.model_input_names[0]
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language phoneme_language = data_args.phoneme_language
...@@ -654,8 +666,9 @@ def main(): ...@@ -654,8 +666,9 @@ def main():
sample = batch[audio_column_name] sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
batch["input_values"] = inputs.input_values[0] batch[feature_extractor_input_name] = getattr(inputs, feature_extractor_input_name)[0]
batch["input_length"] = len(batch["input_values"]) # take length of raw audio waveform
batch["input_length"] = len(sample["array"].squeeze())
# encode targets # encode targets
additional_kwargs = {} additional_kwargs = {}
...@@ -736,7 +749,9 @@ def main(): ...@@ -736,7 +749,9 @@ def main():
processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir) processor = Wav2Vec2Processor.from_pretrained(training_args.output_dir)
# Instantiate custom data collator # Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor) data_collator = DataCollatorCTCWithPadding(
processor=processor, feature_extractor_input_name=feature_extractor_input_name
)
# Initialize Trainer # Initialize Trainer
trainer = Trainer( trainer = Trainer(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment