Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
c2dc89be
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "403d530eec105c0e229fc2b754afdf77a4439def"
Unverified
Commit
c2dc89be
authored
Mar 16, 2022
by
Patrick von Platen
Committed by
GitHub
Mar 16, 2022
Browse files
[Xtreme-S] fix some namings (#16183)
parent
99fd3eb4
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
81 additions
and
29 deletions
+81
-29
examples/research_projects/xtreme-s/README.md
examples/research_projects/xtreme-s/README.md
+4
-7
examples/research_projects/xtreme-s/requirements.txt
examples/research_projects/xtreme-s/requirements.txt
+0
-0
examples/research_projects/xtreme-s/run_xtreme_s.py
examples/research_projects/xtreme-s/run_xtreme_s.py
+77
-22
No files found.
examples/research_projects/xreme-s/README.md
→
examples/research_projects/x
t
reme-s/README.md
View file @
c2dc89be
...
@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
...
@@ -81,9 +81,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python
-m
torch.distributed.launch
\
python
-m
torch.distributed.launch
\
--nproc_per_node
=
8
\
--nproc_per_node
=
8
\
run_xtreme_s.py
\
run_xtreme_s.py
\
--task
=
"mls"
\
--language
=
"all"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--dataset_name
=
"google/xtreme_s"
\
--dataset_config_name
=
"mls.all"
\
--eval_split_name
=
"test"
\
--eval_split_name
=
"test"
\
--output_dir
=
"xtreme_s_xlsr_300m_mls"
\
--output_dir
=
"xtreme_s_xlsr_300m_mls"
\
--overwrite_output_dir
\
--overwrite_output_dir
\
...
@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
...
@@ -94,7 +94,6 @@ python -m torch.distributed.launch \
--learning_rate
=
"3e-4"
\
--learning_rate
=
"3e-4"
\
--warmup_steps
=
3000
\
--warmup_steps
=
3000
\
--evaluation_strategy
=
"steps"
\
--evaluation_strategy
=
"steps"
\
--target_column_name
=
"transcription"
\
--max_duration_in_seconds
=
20
\
--max_duration_in_seconds
=
20
\
--save_steps
=
500
\
--save_steps
=
500
\
--eval_steps
=
500
\
--eval_steps
=
500
\
...
@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
...
@@ -126,10 +125,9 @@ The following command shows how to fine-tune the [XLS-R](https://huggingface.co/
python
-m
torch.distributed.launch
\
python
-m
torch.distributed.launch
\
--nproc_per_node
=
2
\
--nproc_per_node
=
2
\
run_xtreme_s.py
\
run_xtreme_s.py
\
--task
=
"minds14"
\
--language
=
"all"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--model_name_or_path
=
"facebook/wav2vec2-xls-r-300m"
\
--dataset_name
=
"google/xtreme_s"
\
--dataset_config_name
=
"minds14.all"
\
--eval_split_name
=
"test"
\
--output_dir
=
"xtreme_s_xlsr_300m_minds14"
\
--output_dir
=
"xtreme_s_xlsr_300m_minds14"
\
--overwrite_output_dir
\
--overwrite_output_dir
\
--num_train_epochs
=
50
\
--num_train_epochs
=
50
\
...
@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
...
@@ -139,7 +137,6 @@ python -m torch.distributed.launch \
--learning_rate
=
"3e-4"
\
--learning_rate
=
"3e-4"
\
--warmup_steps
=
1500
\
--warmup_steps
=
1500
\
--evaluation_strategy
=
"steps"
\
--evaluation_strategy
=
"steps"
\
--target_column_name
=
"intent_class"
\
--max_duration_in_seconds
=
30
\
--max_duration_in_seconds
=
30
\
--save_steps
=
200
\
--save_steps
=
200
\
--eval_steps
=
200
\
--eval_steps
=
200
\
...
...
examples/research_projects/xreme-s/requirements.txt
→
examples/research_projects/x
t
reme-s/requirements.txt
View file @
c2dc89be
File moved
examples/research_projects/xreme-s/run_xtreme_s.py
→
examples/research_projects/x
t
reme-s/run_xtreme_s.py
View file @
c2dc89be
...
@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
...
@@ -62,6 +62,17 @@ def list_field(default=None, metadata=None):
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
return
field
(
default_factory
=
lambda
:
default
,
metadata
=
metadata
)
TASK_TO_TARGET_COLUMN_NAME
=
{
"fleurs-asr"
:
"transcription"
,
"fleurs-lang_id"
:
"lang_id"
,
"mls"
:
"transcription"
,
"voxpopuli"
:
"transcription"
,
"covost2"
:
"translation"
,
"minds14"
:
"intent_class"
,
"babel"
:
"transcription"
,
}
@
dataclass
@
dataclass
class
ModelArguments
:
class
ModelArguments
:
"""
"""
...
@@ -144,8 +155,16 @@ class DataTrainingArguments:
...
@@ -144,8 +155,16 @@ class DataTrainingArguments:
default
=
"xtreme_s"
,
default
=
"xtreme_s"
,
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"
},
metadata
=
{
"help"
:
"The name of the dataset to use (via the datasets library). Defaults to 'xtreme_s'"
},
)
)
dataset_config_name
:
str
=
field
(
task
:
str
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"The configuration name of the dataset to use (via the datasets library)."
}
default
=
None
,
metadata
=
{
"help"
:
"The task name of the benchmark to use (via the datasets library). Should be on of: "
"'fleurs-asr', 'mls', 'voxpopuli', 'covost2', 'minds14', 'fleurs-lang_id', 'babel'."
},
)
language
:
str
=
field
(
default
=
"all"
,
metadata
=
{
"help"
:
"The language id as defined in the datasets config name or `all` for all languages."
},
)
)
train_split_name
:
str
=
field
(
train_split_name
:
str
=
field
(
default
=
"train"
,
default
=
"train"
,
...
@@ -160,6 +179,13 @@ class DataTrainingArguments:
...
@@ -160,6 +179,13 @@ class DataTrainingArguments:
"Defaults to 'validation'"
"Defaults to 'validation'"
},
},
)
)
predict_split_name
:
str
=
field
(
default
=
"test"
,
metadata
=
{
"help"
:
"The name of the prediction data set split to use (via the datasets library). "
"Defaults to 'test'"
},
)
audio_column_name
:
str
=
field
(
audio_column_name
:
str
=
field
(
default
=
"audio"
,
default
=
"audio"
,
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
metadata
=
{
"help"
:
"The name of the dataset column containing the audio data. Defaults to 'audio'"
},
...
@@ -192,6 +218,13 @@ class DataTrainingArguments:
...
@@ -192,6 +218,13 @@ class DataTrainingArguments:
"value if set."
"value if set."
},
},
)
)
max_predict_samples
:
Optional
[
int
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"For debugging purposes or quicker training, truncate the number of prediction examples to this "
"value if set."
},
)
chars_to_ignore
:
Optional
[
List
[
str
]]
=
list_field
(
chars_to_ignore
:
Optional
[
List
[
str
]]
=
list_field
(
default
=
', ? . ! - ; : " “ % ‘ ” �'
.
split
(
" "
),
default
=
', ? . ! - ; : " “ % ‘ ” �'
.
split
(
" "
),
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
...
@@ -387,22 +420,31 @@ def main():
...
@@ -387,22 +420,31 @@ def main():
# 1. First, let's load the dataset
# 1. First, let's load the dataset
raw_datasets
=
DatasetDict
()
raw_datasets
=
DatasetDict
()
if
data_args
.
dataset_config_name
is
None
:
task_name
=
data_args
.
task
lang_id
=
data_args
.
language
if
task_name
is
None
:
raise
ValueError
(
"Set --task should be set to '<xtreme_s_task>' "
"(e.g. 'fleurs-asr', 'mls', 'covost2', 'minds14') "
)
if
lang_id
is
None
:
raise
ValueError
(
raise
ValueError
(
"Set --
dataset_config_nam
e should be set to
'<xtreme_s_subset>.<language(s)>'
"
"Set --
languag
e should be set to
the language id of the sub dataset
"
"(e.g. '
mls.
pl', '
covost2.
en.tr', '
minds14.
fr-FR') "
"
config to be used
(e.g. 'pl', 'en.tr', 'fr-FR')
or 'all'
"
"
or '<xtreme_s_subset>.all'
for multi-lingual fine-tuning."
" for multi-lingual fine-tuning."
)
)
ta
sk_name
=
data_args
.
dataset_config_name
.
split
(
"."
)[
0
]
ta
rget_column_name
=
TASK_TO_TARGET_COLUMN_NAME
[
task_name
]
target_column_name
=
data_args
.
target_column_name
# here we differentiate between tasks with text as the target and classification tasks
# here we differentiate between tasks with text as the target and classification tasks
is_text_target
=
target_column_name
in
(
"transcription"
,
"translation"
)
is_text_target
=
target_column_name
in
(
"transcription"
,
"translation"
)
config_name
=
"."
.
join
([
task_name
.
split
(
"-"
)[
0
],
lang_id
])
if
training_args
.
do_train
:
if
training_args
.
do_train
:
raw_datasets
[
"train"
]
=
load_dataset
(
raw_datasets
[
"train"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_
config_name
,
config_name
,
split
=
data_args
.
train_split_name
,
split
=
data_args
.
train_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
...
@@ -432,7 +474,7 @@ def main():
...
@@ -432,7 +474,7 @@ def main():
if
training_args
.
do_eval
:
if
training_args
.
do_eval
:
raw_datasets
[
"eval"
]
=
load_dataset
(
raw_datasets
[
"eval"
]
=
load_dataset
(
data_args
.
dataset_name
,
data_args
.
dataset_name
,
data_args
.
dataset_
config_name
,
config_name
,
split
=
data_args
.
eval_split_name
,
split
=
data_args
.
eval_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
cache_dir
=
model_args
.
cache_dir
,
...
@@ -441,6 +483,18 @@ def main():
...
@@ -441,6 +483,18 @@ def main():
if
data_args
.
max_eval_samples
is
not
None
:
if
data_args
.
max_eval_samples
is
not
None
:
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
raw_datasets
[
"eval"
]
=
raw_datasets
[
"eval"
].
select
(
range
(
data_args
.
max_eval_samples
))
if
training_args
.
do_predict
:
raw_datasets
[
"predict"
]
=
load_dataset
(
data_args
.
dataset_name
,
config_name
,
split
=
data_args
.
predict_split_name
,
use_auth_token
=
data_args
.
use_auth_token
,
cache_dir
=
model_args
.
cache_dir
,
)
if
data_args
.
max_predict_samples
is
not
None
:
raw_datasets
[
"predict"
]
=
raw_datasets
[
"predict"
].
select
(
range
(
data_args
.
max_predict_samples
))
# 2. We remove some special characters from the datasets
# 2. We remove some special characters from the datasets
# that make training complicated and do not help in transcribing the speech
# that make training complicated and do not help in transcribing the speech
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
# E.g. characters, such as `,` and `.` do not really have an acoustic characteristic
...
@@ -757,24 +811,25 @@ def main():
...
@@ -757,24 +811,25 @@ def main():
# Evaluation
# Evaluation
results
=
{}
results
=
{}
if
training_args
.
do_eval
:
if
training_args
.
do_predict
:
logger
.
info
(
"*** Evaluate ***"
)
logger
.
info
(
"*** Predicte ***"
)
metrics
=
trainer
.
evaluate
()
metrics
=
trainer
.
evaluate
(
vectorized_datasets
[
"predict"
])
max_eval_samples
=
(
max_predict_samples
=
(
data_args
.
max_eval_samples
if
data_args
.
max_eval_samples
is
not
None
else
len
(
vectorized_datasets
[
"eval"
])
data_args
.
max_predict_samples
if
data_args
.
max_predict_samples
is
not
None
else
len
(
vectorized_datasets
[
"predict"
])
)
)
metrics
[
"
eval
_samples"
]
=
min
(
max_
eval
_samples
,
len
(
vectorized_datasets
[
"
eval
"
]))
metrics
[
"
predict
_samples"
]
=
min
(
max_
predict
_samples
,
len
(
vectorized_datasets
[
"
predict
"
]))
trainer
.
log_metrics
(
"
eval
"
,
metrics
)
trainer
.
log_metrics
(
"
predict
"
,
metrics
)
trainer
.
save_metrics
(
"
eval
"
,
metrics
)
trainer
.
save_metrics
(
"
predict
"
,
metrics
)
# Write model card and (optionally) push to hub
# Write model card and (optionally) push to hub
config_name
=
data_args
.
dataset_config_name
if
data_args
.
dataset_config_name
is
not
None
else
"na"
kwargs
=
{
kwargs
=
{
"finetuned_from"
:
model_args
.
model_name_or_path
,
"finetuned_from"
:
model_args
.
model_name_or_path
,
"tasks"
:
"speech-recognition"
,
"tasks"
:
task_name
,
"tags"
:
[
"automatic-speech-recognition"
,
data_args
.
dataset_name
],
"tags"
:
[
task_name
,
data_args
.
dataset_name
],
"dataset_args"
:
f
"Config:
{
config_name
}
, Training split:
{
data_args
.
train_split_name
}
, Eval split:
{
data_args
.
eval_split_name
}
"
,
"dataset_args"
:
f
"Config:
{
config_name
}
, Training split:
{
data_args
.
train_split_name
}
, Eval split:
{
data_args
.
eval_split_name
}
, Predict split:
{
data_args
.
predict_split_name
}
"
,
"dataset"
:
f
"
{
data_args
.
dataset_name
.
upper
()
}
-
{
config_name
.
upper
()
}
"
,
"dataset"
:
f
"
{
data_args
.
dataset_name
.
upper
()
}
-
{
config_name
.
upper
()
}
"
,
}
}
if
"common_voice"
in
data_args
.
dataset_name
:
if
"common_voice"
in
data_args
.
dataset_name
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment