Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d718c0c3
Unverified
Commit
d718c0c3
authored
Feb 02, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 02, 2022
Browse files
[Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465)
parent
1d94d575
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
3 deletions
+88
-3
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+34
-0
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+54
-3
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
d718c0c3
...
...
@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
:
Optional
[
float
]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
unk_score_offset
:
Optional
[
float
]
=
None
,
lm_score_boundary
:
Optional
[
bool
]
=
None
,
):
"""
Batch decode output logits to audio transcription with language model support.
...
...
@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance, can be OOV for LM
hotword_weight (`int`, *optional*):
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...
...
@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self
.
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
)
# create multiprocessing pool and list numpy arrays
logits_list
=
[
array
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
(
num_processes
)
...
...
@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
:
Optional
[
float
]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
unk_score_offset
:
Optional
[
float
]
=
None
,
lm_score_boundary
:
Optional
[
bool
]
=
None
,
):
"""
Decode output logits to audio transcription with language model support.
...
...
@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
hotword_weight (`int`, *optional*):
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...
...
@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self
.
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
)
# pyctcdecode
decoded_beams
=
self
.
decoder
.
decode_beams
(
logits
,
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
d718c0c3
...
...
@@ -17,7 +17,7 @@ import os
import
shutil
import
tempfile
import
unittest
from
multiprocessing
import
Pool
from
multiprocessing
import
get_context
from
pathlib
import
Path
import
numpy
as
np
...
...
@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor
=
processor
.
batch_decode
(
logits
).
text
logits_list
=
[
array
for
array
in
logits
]
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoder
.
decode_beams_batch
(
Pool
(),
logits_list
)]
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoder
.
decode_beams_batch
(
pool
,
logits_list
)]
pool
.
close
()
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> <s> </s>"
,
"<s> <s> <s>"
],
decoded_processor
)
...
...
@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor
=
decoded_processor_out
.
text
logits_list
=
[
array
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder_out
=
decoder
.
decode_beams_batch
(
P
ool
()
,
p
ool
,
logits_list
,
beam_width
=
beam_width
,
beam_prune_logp
=
beam_prune_logp
,
token_min_logp
=
token_min_logp
,
)
pool
.
close
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
def
test_decoder_with_params_of_lm
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
tokenizer
=
self
.
get_tokenizer
()
decoder
=
self
.
get_decoder
()
processor
=
Wav2Vec2ProcessorWithLM
(
tokenizer
=
tokenizer
,
feature_extractor
=
feature_extractor
,
decoder
=
decoder
)
logits
=
self
.
_get_dummy_logits
()
alpha
=
2.0
beta
=
5.0
unk_score_offset
=
-
20.0
lm_score_boundary
=
True
decoded_processor_out
=
processor
.
batch_decode
(
logits
,
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
,
)
decoded_processor
=
decoded_processor_out
.
text
logits_list
=
[
array
for
array
in
logits
]
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
,
)
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder_out
=
decoder
.
decode_beams_batch
(
pool
,
logits_list
,
)
pool
.
close
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> <s> </s> </s>"
,
"</s> </s> <s> </s> </s>"
],
decoded_processor
)
lm_model
=
processor
.
decoder
.
model_container
[
processor
.
decoder
.
_model_key
]
self
.
assertEqual
(
lm_model
.
alpha
,
2.0
)
self
.
assertEqual
(
lm_model
.
beta
,
5.0
)
self
.
assertEqual
(
lm_model
.
unk_score_offset
,
-
20.0
)
self
.
assertEqual
(
lm_model
.
score_boundary
,
True
)
def
test_decoder_download_ignores_files
(
self
):
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment