Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
efb35a41
"docs/source/en/model_doc/dialogpt.md" did not exist on "d22894dfd40d5c858e8398e2783545103d191b47"
Unverified
Commit
efb35a41
authored
Jan 11, 2022
by
Patrick von Platen
Committed by
GitHub
Jan 11, 2022
Browse files
[Wav2Vec2ProcessorWithLM] improve decoder downlaod (#15040)
parent
6ea62666
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
1 deletion
+22
-1
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+8
-1
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+14
-0
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
efb35a41
...
...
@@ -166,7 +166,14 @@ class Wav2Vec2ProcessorWithLM:
# BeamSearchDecoderCTC has no auto class
kwargs
.
pop
(
"_from_auto"
,
None
)
decoder
=
BeamSearchDecoderCTC
.
load_from_hf_hub
(
pretrained_model_name_or_path
,
**
kwargs
)
# make sure that only relevant filenames are downloaded
language_model_filenames
=
os
.
path
.
join
(
BeamSearchDecoderCTC
.
_LANGUAGE_MODEL_SERIALIZED_DIRECTORY
,
"*"
)
alphabet_filename
=
BeamSearchDecoderCTC
.
_ALPHABET_SERIALIZED_FILENAME
allow_regex
=
[
language_model_filenames
,
alphabet_filename
]
decoder
=
BeamSearchDecoderCTC
.
load_from_hf_hub
(
pretrained_model_name_or_path
,
allow_regex
=
allow_regex
,
**
kwargs
)
# set language model attributes
for
attribute
in
[
"alpha"
,
"beta"
,
"unk_score_offset"
,
"score_boundary"
]:
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
efb35a41
...
...
@@ -18,6 +18,7 @@ import shutil
import
tempfile
import
unittest
from
multiprocessing
import
Pool
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -234,3 +235,16 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
def
test_decoder_download_ignores_files
(
self
):
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
language_model
=
processor
.
decoder
.
model_container
[
processor
.
decoder
.
_model_key
]
path_to_cached_dir
=
Path
(
language_model
.
_kenlm_model
.
path
.
decode
(
"utf-8"
)).
parent
.
parent
.
absolute
()
downloaded_decoder_files
=
os
.
listdir
(
path_to_cached_dir
)
# test that only decoder relevant files from
# https://huggingface.co/hf-internal-testing/processor_with_lm/tree/main
# are downloaded and none of the rest (e.g. README.md, ...)
self
.
assertListEqual
(
downloaded_decoder_files
,
[
"alphabet.json"
,
"language_model"
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment