"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "bc2571e61c985ec82819cf01ad038342771c94d0"
Unverified Commit a4a88fa0 authored by Anton Lozhkov's avatar Anton Lozhkov Committed by GitHub
Browse files

[Research] Speed up evaluation for XTREME-S (#16785)

* Avoid repeated per-lang filtering

* Language groups and logits preprocessing

* Style
parent 2d91e3c3
...@@ -136,6 +136,10 @@ class ModelArguments: ...@@ -136,6 +136,10 @@ class ModelArguments:
metadata={"help": "Length of vector span to mask along the feature axis."}, metadata={"help": "Length of vector span to mask along the feature axis."},
) )
layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."}) layerdrop: float = field(default=0.0, metadata={"help": "The LayerDrop probability."})
ctc_zero_infinity: bool = field(
default=False,
metadata={"help": "Whether to zero infinite losses and the associated gradients of `torch.nn.CTCLoss`."},
)
ctc_loss_reduction: Optional[str] = field( ctc_loss_reduction: Optional[str] = field(
default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."} default="mean", metadata={"help": "The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."}
) )
...@@ -166,6 +170,15 @@ class DataTrainingArguments: ...@@ -166,6 +170,15 @@ class DataTrainingArguments:
default="all", default="all",
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."}, metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
) )
language_group: str = field(
default=None,
metadata={
"help": "The language group to select a subset of languages to train on. "
"This option is only used the 'fleurs-asr' task. Should be one of: "
"'western_european_we', 'eastern_european_ee', 'central_asia_middle_north_african_cmn', "
"'sub_saharan_african_ssa', 'south_asian_sa', 'south_east_asian_sea', 'chinese_japanase_korean_cjk'."
},
)
train_split_name: str = field( train_split_name: str = field(
default="train", default="train",
metadata={ metadata={
...@@ -441,6 +454,11 @@ def main(): ...@@ -441,6 +454,11 @@ def main():
"config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'" "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
" for multi-lingual fine-tuning." " for multi-lingual fine-tuning."
) )
if data_args.language_group is not None:
if data_args.task != "fleurs-asr":
raise ValueError("--language_group should only be used with --task=fleurs-asr")
if data_args.language != "all":
raise ValueError("--language_group should only be used with --language=all")
if data_args.target_column_name is None: if data_args.target_column_name is None:
target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name] target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
...@@ -502,11 +520,23 @@ def main(): ...@@ -502,11 +520,23 @@ def main():
if data_args.max_predict_samples is not None: if data_args.max_predict_samples is not None:
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples)) raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
if not is_text_target: if not is_text_target:
label_list = next(iter(raw_datasets.values())).features[target_column_name].names label_list = next(iter(raw_datasets.values())).features[target_column_name].names
lang_list = next(iter(raw_datasets.values())).features["lang_id"].names
num_labels = len(label_list) num_labels = len(label_list)
num_workers = data_args.preprocessing_num_workers
lang_group = data_args.language_group
if lang_group is not None:
with training_args.main_process_first(desc="language group filter"):
lang_group_id = next(iter(raw_datasets.values())).features["lang_group_id"].str2int(lang_group)
raw_datasets = raw_datasets.filter(
lambda lang_group: lang_group == lang_group_id,
num_proc=num_workers,
input_columns=["lang_group_id"],
)
# 2. We remove some special characters from the datasets # 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech # that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
...@@ -616,6 +646,7 @@ def main(): ...@@ -616,6 +646,7 @@ def main():
"mask_feature_length": model_args.mask_feature_length, "mask_feature_length": model_args.mask_feature_length,
"gradient_checkpointing": training_args.gradient_checkpointing, "gradient_checkpointing": training_args.gradient_checkpointing,
"layerdrop": model_args.layerdrop, "layerdrop": model_args.layerdrop,
"ctc_zero_infinity": model_args.ctc_zero_infinity,
"ctc_loss_reduction": model_args.ctc_loss_reduction, "ctc_loss_reduction": model_args.ctc_loss_reduction,
"activation_dropout": model_args.activation_dropout, "activation_dropout": model_args.activation_dropout,
} }
...@@ -675,7 +706,6 @@ def main(): ...@@ -675,7 +706,6 @@ def main():
max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate max_input_length = data_args.max_duration_in_seconds * feature_extractor.sampling_rate
min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate min_input_length = data_args.min_duration_in_seconds * feature_extractor.sampling_rate
audio_column_name = data_args.audio_column_name audio_column_name = data_args.audio_column_name
num_workers = data_args.preprocessing_num_workers
# `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification # `phoneme_language` is only relevant if the model is fine-tuned on phoneme classification
phoneme_language = data_args.phoneme_language phoneme_language = data_args.phoneme_language
...@@ -740,13 +770,13 @@ def main(): ...@@ -740,13 +770,13 @@ def main():
logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}") logger.info(f"Data preprocessing finished. Files cached at {vectorized_datasets.cache_files}")
return return
def compute_asr_metric(pred): def asr_logits_argmax(logits, labels):
pred_logits = pred.predictions return logits.argmax(dim=-1)
pred_ids = np.argmax(pred_logits, axis=-1)
def compute_asr_metric(pred):
pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id pred.label_ids[pred.label_ids == -100] = tokenizer.pad_token_id
pred_str = tokenizer.batch_decode(pred_ids) pred_str = tokenizer.batch_decode(pred.predictions)
# we do not want to group tokens when computing the metrics # we do not want to group tokens when computing the metrics
label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False) label_str = tokenizer.batch_decode(pred.label_ids, group_tokens=False)
...@@ -783,6 +813,7 @@ def main(): ...@@ -783,6 +813,7 @@ def main():
model=model, model=model,
data_collator=data_collator, data_collator=data_collator,
args=training_args, args=training_args,
preprocess_logits_for_metrics=asr_logits_argmax if training_args.predict_with_generate else None,
compute_metrics=compute_asr_metric if training_args.predict_with_generate else None, compute_metrics=compute_asr_metric if training_args.predict_with_generate else None,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
...@@ -793,6 +824,7 @@ def main(): ...@@ -793,6 +824,7 @@ def main():
model=model, model=model,
data_collator=data_collator, data_collator=data_collator,
args=training_args, args=training_args,
preprocess_logits_for_metrics=asr_logits_argmax if is_text_target else None,
compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric, compute_metrics=compute_asr_metric if is_text_target else compute_classification_metric,
train_dataset=vectorized_datasets["train"] if training_args.do_train else None, train_dataset=vectorized_datasets["train"] if training_args.do_train else None,
eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None, eval_dataset=vectorized_datasets["eval"] if training_args.do_eval else None,
...@@ -837,11 +869,17 @@ def main(): ...@@ -837,11 +869,17 @@ def main():
average_metrics = defaultdict(list) average_metrics = defaultdict(list)
for lang_id in range(len(lang_list)): for lang_id in range(len(lang_list)):
lang_name = lang_list[lang_id] lang_name = lang_list[lang_id]
lang_dataset = vectorized_datasets["predict"].filter(lambda example: example["lang"] == lang_id) with training_args.main_process_first(desc="per-language dataset filter"):
lang_dataset = vectorized_datasets["predict"].filter(
lambda lang: lang == lang_id,
num_proc=num_workers,
input_columns=["lang"],
)
lang_metrics = trainer.evaluate(lang_dataset) lang_metrics = trainer.evaluate(lang_dataset)
redundant_metrics = ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second", "eval_epoch"]
for metric_name, value in lang_metrics.items(): for metric_name, value in lang_metrics.items():
average_metrics[metric_name].append(value) average_metrics[metric_name].append(value)
if metric_name not in ["eval_runtime", "eval_samples_per_second", "eval_steps_per_second"]: if metric_name not in redundant_metrics:
metrics[f"{metric_name}_{lang_name}"] = value metrics[f"{metric_name}_{lang_name}"] = value
for metric_name, value in average_metrics.items(): for metric_name, value in average_metrics.items():
metrics[metric_name] = np.mean(value) metrics[metric_name] = np.mean(value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment