"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "77966a43a405667962f22d919dc97183a0ef644b"
Unverified Commit 55f49c5f authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Wav2Vec2 Example] Improve fine-tuning script (#14373)

* improve some stuff

* finish

* correct last
parent 21546e59
......@@ -99,9 +99,24 @@ class ModelArguments:
metadata={
"help": "Probability of each feature vector along the time axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
"vectors will be masked along the time axis. This is only relevant if ``apply_spec_augment is True``."
"vectors will be masked along the time axis."
},
)
mask_time_length: Optional[int] = field(
default=10,
metadata={"help": "Length of vector span to mask along the time axis."},
)
mask_feature_prob: Optional[float] = field(
default=0.0,
metadata={
"help": "Probability of each feature vector along the feature axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
},
)
mask_feature_length: Optional[int] = field(
default=10,
metadata={"help": "Length of vector span to mask along the feature axis."},
)
layerdrop: Optional[float] = field(default=0.0, metadata={"help": "The LayerDrop probability."})
ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
......@@ -169,6 +184,10 @@ class DataTrainingArguments:
default=None,
metadata={"help": "A list of characters to remove from the transcripts."},
)
eval_metrics: Optional[List[str]] = list_field(
default=["wer"],
metadata={"help": "A list of metrics the model should be evaluated on. E.g. `'wer cer'`"},
)
max_duration_in_seconds: Optional[float] = field(
default=20.0,
metadata={
......@@ -446,6 +465,9 @@ def main():
"hidden_dropout": model_args.hidden_dropout,
"final_dropout": model_args.final_dropout,
"mask_time_prob": model_args.mask_time_prob,
"mask_time_length": model_args.mask_time_length,
"mask_feature_prob": model_args.mask_feature_prob,
"mask_feature_length": model_args.mask_feature_length,
"gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop,
"ctc_loss_reduction": model_args.ctc_loss_reduction,
......@@ -519,8 +541,8 @@ def main():
# Let's use word error rate (WER) as our evaluation metric,
# instantiate a data collator and the trainer
# Define Metric during training
wer_metric = load_metric("wer")
# Define evaluation metrics during training, *i.e.* word error rate, character error rate
eval_metrics = {metric: load_metric(metric) for metric in data_args.eval_metrics}
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
......@@ -541,9 +563,9 @@ def main():
# we do not want to group tokens when computing the metrics
label_str = processor.batch_decode(pred.label_ids, group_tokens=False)
wer = wer_metric.compute(predictions=pred_str, references=label_str)
metrics = {k: v.compute(predictions=pred_str, references=label_str) for k, v in eval_metrics.items()}
return {"wer": wer}
return metrics
# Instantiate custom data collator
data_collator = DataCollatorCTCWithPadding(processor=processor)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment