Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
55f49c5f
Unverified
Commit
55f49c5f
authored
Nov 12, 2021
by
Patrick von Platen
Committed by
GitHub
Nov 12, 2021
Browse files
[Wav2Vec2 Example] Improve fine-tuning script (#14373)
* improve some stuff * finish * correct last
parent
21546e59
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
5 deletions
+27
-5
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
.../pytorch/speech-recognition/run_speech_recognition_ctc.py
+27
-5
No files found.
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
View file @
55f49c5f
...
...
@@ -99,9 +99,24 @@ class ModelArguments:
metadata
=
{
"help"
:
"Probability of each feature vector along the time axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
"vectors will be masked along the time axis.
This is only relevant if ``apply_spec_augment is True``.
"
"vectors will be masked along the time axis."
},
)
mask_time_length
:
Optional
[
int
]
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"Length of vector span to mask along the time axis."
},
)
mask_feature_prob
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Probability of each feature vector along the feature axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
},
)
mask_feature_length
:
Optional
[
int
]
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"Length of vector span to mask along the feature axis."
},
)
layerdrop
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The LayerDrop probability."
})
ctc_loss_reduction
:
Optional
[
str
]
=
field
(
default
=
"mean"
,
metadata
=
{
"help"
:
"The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."
}
...
...
@@ -169,6 +184,10 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
)
eval_metrics
:
Optional
[
List
[
str
]]
=
list_field
(
default
=
[
"wer"
],
metadata
=
{
"help"
:
"A list of metrics the model should be evaluated on. E.g. `'wer cer'`"
},
)
max_duration_in_seconds
:
Optional
[
float
]
=
field
(
default
=
20.0
,
metadata
=
{
...
...
@@ -446,6 +465,9 @@ def main():
"hidden_dropout"
:
model_args
.
hidden_dropout
,
"final_dropout"
:
model_args
.
final_dropout
,
"mask_time_prob"
:
model_args
.
mask_time_prob
,
"mask_time_length"
:
model_args
.
mask_time_length
,
"mask_feature_prob"
:
model_args
.
mask_feature_prob
,
"mask_feature_length"
:
model_args
.
mask_feature_length
,
"gradient_checkpointing"
:
training_args
.
gradient_checkpointing
,
"layerdrop"
:
model_args
.
layerdrop
,
"ctc_loss_reduction"
:
model_args
.
ctc_loss_reduction
,
...
...
@@ -519,8 +541,8 @@ def main():
# Let's use word error rate (WER) as our evaluation metric,
# instantiate a data collator and the trainer
# Define
M
etric during training
wer
_metric
=
load_metric
(
"wer"
)
# Define
evaluation m
etric
s
during training
, *i.e.* word error rate, character error rate
eval
_metric
s
=
{
metric
:
load_metric
(
metric
)
for
metric
in
data_args
.
eval_metrics
}
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
...
...
@@ -541,9 +563,9 @@ def main():
# we do not want to group tokens when computing the metrics
label_str
=
processor
.
batch_decode
(
pred
.
label_ids
,
group_tokens
=
False
)
wer
=
wer_metric
.
compute
(
predictions
=
pred_str
,
references
=
label_str
)
metrics
=
{
k
:
v
.
compute
(
predictions
=
pred_str
,
references
=
label_str
)
for
k
,
v
in
eval_metrics
.
items
()}
return
{
"wer"
:
wer
}
return
metrics
# Instantiate custom data collator
data_collator
=
DataCollatorCTCWithPadding
(
processor
=
processor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment