Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
55f49c5f
Unverified
Commit
55f49c5f
authored
Nov 12, 2021
by
Patrick von Platen
Committed by
GitHub
Nov 12, 2021
Browse files
[Wav2Vec2 Example] Improve fine-tuning script (#14373)
* improve some stuff * finish * correct last
parent
21546e59
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
5 deletions
+27
-5
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
.../pytorch/speech-recognition/run_speech_recognition_ctc.py
+27
-5
No files found.
examples/pytorch/speech-recognition/run_speech_recognition_ctc.py
View file @
55f49c5f
...
...
@@ -99,9 +99,24 @@ class ModelArguments:
metadata
=
{
"help"
:
"Probability of each feature vector along the time axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_time_prob * sequence_length // mask_time_length`` feature"
"vectors will be masked along the time axis.
This is only relevant if ``apply_spec_augment is True``.
"
"vectors will be masked along the time axis."
},
)
mask_time_length
:
Optional
[
int
]
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"Length of vector span to mask along the time axis."
},
)
mask_feature_prob
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"Probability of each feature vector along the feature axis to be chosen as the start of the vector"
"span to be masked. Approximately ``mask_feature_prob * sequence_length // mask_feature_length`` feature bins will be masked along the time axis."
},
)
mask_feature_length
:
Optional
[
int
]
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"Length of vector span to mask along the feature axis."
},
)
layerdrop
:
Optional
[
float
]
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"The LayerDrop probability."
})
ctc_loss_reduction
:
Optional
[
str
]
=
field
(
default
=
"mean"
,
metadata
=
{
"help"
:
"The way the ctc loss should be reduced. Should be one of 'mean' or 'sum'."
}
...
...
@@ -169,6 +184,10 @@ class DataTrainingArguments:
default
=
None
,
metadata
=
{
"help"
:
"A list of characters to remove from the transcripts."
},
)
eval_metrics
:
Optional
[
List
[
str
]]
=
list_field
(
default
=
[
"wer"
],
metadata
=
{
"help"
:
"A list of metrics the model should be evaluated on. E.g. `'wer cer'`"
},
)
max_duration_in_seconds
:
Optional
[
float
]
=
field
(
default
=
20.0
,
metadata
=
{
...
...
@@ -446,6 +465,9 @@ def main():
"hidden_dropout"
:
model_args
.
hidden_dropout
,
"final_dropout"
:
model_args
.
final_dropout
,
"mask_time_prob"
:
model_args
.
mask_time_prob
,
"mask_time_length"
:
model_args
.
mask_time_length
,
"mask_feature_prob"
:
model_args
.
mask_feature_prob
,
"mask_feature_length"
:
model_args
.
mask_feature_length
,
"gradient_checkpointing"
:
training_args
.
gradient_checkpointing
,
"layerdrop"
:
model_args
.
layerdrop
,
"ctc_loss_reduction"
:
model_args
.
ctc_loss_reduction
,
...
...
@@ -519,8 +541,8 @@ def main():
# Let's use word error rate (WER) as our evaluation metric,
# instantiate a data collator and the trainer
# Define
M
etric during training
wer
_metric
=
load_metric
(
"wer"
)
# Define
evaluation m
etric
s
during training
, *i.e.* word error rate, character error rate
eval
_metric
s
=
{
metric
:
load_metric
(
metric
)
for
metric
in
data_args
.
eval_metrics
}
# for large datasets it is advised to run the preprocessing on a
# single machine first with ``args.preprocessing_only`` since there will mostly likely
...
...
@@ -541,9 +563,9 @@ def main():
# we do not want to group tokens when computing the metrics
label_str
=
processor
.
batch_decode
(
pred
.
label_ids
,
group_tokens
=
False
)
wer
=
wer_metric
.
compute
(
predictions
=
pred_str
,
references
=
label_str
)
metrics
=
{
k
:
v
.
compute
(
predictions
=
pred_str
,
references
=
label_str
)
for
k
,
v
in
eval_metrics
.
items
()}
return
{
"wer"
:
wer
}
return
metrics
# Instantiate custom data collator
data_collator
=
DataCollatorCTCWithPadding
(
processor
=
processor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment