Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3a4376d0
Unverified
Commit
3a4376d0
authored
Feb 16, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 16, 2022
Browse files
[Wav2Vec2ProcessorWithLM] Fix auto processor with lm (#15683)
parent
cdc51ffd
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
0 deletions
+22
-0
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+2
-0
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+20
-0
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
3a4376d0
...
@@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
...
@@ -138,6 +138,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
else
:
else
:
# BeamSearchDecoderCTC has no auto class
# BeamSearchDecoderCTC has no auto class
kwargs
.
pop
(
"_from_auto"
,
None
)
kwargs
.
pop
(
"_from_auto"
,
None
)
# snapshot_download has no `trust_remote_code` flag
kwargs
.
pop
(
"trust_remote_code"
,
None
)
# make sure that only relevant filenames are downloaded
# make sure that only relevant filenames are downloaded
language_model_filenames
=
os
.
path
.
join
(
BeamSearchDecoderCTC
.
_LANGUAGE_MODEL_SERIALIZED_DIRECTORY
,
"*"
)
language_model_filenames
=
os
.
path
.
join
(
BeamSearchDecoderCTC
.
_LANGUAGE_MODEL_SERIALIZED_DIRECTORY
,
"*"
)
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
3a4376d0
...
@@ -22,6 +22,7 @@ from pathlib import Path
...
@@ -22,6 +22,7 @@ from pathlib import Path
import
numpy
as
np
import
numpy
as
np
from
transformers
import
AutoProcessor
from
transformers.file_utils
import
FEATURE_EXTRACTOR_NAME
,
is_pyctcdecode_available
from
transformers.file_utils
import
FEATURE_EXTRACTOR_NAME
,
is_pyctcdecode_available
from
transformers.models.wav2vec2
import
Wav2Vec2CTCTokenizer
,
Wav2Vec2FeatureExtractor
from
transformers.models.wav2vec2
import
Wav2Vec2CTCTokenizer
,
Wav2Vec2FeatureExtractor
from
transformers.models.wav2vec2.tokenization_wav2vec2
import
VOCAB_FILES_NAMES
from
transformers.models.wav2vec2.tokenization_wav2vec2
import
VOCAB_FILES_NAMES
...
@@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -330,3 +331,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
# test that both decoder form hub and local files in cache are the same
# test that both decoder form hub and local files in cache are the same
self
.
assertListEqual
(
local_decoder_files
,
expected_decoder_files
)
self
.
assertListEqual
(
local_decoder_files
,
expected_decoder_files
)
def
test_processor_from_auto_processor
(
self
):
processor_wav2vec2
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
processor_auto
=
AutoProcessor
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
raw_speech
=
floats_list
((
3
,
1000
))
input_wav2vec2
=
processor_wav2vec2
(
raw_speech
,
return_tensors
=
"np"
)
input_auto
=
processor_auto
(
raw_speech
,
return_tensors
=
"np"
)
for
key
in
input_wav2vec2
.
keys
():
self
.
assertAlmostEqual
(
input_wav2vec2
[
key
].
sum
(),
input_auto
[
key
].
sum
(),
delta
=
1e-2
)
logits
=
self
.
_get_dummy_logits
()
decoded_wav2vec2
=
processor_wav2vec2
.
batch_decode
(
logits
)
decoded_auto
=
processor_auto
.
batch_decode
(
logits
)
self
.
assertListEqual
(
decoded_wav2vec2
.
text
,
decoded_auto
.
text
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment