Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
bc084938
Unverified
Commit
bc084938
authored
Jun 29, 2021
by
Will Rice
Committed by
GitHub
Jun 29, 2021
Browse files
Add out of vocabulary error to ASR models (#12288)
* Add OOV error to ASR models * Feedback changes
parent
1fc6817a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
75 additions
and
0 deletions
+75
-0
src/transformers/models/hubert/modeling_hubert.py
src/transformers/models/hubert/modeling_hubert.py
+3
-0
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
+4
-0
src/transformers/models/wav2vec2/modeling_wav2vec2.py
src/transformers/models/wav2vec2/modeling_wav2vec2.py
+3
-0
tests/test_modeling_hubert.py
tests/test_modeling_hubert.py
+24
-0
tests/test_modeling_tf_wav2vec2.py
tests/test_modeling_tf_wav2vec2.py
+17
-0
tests/test_modeling_wav2vec2.py
tests/test_modeling_wav2vec2.py
+24
-0
No files found.
src/transformers/models/hubert/modeling_hubert.py
View file @
bc084938
...
@@ -1030,6 +1030,9 @@ class HubertForCTC(HubertPreTrainedModel):
...
@@ -1030,6 +1030,9 @@ class HubertForCTC(HubertPreTrainedModel):
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
if
labels
.
max
()
>=
self
.
config
.
vocab_size
:
raise
ValueError
(
f
"Label values must be <= vocab_size:
{
self
.
config
.
vocab_size
}
"
)
# retrieve loss input_lengths from attention_mask
# retrieve loss input_lengths from attention_mask
attention_mask
=
(
attention_mask
=
(
attention_mask
if
attention_mask
is
not
None
else
torch
.
ones_like
(
input_values
,
dtype
=
torch
.
long
)
attention_mask
if
attention_mask
is
not
None
else
torch
.
ones_like
(
input_values
,
dtype
=
torch
.
long
)
...
...
src/transformers/models/wav2vec2/modeling_tf_wav2vec2.py
View file @
bc084938
...
@@ -1571,6 +1571,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
...
@@ -1571,6 +1571,10 @@ class TFWav2Vec2ForCTC(TFWav2Vec2PreTrainedModel):
logits
=
self
.
lm_head
(
hidden_states
)
logits
=
self
.
lm_head
(
hidden_states
)
if
labels
is
not
None
:
if
labels
is
not
None
:
if
tf
.
reduce_max
(
labels
)
>=
self
.
config
.
vocab_size
:
raise
ValueError
(
f
"Label values must be <= vocab_size:
{
self
.
config
.
vocab_size
}
"
)
attention_mask
=
(
attention_mask
=
(
inputs
[
"attention_mask"
]
inputs
[
"attention_mask"
]
if
inputs
[
"attention_mask"
]
is
not
None
if
inputs
[
"attention_mask"
]
is
not
None
...
...
src/transformers/models/wav2vec2/modeling_wav2vec2.py
View file @
bc084938
...
@@ -1480,6 +1480,9 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
...
@@ -1480,6 +1480,9 @@ class Wav2Vec2ForCTC(Wav2Vec2PreTrainedModel):
loss
=
None
loss
=
None
if
labels
is
not
None
:
if
labels
is
not
None
:
if
labels
.
max
()
>=
self
.
config
.
vocab_size
:
raise
ValueError
(
f
"Label values must be <= vocab_size:
{
self
.
config
.
vocab_size
}
"
)
# retrieve loss input_lengths from attention_mask
# retrieve loss input_lengths from attention_mask
attention_mask
=
(
attention_mask
=
(
attention_mask
if
attention_mask
is
not
None
else
torch
.
ones_like
(
input_values
,
dtype
=
torch
.
long
)
attention_mask
if
attention_mask
is
not
None
else
torch
.
ones_like
(
input_values
,
dtype
=
torch
.
long
)
...
...
tests/test_modeling_hubert.py
View file @
bc084938
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
import
math
import
math
import
unittest
import
unittest
import
pytest
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_torch
,
slow
,
torch_device
...
@@ -210,6 +212,20 @@ class HubertModelTester:
...
@@ -210,6 +212,20 @@ class HubertModelTester:
loss
.
backward
()
loss
.
backward
()
def
check_labels_out_of_vocab
(
self
,
config
,
input_values
,
*
args
):
model
=
HubertForCTC
(
config
)
model
.
to
(
torch_device
)
model
.
train
()
input_values
=
input_values
[:
3
]
input_lengths
=
[
input_values
.
shape
[
-
1
]
//
i
for
i
in
[
4
,
2
,
1
]]
max_length_labels
=
model
.
_get_feat_extract_output_lengths
(
torch
.
tensor
(
input_lengths
))
labels
=
ids_tensor
((
input_values
.
shape
[
0
],
max
(
max_length_labels
)
-
2
),
model
.
config
.
vocab_size
+
100
)
with
pytest
.
raises
(
ValueError
):
model
(
input_values
,
labels
=
labels
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
...
@@ -242,6 +258,10 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -242,6 +258,10 @@ class HubertModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
# Hubert has no inputs_embeds
# Hubert has no inputs_embeds
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
@@ -377,6 +397,10 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -377,6 +397,10 @@ class HubertRobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
# Hubert has no inputs_embeds
# Hubert has no inputs_embeds
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
tests/test_modeling_tf_wav2vec2.py
View file @
bc084938
...
@@ -20,6 +20,7 @@ import math
...
@@ -20,6 +20,7 @@ import math
import
unittest
import
unittest
import
numpy
as
np
import
numpy
as
np
import
pytest
from
transformers
import
Wav2Vec2Config
,
is_tf_available
from
transformers
import
Wav2Vec2Config
,
is_tf_available
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_tf
,
slow
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_tf
,
slow
...
@@ -202,6 +203,14 @@ class TFWav2Vec2ModelTester:
...
@@ -202,6 +203,14 @@ class TFWav2Vec2ModelTester:
self
.
parent
.
assertFalse
(
tf
.
math
.
is_inf
(
loss
))
self
.
parent
.
assertFalse
(
tf
.
math
.
is_inf
(
loss
))
def
check_labels_out_of_vocab
(
self
,
config
,
input_values
,
*
args
):
model
=
TFWav2Vec2ForCTC
(
config
)
input_lengths
=
tf
.
constant
([
input_values
.
shape
[
-
1
]
//
i
for
i
in
[
4
,
2
,
1
]])
max_length_labels
=
model
.
wav2vec2
.
_get_feat_extract_output_lengths
(
input_lengths
)
labels
=
ids_tensor
((
input_values
.
shape
[
0
],
min
(
max_length_labels
)
-
1
),
model
.
config
.
vocab_size
+
100
)
with
pytest
.
raises
(
ValueError
):
model
(
input_values
,
labels
=
labels
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
...
@@ -288,6 +297,10 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -288,6 +297,10 @@ class TFWav2Vec2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_ctc_loss
(
*
config_and_inputs
)
self
.
model_tester
.
check_ctc_loss
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
def
test_train
(
self
):
def
test_train
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
...
@@ -402,6 +415,10 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
...
@@ -402,6 +415,10 @@ class TFWav2Vec2RobustModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_ctc_loss
(
*
config_and_inputs
)
self
.
model_tester
.
check_ctc_loss
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
def
test_train
(
self
):
def
test_train
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
...
...
tests/test_modeling_wav2vec2.py
View file @
bc084938
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
import
math
import
math
import
unittest
import
unittest
import
pytest
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
tests.test_modeling_common
import
floats_tensor
,
ids_tensor
,
random_attention_mask
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_torch
,
slow
,
torch_device
from
transformers.testing_utils
import
require_datasets
,
require_soundfile
,
require_torch
,
slow
,
torch_device
...
@@ -218,6 +220,20 @@ class Wav2Vec2ModelTester:
...
@@ -218,6 +220,20 @@ class Wav2Vec2ModelTester:
loss
.
backward
()
loss
.
backward
()
def
check_labels_out_of_vocab
(
self
,
config
,
input_values
,
*
args
):
model
=
Wav2Vec2ForCTC
(
config
)
model
.
to
(
torch_device
)
model
.
train
()
input_values
=
input_values
[:
3
]
input_lengths
=
[
input_values
.
shape
[
-
1
]
//
i
for
i
in
[
4
,
2
,
1
]]
max_length_labels
=
model
.
_get_feat_extract_output_lengths
(
torch
.
tensor
(
input_lengths
))
labels
=
ids_tensor
((
input_values
.
shape
[
0
],
max
(
max_length_labels
)
-
2
),
model
.
config
.
vocab_size
+
100
)
with
pytest
.
raises
(
ValueError
):
model
(
input_values
,
labels
=
labels
)
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
config
,
input_values
,
attention_mask
=
self
.
prepare_config_and_inputs
()
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
inputs_dict
=
{
"input_values"
:
input_values
,
"attention_mask"
:
attention_mask
}
...
@@ -252,6 +268,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -252,6 +268,10 @@ class Wav2Vec2ModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
# Wav2Vec2 has no inputs_embeds
# Wav2Vec2 has no inputs_embeds
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
@@ -392,6 +412,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
...
@@ -392,6 +412,10 @@ class Wav2Vec2RobustModelTest(ModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
self
.
model_tester
.
check_training
(
*
config_and_inputs
)
def
test_labels_out_of_vocab
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
check_labels_out_of_vocab
(
*
config_and_inputs
)
# Wav2Vec2 has no inputs_embeds
# Wav2Vec2 has no inputs_embeds
def
test_inputs_embeds
(
self
):
def
test_inputs_embeds
(
self
):
pass
pass
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment