Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efb35a41
Unverified
Commit
efb35a41
authored
Jan 11, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 11, 2022
Browse files
[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)
parent
6ea62666
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+8
-1
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+14
-0
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
efb35a41
...
@@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM:
# BeamSearchDecoderCTC has no auto class
# BeamSearchDecoderCTC has no auto class
kwargs
.
pop
(
"_from_auto"
,
None
)
kwargs
.
pop
(
"_from_auto"
,
None
)
decoder
=
BeamSearchDecoderCTC
.
load_from_hf_hub
(
pretrained_model_name_or_path
,
**
kwargs
)
# make sure that only relevant filenames are downloaded
language_model_filenames
=
os
.
path
.
join
(
BeamSearchDecoderCTC
.
_LANGUAGE_MODEL_SERIALIZED_DIRECTORY
,
"*"
)
alphabet_filename
=
BeamSearchDecoderCTC
.
_ALPHABET_SERIALIZED_FILENAME
allow_regex
=
[
language_model_filenames
,
alphabet_filename
]
decoder
=
BeamSearchDecoderCTC
.
load_from_hf_hub
(
pretrained_model_name_or_path
,
allow_regex
=
allow_regex
,
**
kwargs
)
# set language model attributes
# set language model attributes
for
attribute
in
[
"alpha"
,
"beta"
,
"unk_score_offset"
,
"score_boundary"
]:
for
attribute
in
[
"alpha"
,
"beta"
,
"unk_score_offset"
,
"score_boundary"
]:
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
efb35a41
...
@@ -18,6 +18,7 @@ import shutil
...
@@ -18,6 +18,7 @@ import shutil
import
tempfile
import
tempfile
import
unittest
import
unittest
from
multiprocessing
import
Pool
from
multiprocessing
import
Pool
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
def
test_decoder_download_ignores_files
(
self
):
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
language_model
=
processor
.
decoder
.
model_container
[
processor
.
decoder
.
_model_key
]
path_to_cached_dir
=
Path
(
language_model
.
_kenlm_model
.
path
.
decode
(
"utf-8"
)).
parent
.
parent
.
absolute
()
downloaded_decoder_files
=
os
.
listdir
(
path_to_cached_dir
)
# test that only decoder relevant files from
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self
.
assertListEqual
(
downloaded_decoder_files
,
[
"alphabet.json"
,
"language_model"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment