Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2dc89be
Unverified
Commit
c2dc89be
authored
Mar 16, 2022
by
Patrick von Platen
Committed by
GitHub
Mar 16, 2022
Browse files
[Xtreme-S] fix some namings (#16183)
parent
99fd3eb4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
29 deletions
+81
-29
examples/research_projects/xtreme-s/README.md
examples/research_projects/xtreme-s/README.md
+4
-7
examples/research_projects/xtreme-s/requirements.txt
examples/research_projects/xtreme-s/requirements.txt
+0
-0
examples/research_projects/xtreme-s/run_xtreme_s.py
examples/research_projects/xtreme-s/run_xtreme_s.py
+77
-22
No files found.
examples/research_projects/xreme-s/README.md
→
examples/research_projects/x
t
reme-s/README.md
View file @
c2dc89be
...
@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
...
@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python
-m
torch.distributed.launch
\
python
-m
torch.distributed.launch
\
--nproc_per_node
=
8
\
--nproc_per_node
=
8
\
run_xtreme_s.py
\
run_xtreme_s.py
\
--task
=
"mls"
\
--language
=
"all"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--dataset_name
=
"google/xtreme_s"
\
--dataset_config_name
=
"mls.all"
\
--eval_split_name
=
"test"
\
--eval_split_name
=
"test"
\
--output_dir
=
"xtreme_s_xlsr_300m_mls"
\
--output_dir
=
"xtreme_s_xlsr_300m_mls"
\
--overwrite_output_dir
\
--overwrite_output_dir
\
...
@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
...
@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
--learning_rate
=
"3e-4"
\
--learning_rate
=
"3e-4"
\
--warmup_steps
=
3000
\
--warmup_steps
=
3000
\
--evaluation_strategy
=
"steps"
\
--evaluation_strategy
=
"steps"
\
--target_column_name
=
"transcription"
\
--max_duration_in_seconds
=
20
\
--max_duration_in_seconds
=
20
\
--save_steps
=
500
\
--save_steps
=
500
\
--eval_steps
=
500
\
--eval_steps
=
500
\
...
@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
...
@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python
-m
torch.distributed.launch
\
python
-m
torch.distributed.launch
\
--nproc_per_node
=
2
\
--nproc_per_node
=
2
\
run_xtreme_s.py
\
run_xtreme_s.py
\
--task
=
"minds14"
\
--language
=
"all"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--dataset_name
=
"google/xtreme_s"
\
--dataset_config_name
=
"minds14.all"
\
--eval_split_name
=
"test"
\
--output_dir
=
"xtreme_s_xlsr_300m_minds14"
\
--output_dir
=
"xtreme_s_xlsr_300m_minds14"
\
--overwrite_output_dir
\
--overwrite_output_dir
\
--num_train_epochs
=
50
\
--num_train_epochs
=
50
\
...
@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
...
@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
--learning_rate
=
"3e-4"
\
--learning_rate
=
"3e-4"
\
--warmup_steps
=
1500
\
--warmup_steps
=
1500
\
--evaluation_strategy
=
"steps"
\
--evaluation_strategy
=
"steps"
\
--target_column_name
=
"intent_class"
\
--max_duration_in_seconds
=
30
\
--max_duration_in_seconds
=
30
\
--save_steps
=
200
\
--save_steps
=
200
\
--eval_steps
=
200
\
--eval_steps
=
200
\
...
...
examples/research_projects/xreme-s/requirements.txt
→
examples/research_projects/x
t
reme-s/requirements.txt
View file @
c2dc89be
File moved
examples/research_projects/xreme-s/run_xtreme_s.py
→
examples/research_projects/x
t
reme-s/run_xtreme_s.py
View file @
c2dc89be
...
@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
...
@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
TASK_TO_TARGET_COLUMN_NAME
=
{
"fleurs-asr"
:
"transcription"
,
"fleurs-lang_id"
:
"lang_id"
,
"mls"
:
"transcription"
,
"voxpopuli"
:
"transcription"
,
"covost2"
:
"translation"
,
"minds14"
:
"intent_class"
,
"babel"
:
"transcription"
,
}
@
dataclass
@
dataclass
class
ModelArguments
:
class
ModelArguments
:
"""
"""
...
@@ -144,8 +155,16 @@ class DataTrainingArguments:
...
@@ -144,8 +155,16 @@ class DataTrainingArguments:
default
=
"xtreme_s"
,
default
=
"xtreme_s"
,
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"
},
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"
},
)
)
dataset_config_name
:
str
=
field
(
task
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The configuration name of the dataset to use (via the datasets library)."
}
default
=
None
,
metadata
=
{
"help"
:
"The task name of the benchmark to use (via the datasets library). Should be on of: "
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
},
)
language
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
"The language id as defined in the datasets config name or `all` for all languages."
},
)
)
train_split_name
:
str
=
field
(
train_split_name
:
str
=
field
(
default
=
"train"
,
default
=
"train"
,
...
@@ -160,6 +179,13 @@ class DataTrainingArguments:
...
@@ -160,6 +179,13 @@ class DataTrainingArguments:
"Defaults to 'validation'"
"Defaults to 'validation'"
},
},
)
)
predict_split_name
:
str
=
field
(
default
=
"test"
,
metadata
=
{
"help"
:
"The name of the prediction data set split to use (via the datasets library). "
"Defaults to 'test'"
},
)
audio_column_name
:
str
=
field
(
audio_column_name
:
str
=
field
(
default
=
"audio"
,
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
...
@@ -192,6 +218,13 @@ class DataTrainingArguments:
...
@@ -192,6 +218,13 @@ class DataTrainingArguments:
"value if set."
"value if set."
},
},
)
)
max_predict_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
chars_to_ignore
:
Optional
[
List
[
str
]]
=
list_field
(
chars_to_ignore
:
Optional
[
List
[
str
]]
=
list_field
(
default
=
', ? . ! - ; : " “ % ‘ ” �'
.
split
(
" "
),
default
=
', ? . ! - ; : " “ % ‘ ” �'
.
split
(
" "
),
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
...
@@ -387,22 +420,31 @@ def main():
...
@@ -387,22 +420,31 @@ def main():
# 1. First, let's load the dataset
# 1. First, let's load the dataset
raw_datasets
=
DatasetDict
()
raw_datasets
=
DatasetDict
()
if
data_args
.
dataset_config_name
is
None
:
task_name
=
data_args
.
task
lang_id
=
data_args
.
language
if
task_name
is
None
:
raise
ValueError
(
"Set --task should be set to '<xtreme_s_task>' "
"(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if
lang_id
is
None
:
raise
ValueError
(
raise
ValueError
(
"Set --
dataset_config_nam
e should be set to
'<xtreme_s_subset>.<language(s)>'
"
"Set --
languag
e should be set to
the language id of the sub dataset
"
"(e.g. '
mls.
pl', '
covost2.
en.tr', '
minds14.
fr-FR') "
"
config to be used
(e.g. 'pl', 'en.tr', 'fr-FR')
or 'all'
"
"
or '<xtreme_s_subset>.all'
for multi-lingual fine-tuning."
" for multi-lingual fine-tuning."
)
)
ta
sk_name
=
data_args
.
dataset_config_name
.
split
(
"."
)[
0
]
ta
rget_column_name
=
TASK_TO_TARGET_COLUMN_NAME
[
task_name
]
target_column_name
=
data_args
.
target_column_name
# here we differentiate between tasks with text as the target and classification tasks
# here we differentiate between tasks with text as the target and classification tasks
is_text_target
=
target_column_name
in
(
"transcription"
,
"translation"
)
is_text_target
=
target_column_name
in
(
"transcription"
,
"translation"
)
config_name
=
"."
.
join
([
task_name
.
split
(
"-"
)[
0
],
lang_id
])
if
training_args
.
do_train
:
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_dataset
(
raw_datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_
config_name
,
config_name
,
split
=
data_args
.
train_split_name
,
split
=
data_args
.
train_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
...
@@ -432,7 +474,7 @@ def main():
...
@@ -432,7 +474,7 @@ def main():
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_dataset
(
raw_datasets
[
"eval"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_
config_name
,
config_name
,
split
=
data_args
.
eval_split_name
,
split
=
data_args
.
eval_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
...
@@ -441,6 +483,18 @@ def main():
...
@@ -441,6 +483,18 @@ def main():
if
data_args
.
max_eval_samples
is
not
None
:
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
if
training_args
.
do_predict
:
raw_datasets
[
"predict"
]
=
load_dataset
(
data_args
.
dataset_name
,
config_name
,
split
=
data_args
.
predict_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
)
if
data_args
.
max_predict_samples
is
not
None
:
raw_datasets
[
"predict"
]
=
raw_datasets
[
"predict"
].
select
(
range
(
data_args
.
max_predict_samples
))
# 2. We remove some special characters from the datasets
# 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech
# that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
...
@@ -757,24 +811,25 @@ def main():
...
@@ -757,24 +811,25 @@ def main():
# Evaluation
# Evaluation
results
=
{}
results
=
{}
if
training_args
.
do_eval
:
if
training_args
.
do_predict
:
logger
.
info
(
"*** Evaluate ***"
)
logger
.
info
(
"*** Predicte ***"
)
metrics
=
trainer
.
evaluate
()
metrics
=
trainer
.
evaluate
(
vectorized_datasets
[
"predict"
])
max_eval_samples
=
(
max_predict_samples
=
(
data_args
.
max_eval_samples
if
data_args
.
max_eval_samples
is
not
None
else
len
(
vectorized_datasets
[
"eval"
])
data_args
.
max_predict_samples
if
data_args
.
max_predict_samples
is
not
None
else
len
(
vectorized_datasets
[
"predict"
])
)
)
metrics
[
"
eval
_samples"
]
=
min
(
max_
eval
_samples
,
len
(
vectorized_datasets
[
"
eval
"
]))
metrics
[
"
predict
_samples"
]
=
min
(
max_
predict
_samples
,
len
(
vectorized_datasets
[
"
predict
"
]))
trainer
.
log_metrics
(
"
eval
"
,
metrics
)
trainer
.
log_metrics
(
"
predict
"
,
metrics
)
trainer
.
save_metrics
(
"
eval
"
,
metrics
)
trainer
.
save_metrics
(
"
predict
"
,
metrics
)
# Write model card and (optionally) push to hub
# Write model card and (optionally) push to hub
config_name
=
data_args
.
dataset_config_name
if
data_args
.
dataset_config_name
is
not
None
else
"na"
kwargs
=
{
kwargs
=
{
"finetuned_from"
:
model_args
.
model_name_or_path
,
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tasks"
:
"speech-recognition"
,
"tasks"
:
task_name
,
"tags"
:
[
"automatic-speech-recognition"
,
data_args
.
dataset_name
],
"tags"
:
[
task_name
,
data_args
.
dataset_name
],
"dataset_args"
:
f
"Config:
{
config_name
}
, Training split:
{
data_args
.
train_split_name
}
, Eval split:
{
data_args
.
eval_split_name
}
"
,
"dataset_args"
:
f
"Config:
{
config_name
}
, Training split:
{
data_args
.
train_split_name
}
, Eval split:
{
data_args
.
eval_split_name
}
, Predict split:
{
data_args
.
predict_split_name
}
"
,
"dataset"
:
f
"
{
data_args
.
dataset_name
.
upper
()
}
-
{
config_name
.
upper
()
}
"
,
"dataset"
:
f
"
{
data_args
.
dataset_name
.
upper
()
}
-
{
config_name
.
upper
()
}
"
,
}
}
if
"common_voice"
in
data_args
.
dataset_name
:
if
"common_voice"
in
data_args
.
dataset_name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment