"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "403d530eec105c0e229fc2b754afdf77a4439def"
Unverified Commit c2dc89be authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Xtreme-S] fix some namings (#16183)

parent 99fd3eb4
...@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/ ...@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python -m torch.distributed.launch \ python -m torch.distributed.launch \
--nproc_per_node=8 \ --nproc_per_node=8 \
run_xtreme_s.py \ run_xtreme_s.py \
--task="mls" \
--language="all" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \
--dataset_config_name="mls.all" \
--eval_split_name="test" \ --eval_split_name="test" \
--output_dir="xtreme_s_xlsr_300m_mls" \ --output_dir="xtreme_s_xlsr_300m_mls" \
--overwrite_output_dir \ --overwrite_output_dir \
...@@ -94,7 +94,6 @@ python -m torch.distributed.launch \ ...@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
--learning_rate="3e-4" \ --learning_rate="3e-4" \
--warmup_steps=3000 \ --warmup_steps=3000 \
--evaluation_strategy="steps" \ --evaluation_strategy="steps" \
--target_column_name="transcription" \
--max_duration_in_seconds=20 \ --max_duration_in_seconds=20 \
--save_steps=500 \ --save_steps=500 \
--eval_steps=500 \ --eval_steps=500 \
...@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/ ...@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python -m torch.distributed.launch \ python -m torch.distributed.launch \
--nproc_per_node=2 \ --nproc_per_node=2 \
run_xtreme_s.py \ run_xtreme_s.py \
--task="minds14" \
--language="all" \
--model_name_or_path="facebook/wav2vec2-xls-r-300m" \ --model_name_or_path="facebook/wav2vec2-xls-r-300m" \
--dataset_name="google/xtreme_s" \
--dataset_config_name="minds14.all" \
--eval_split_name="test" \
--output_dir="xtreme_s_xlsr_300m_minds14" \ --output_dir="xtreme_s_xlsr_300m_minds14" \
--overwrite_output_dir \ --overwrite_output_dir \
--num_train_epochs=50 \ --num_train_epochs=50 \
...@@ -139,7 +137,6 @@ python -m torch.distributed.launch \ ...@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
--learning_rate="3e-4" \ --learning_rate="3e-4" \
--warmup_steps=1500 \ --warmup_steps=1500 \
--evaluation_strategy="steps" \ --evaluation_strategy="steps" \
--target_column_name="intent_class" \
--max_duration_in_seconds=30 \ --max_duration_in_seconds=30 \
--save_steps=200 \ --save_steps=200 \
--eval_steps=200 \ --eval_steps=200 \
......
...@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None): ...@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
return field(default_factory=lambda: default, metadata=metadata) return field(default_factory=lambda: default, metadata=metadata)
TASK_TO_TARGET_COLUMN_NAME = {
"fleurs-asr": "transcription",
"fleurs-lang_id": "lang_id",
"mls": "transcription",
"voxpopuli": "transcription",
"covost2": "translation",
"minds14": "intent_class",
"babel": "transcription",
}
@dataclass @dataclass
class ModelArguments: class ModelArguments:
""" """
...@@ -144,8 +155,16 @@ class DataTrainingArguments: ...@@ -144,8 +155,16 @@ class DataTrainingArguments:
default="xtreme_s", default="xtreme_s",
metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"}, metadata={"help": "The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"},
) )
dataset_config_name: str = field( task: str = field(
default=None, metadata={"help": "The configuration name of the dataset to use (via the datasets library)."} default=None,
metadata={
"help": "The task name of the benchmark to use (via the datasets library). Should be on of: "
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
},
)
language: str = field(
default="all",
metadata={"help": "The language id as defined in the datasets config name or `all` for all languages."},
) )
train_split_name: str = field( train_split_name: str = field(
default="train", default="train",
...@@ -160,6 +179,13 @@ class DataTrainingArguments: ...@@ -160,6 +179,13 @@ class DataTrainingArguments:
"Defaults to 'validation'" "Defaults to 'validation'"
}, },
) )
predict_split_name: str = field(
default="test",
metadata={
"help": "The name of the prediction data set split to use (via the datasets library). "
"Defaults to 'test'"
},
)
audio_column_name: str = field( audio_column_name: str = field(
default="audio", default="audio",
metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"}, metadata={"help": "The name of the dataset column containing the audio data. Defaults to 'audio'"},
...@@ -192,6 +218,13 @@ class DataTrainingArguments: ...@@ -192,6 +218,13 @@ class DataTrainingArguments:
"value if set." "value if set."
}, },
) )
max_predict_samples: Optional[int] = field(
default=None,
metadata={
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
chars_to_ignore: Optional[List[str]] = list_field( chars_to_ignore: Optional[List[str]] = list_field(
default=', ? . ! - ; : " “ % ‘ ” �'.split(" "), default=', ? . ! - ; : " “ % ‘ ” �'.split(" "),
metadata={"help": "A list of characters to remove from the transcripts."}, metadata={"help": "A list of characters to remove from the transcripts."},
...@@ -387,22 +420,31 @@ def main(): ...@@ -387,22 +420,31 @@ def main():
# 1. First, let's load the dataset # 1. First, let's load the dataset
raw_datasets = DatasetDict() raw_datasets = DatasetDict()
if data_args.dataset_config_name is None: task_name = data_args.task
lang_id = data_args.language
if task_name is None:
raise ValueError(
"Set --task should be set to '<xtreme_s_task>' " "(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if lang_id is None:
raise ValueError( raise ValueError(
"Set --dataset_config_name should be set to '<xtreme_s_subset>.<language(s)>' " "Set --language should be set to the language id of the sub dataset "
"(e.g. 'mls.pl', 'covost2.en.tr', 'minds14.fr-FR') " "config to be used (e.g. 'pl', 'en.tr', 'fr-FR') or 'all'"
"or '<xtreme_s_subset>.all' for multi-lingual fine-tuning." " for multi-lingual fine-tuning."
) )
task_name = data_args.dataset_config_name.split(".")[0] target_column_name = TASK_TO_TARGET_COLUMN_NAME[task_name]
target_column_name = data_args.target_column_name
# here we differentiate between tasks with text as the target and classification tasks # here we differentiate between tasks with text as the target and classification tasks
is_text_target = target_column_name in ("transcription", "translation") is_text_target = target_column_name in ("transcription", "translation")
config_name = ".".join([task_name.split("-")[0], lang_id])
if training_args.do_train: if training_args.do_train:
raw_datasets["train"] = load_dataset( raw_datasets["train"] = load_dataset(
data_args.dataset_name, data_args.dataset_name,
data_args.dataset_config_name, config_name,
split=data_args.train_split_name, split=data_args.train_split_name,
use_auth_token=data_args.use_auth_token, use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -432,7 +474,7 @@ def main(): ...@@ -432,7 +474,7 @@ def main():
if training_args.do_eval: if training_args.do_eval:
raw_datasets["eval"] = load_dataset( raw_datasets["eval"] = load_dataset(
data_args.dataset_name, data_args.dataset_name,
data_args.dataset_config_name, config_name,
split=data_args.eval_split_name, split=data_args.eval_split_name,
use_auth_token=data_args.use_auth_token, use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
...@@ -441,6 +483,18 @@ def main(): ...@@ -441,6 +483,18 @@ def main():
if data_args.max_eval_samples is not None: if data_args.max_eval_samples is not None:
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples)) raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
if training_args.do_predict:
raw_datasets["predict"] = load_dataset(
data_args.dataset_name,
config_name,
split=data_args.predict_split_name,
use_auth_token=data_args.use_auth_token,
cache_dir=model_args.cache_dir,
)
if data_args.max_predict_samples is not None:
raw_datasets["predict"] = raw_datasets["predict"].select(range(data_args.max_predict_samples))
# 2. We remove some special characters from the datasets # 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech # that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic # E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
...@@ -757,24 +811,25 @@ def main(): ...@@ -757,24 +811,25 @@ def main():
# Evaluation # Evaluation
results = {} results = {}
if training_args.do_eval: if training_args.do_predict:
logger.info("*** Evaluate ***") logger.info("*** Predicte ***")
metrics = trainer.evaluate() metrics = trainer.evaluate(vectorized_datasets["predict"])
max_eval_samples = ( max_predict_samples = (
data_args.max_eval_samples if data_args.max_eval_samples is not None else len(vectorized_datasets["eval"]) data_args.max_predict_samples
if data_args.max_predict_samples is not None
else len(vectorized_datasets["predict"])
) )
metrics["eval_samples"] = min(max_eval_samples, len(vectorized_datasets["eval"])) metrics["predict_samples"] = min(max_predict_samples, len(vectorized_datasets["predict"]))
trainer.log_metrics("eval", metrics) trainer.log_metrics("predict", metrics)
trainer.save_metrics("eval", metrics) trainer.save_metrics("predict", metrics)
# Write model card and (optionally) push to hub # Write model card and (optionally) push to hub
config_name = data_args.dataset_config_name if data_args.dataset_config_name is not None else "na"
kwargs = { kwargs = {
"finetuned_from": model_args.model_name_or_path, "finetuned_from": model_args.model_name_or_path,
"tasks": "speech-recognition", "tasks": task_name,
"tags": ["automatic-speech-recognition", data_args.dataset_name], "tags": [task_name, data_args.dataset_name],
"dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}", "dataset_args": f"Config: {config_name}, Training split: {data_args.train_split_name}, Eval split: {data_args.eval_split_name}, Predict split: {data_args.predict_split_name}",
"dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}", "dataset": f"{data_args.dataset_name.upper()} - {config_name.upper()}",
} }
if "common_voice" in data_args.dataset_name: if "common_voice" in data_args.dataset_name:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment