Unverified Commit af1a7c8c authored by Sanchit Gandhi's avatar Sanchit Gandhi Committed by GitHub
Browse files

[Examples] Generalise Seq2Seq ASR to handle Whisper (#19519)

* merge conflicts

* bos and eos in datacollator

* (temp) hardcode removal of attention mask

* freeze encoder

* actually freeze encoder

* set max length / num beams according to gen kwargs

* (temp) fix tests

* don't pop attn mask

* override return attention mask config from Hub

* Hub configs updated 🤗

* final fixes

* update type annotations

* backward comp
parent 7ecb0391
...@@ -97,6 +97,22 @@ class ModelArguments: ...@@ -97,6 +97,22 @@ class ModelArguments:
freeze_feature_encoder: bool = field( freeze_feature_encoder: bool = field(
default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."} default=True, metadata={"help": "Whether to freeze the feature encoder layers of the model."}
) )
freeze_encoder: bool = field(
default=False, metadata={"help": "Whether to freeze the entire encoder of the seq2seq model."}
)
forced_decoder_ids: List[List[int]] = field(
default=None,
metadata={
"help": (
"A list of pairs of integers which indicates a mapping from generation indices to token indices "
"that will be forced before sampling. For example, [[0, 123]] means the first generated token "
"will always be a token of index 123."
)
},
)
suppress_tokens: List[int] = field(
default=None, metadata={"help": "A list of tokens that will be suppressed at generation."}
)
@dataclass @dataclass
...@@ -187,6 +203,19 @@ class DataTrainingArguments: ...@@ -187,6 +203,19 @@ class DataTrainingArguments:
default=True, default=True,
metadata={"help": "Whether the target text should be lower cased."}, metadata={"help": "Whether the target text should be lower cased."},
) )
language: str = field(
default=None,
metadata={
"help": (
"Language for multilingual fine-tuning. This argument should be set for multilingual fine-tuning "
"only. For English speech recognition, it should be set to `None`."
)
},
)
task: str = field(
default="transcribe",
metadata={"help": "Task, either `transcribe` for speech recognition or `translate` for speech translation."},
)
@dataclass @dataclass
...@@ -194,7 +223,7 @@ class DataCollatorSpeechSeq2SeqWithPadding: ...@@ -194,7 +223,7 @@ class DataCollatorSpeechSeq2SeqWithPadding:
""" """
Data collator that will dynamically pad the inputs received. Data collator that will dynamically pad the inputs received.
Args: Args:
processor ([`Wav2Vec2Processor`]) processor ([`WhisperProcessor`])
The processor used for processing the data. The processor used for processing the data.
decoder_start_token_id (`int`) decoder_start_token_id (`int`)
The begin-of-sentence of the decoder. The begin-of-sentence of the decoder.
...@@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding: ...@@ -206,7 +235,8 @@ class DataCollatorSpeechSeq2SeqWithPadding:
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
# split inputs and labels since they have to be of different lengths and need # split inputs and labels since they have to be of different lengths and need
# different padding methods # different padding methods
input_features = [{"input_values": feature["input_values"]} for feature in features] model_input_name = self.processor.model_input_names[0]
input_features = [{model_input_name: feature[model_input_name]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt") batch = self.processor.feature_extractor.pad(input_features, return_tensors="pt")
...@@ -333,6 +363,8 @@ def main(): ...@@ -333,6 +363,8 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
config.update({"forced_decoder_ids": model_args.forced_decoder_ids, "suppress_tokens": model_args.suppress_tokens})
feature_extractor = AutoFeatureExtractor.from_pretrained( feature_extractor = AutoFeatureExtractor.from_pretrained(
model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path, model_args.feature_extractor_name if model_args.feature_extractor_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -360,6 +392,14 @@ def main(): ...@@ -360,6 +392,14 @@ def main():
if model_args.freeze_feature_encoder: if model_args.freeze_feature_encoder:
model.freeze_feature_encoder() model.freeze_feature_encoder()
if model_args.freeze_encoder:
model.freeze_encoder()
model.model.encoder.gradient_checkpointing = False
if data_args.language is not None:
# We only need to set the task id when the language is specified (i.e. in a multilingual setting)
tokenizer.set_prefix_tokens(language=data_args.language, task=data_args.task)
# 6. Resample speech dataset if necessary # 6. Resample speech dataset if necessary
dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate dataset_sampling_rate = next(iter(raw_datasets.values())).features[data_args.audio_column_name].sampling_rate
if dataset_sampling_rate != feature_extractor.sampling_rate: if dataset_sampling_rate != feature_extractor.sampling_rate:
...@@ -388,8 +428,8 @@ def main(): ...@@ -388,8 +428,8 @@ def main():
sample = batch[audio_column_name] sample = batch[audio_column_name]
inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"]) inputs = feature_extractor(sample["array"], sampling_rate=sample["sampling_rate"])
# process audio length # process audio length
batch[model_input_name] = inputs.input_values[0] batch[model_input_name] = inputs.get(model_input_name)[0]
batch["input_length"] = len(batch["input_values"]) batch["input_length"] = len(sample["array"])
# process targets # process targets
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name] input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
...@@ -452,7 +492,8 @@ def main(): ...@@ -452,7 +492,8 @@ def main():
# 10. Define data collator # 10. Define data collator
data_collator = DataCollatorSpeechSeq2SeqWithPadding( data_collator = DataCollatorSpeechSeq2SeqWithPadding(
processor=processor, decoder_start_token_id=model.config.decoder_start_token_id processor=processor,
decoder_start_token_id=model.config.decoder_start_token_id,
) )
# 11. Initialize Trainer # 11. Initialize Trainer
...@@ -492,7 +533,9 @@ def main(): ...@@ -492,7 +533,9 @@ def main():
if training_args.do_eval: if training_args.do_eval:
logger.info("*** Evaluate ***") logger.info("*** Evaluate ***")
metrics = trainer.evaluate( metrics = trainer.evaluate(
metric_key_prefix="eval", max_length=model.config.max_length, num_beams=model.config.num_beams metric_key_prefix="eval",
max_length=training_args.generation_max_length,
num_beams=training_args.generation_num_beams,
) )
max_eval_samples = ( max_eval_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment