Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d718c0c3
Unverified
Commit
d718c0c3
authored
Feb 02, 2022
by
Patrick von Platen
Committed by
GitHub
Feb 02, 2022
Browse files
[Wav2Vec2ProcessorWithLM] add alpha & beta to batch decode & decode (#15465)
parent
1d94d575
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
88 additions
and
3 deletions
+88
-3
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+34
-0
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+54
-3
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
d718c0c3
...
@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -253,6 +253,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
:
Optional
[
float
]
=
None
,
token_min_logp
:
Optional
[
float
]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
unk_score_offset
:
Optional
[
float
]
=
None
,
lm_score_boundary
:
Optional
[
bool
]
=
None
,
):
):
"""
"""
Batch decode output logits to audio transcription with language model support.
Batch decode output logits to audio transcription with language model support.
...
@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -280,6 +284,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance, can be OOV for LM
List of words with extra importance, can be OOV for LM
hotword_weight (`int`, *optional*):
hotword_weight (`int`, *optional*):
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
Weight factor for hotword importance Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns:
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...
@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -298,6 +310,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self
.
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
)
# create multiprocessing pool and list numpy arrays
# create multiprocessing pool and list numpy arrays
logits_list
=
[
array
for
array
in
logits
]
logits_list
=
[
array
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
(
num_processes
)
pool
=
get_context
(
"fork"
).
Pool
(
num_processes
)
...
@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -330,6 +347,10 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
:
Optional
[
float
]
=
None
,
token_min_logp
:
Optional
[
float
]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotwords
:
Optional
[
Iterable
[
str
]]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
hotword_weight
:
Optional
[
float
]
=
None
,
alpha
:
Optional
[
float
]
=
None
,
beta
:
Optional
[
float
]
=
None
,
unk_score_offset
:
Optional
[
float
]
=
None
,
lm_score_boundary
:
Optional
[
bool
]
=
None
,
):
):
"""
"""
Decode output logits to audio transcription with language model support.
Decode output logits to audio transcription with language model support.
...
@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -349,6 +370,14 @@ class Wav2Vec2ProcessorWithLM:
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
List of words with extra importance which can be missing from the LM's vocabulary, e.g. ["huggingface"]
hotword_weight (`int`, *optional*):
hotword_weight (`int`, *optional*):
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
Weight multiplier that boosts hotword scores. Defaults to pyctcdecode's DEFAULT_HOTWORD_WEIGHT.
alpha (`float`, *optional*):
Weight for language model during shallow fusion
beta (`float`, *optional*):
Weight for length score adjustment of during scoring
unk_score_offset (`float`, *optional*):
Amount of log score offset for unknown tokens
lm_score_boundary (`bool`, *optional*):
Whether to have kenlm respect boundaries when scoring
Returns:
Returns:
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
[`~models.wav2vec2.Wav2Vec2DecoderWithLMOutput`] or `tuple`.
...
@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM:
...
@@ -367,6 +396,11 @@ class Wav2Vec2ProcessorWithLM:
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
token_min_logp
=
token_min_logp
if
token_min_logp
is
not
None
else
DEFAULT_MIN_TOKEN_LOGP
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
hotword_weight
=
hotword_weight
if
hotword_weight
is
not
None
else
DEFAULT_HOTWORD_WEIGHT
# reset params at every forward call. It's just a `set` method in pyctcdecode
self
.
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
)
# pyctcdecode
# pyctcdecode
decoded_beams
=
self
.
decoder
.
decode_beams
(
decoded_beams
=
self
.
decoder
.
decode_beams
(
logits
,
logits
,
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
d718c0c3
...
@@ -17,7 +17,7 @@ import os
...
@@ -17,7 +17,7 @@ import os
import
shutil
import
shutil
import
tempfile
import
tempfile
import
unittest
import
unittest
from
multiprocessing
import
Pool
from
multiprocessing
import
get_context
from
pathlib
import
Path
from
pathlib
import
Path
import
numpy
as
np
import
numpy
as
np
...
@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -196,7 +196,9 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor
=
processor
.
batch_decode
(
logits
).
text
decoded_processor
=
processor
.
batch_decode
(
logits
).
text
logits_list
=
[
array
for
array
in
logits
]
logits_list
=
[
array
for
array
in
logits
]
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoder
.
decode_beams_batch
(
Pool
(),
logits_list
)]
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoder
.
decode_beams_batch
(
pool
,
logits_list
)]
pool
.
close
()
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> <s> </s>"
,
"<s> <s> <s>"
],
decoded_processor
)
self
.
assertListEqual
([
"<s> <s> </s>"
,
"<s> <s> <s>"
],
decoded_processor
)
...
@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -223,19 +225,68 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
decoded_processor
=
decoded_processor_out
.
text
decoded_processor
=
decoded_processor_out
.
text
logits_list
=
[
array
for
array
in
logits
]
logits_list
=
[
array
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder_out
=
decoder
.
decode_beams_batch
(
decoded_decoder_out
=
decoder
.
decode_beams_batch
(
P
ool
()
,
p
ool
,
logits_list
,
logits_list
,
beam_width
=
beam_width
,
beam_width
=
beam_width
,
beam_prune_logp
=
beam_prune_logp
,
beam_prune_logp
=
beam_prune_logp
,
token_min_logp
=
token_min_logp
,
token_min_logp
=
token_min_logp
,
)
)
pool
.
close
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
def
test_decoder_with_params_of_lm
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
tokenizer
=
self
.
get_tokenizer
()
decoder
=
self
.
get_decoder
()
processor
=
Wav2Vec2ProcessorWithLM
(
tokenizer
=
tokenizer
,
feature_extractor
=
feature_extractor
,
decoder
=
decoder
)
logits
=
self
.
_get_dummy_logits
()
alpha
=
2.0
beta
=
5.0
unk_score_offset
=
-
20.0
lm_score_boundary
=
True
decoded_processor_out
=
processor
.
batch_decode
(
logits
,
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
,
)
decoded_processor
=
decoded_processor_out
.
text
logits_list
=
[
array
for
array
in
logits
]
decoder
.
reset_params
(
alpha
=
alpha
,
beta
=
beta
,
unk_score_offset
=
unk_score_offset
,
lm_score_boundary
=
lm_score_boundary
,
)
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder_out
=
decoder
.
decode_beams_batch
(
pool
,
logits_list
,
)
pool
.
close
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> <s> </s> </s>"
,
"</s> </s> <s> </s> </s>"
],
decoded_processor
)
lm_model
=
processor
.
decoder
.
model_container
[
processor
.
decoder
.
_model_key
]
self
.
assertEqual
(
lm_model
.
alpha
,
2.0
)
self
.
assertEqual
(
lm_model
.
beta
,
5.0
)
self
.
assertEqual
(
lm_model
.
unk_score_offset
,
-
20.0
)
self
.
assertEqual
(
lm_model
.
score_boundary
,
True
)
def
test_decoder_download_ignores_files
(
self
):
def
test_decoder_download_ignores_files
(
self
):
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
processor
=
Wav2Vec2ProcessorWithLM
.
from_pretrained
(
"hf-internal-testing/processor_with_lm"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment