Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
67047b86
Unverified
Commit
67047b86
authored
Feb 15, 2022
by
arampacha
Committed by
GitHub
Feb 15, 2022
Browse files
add scores to Wav2Vec2WithLMOutput (#15413)
* add scores to Wav2Vec2WithLMOutput * style fixup
parent
45f56580
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
35 additions
and
14 deletions
+35
-14
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
...rs/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
+18
-6
tests/test_processor_wav2vec2_with_lm.py
tests/test_processor_wav2vec2_with_lm.py
+17
-8
No files found.
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
View file @
67047b86
...
...
@@ -42,9 +42,15 @@ class Wav2Vec2DecoderWithLMOutput(ModelOutput):
Args:
text (list of `str`):
Decoded logits in text from. Usually the speech transcription.
logit_score (list of `float`):
Total logit score of the beam associated with produced text.
lm_score (list of `float`):
Fused lm_score of the beam associated with produced text.
"""
text
:
Union
[
List
[
str
],
str
]
logit_score
:
Union
[
List
[
float
],
float
]
=
None
lm_score
:
Union
[
List
[
float
],
float
]
=
None
class
Wav2Vec2ProcessorWithLM
(
ProcessorMixin
):
...
...
@@ -283,7 +289,8 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# create multiprocessing pool and list numpy arrays
logits_list
=
[
array
for
array
in
logits
]
# filter out logits padding
logits_list
=
[
array
[(
array
!=
-
100.0
).
all
(
axis
=-
1
)]
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
(
num_processes
)
# pyctcdecode
...
...
@@ -300,11 +307,14 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
# clone multi-processing pool
pool
.
close
()
# extract text
batch_texts
=
[
d
[
0
][
0
]
for
d
in
decoded_beams
]
# extract text and scores
batch_texts
,
logit_scores
,
lm_scores
=
[],
[],
[]
for
d
in
decoded_beams
:
batch_texts
.
append
(
d
[
0
][
0
])
logit_scores
.
append
(
d
[
0
][
-
2
])
lm_scores
.
append
(
d
[
0
][
-
1
])
# more output features will be added in the future
return
Wav2Vec2DecoderWithLMOutput
(
text
=
batch_texts
)
return
Wav2Vec2DecoderWithLMOutput
(
text
=
batch_texts
,
logit_score
=
logit_scores
,
lm_score
=
lm_scores
)
def
decode
(
self
,
...
...
@@ -379,7 +389,9 @@ class Wav2Vec2ProcessorWithLM(ProcessorMixin):
)
# more output features will be added in the future
return
Wav2Vec2DecoderWithLMOutput
(
text
=
decoded_beams
[
0
][
0
])
return
Wav2Vec2DecoderWithLMOutput
(
text
=
decoded_beams
[
0
][
0
],
logit_score
=
decoded_beams
[
0
][
-
2
],
lm_score
=
decoded_beams
[
0
][
-
1
]
)
@
contextmanager
def
as_target_processor
(
self
):
...
...
tests/test_processor_wav2vec2_with_lm.py
View file @
67047b86
...
...
@@ -178,12 +178,14 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits
=
self
.
_get_dummy_logits
(
shape
=
(
10
,
16
),
seed
=
13
)
decoded_processor
=
processor
.
decode
(
logits
)
.
text
decoded_processor
=
processor
.
decode
(
logits
)
decoded_decoder
=
decoder
.
decode_beams
(
logits
)[
0
]
[
0
]
decoded_decoder
=
decoder
.
decode_beams
(
logits
)[
0
]
self
.
assertEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertEqual
(
"</s> <s> </s>"
,
decoded_processor
)
self
.
assertEqual
(
decoded_decoder
[
0
],
decoded_processor
.
text
)
self
.
assertEqual
(
"</s> <s> </s>"
,
decoded_processor
.
text
)
self
.
assertEqual
(
decoded_decoder
[
-
2
],
decoded_processor
.
logit_score
)
self
.
assertEqual
(
decoded_decoder
[
-
1
],
decoded_processor
.
lm_score
)
def
test_decoder_batch
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
...
...
@@ -194,15 +196,22 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits
=
self
.
_get_dummy_logits
()
decoded_processor
=
processor
.
batch_decode
(
logits
)
.
text
decoded_processor
=
processor
.
batch_decode
(
logits
)
logits_list
=
[
array
for
array
in
logits
]
pool
=
get_context
(
"fork"
).
Pool
()
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoder
.
decode_beams_batch
(
pool
,
logits_list
)]
decoded_beams
=
decoder
.
decode_beams_batch
(
pool
,
logits_list
)
texts_decoder
,
logit_scores_decoder
,
lm_scores_decoder
=
[],
[],
[]
for
beams
in
decoded_beams
:
texts_decoder
.
append
(
beams
[
0
][
0
])
logit_scores_decoder
.
append
(
beams
[
0
][
-
2
])
lm_scores_decoder
.
append
(
beams
[
0
][
-
1
])
pool
.
close
()
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> <s> </s>"
,
"<s> <s> <s>"
],
decoded_processor
)
self
.
assertListEqual
(
texts_decoder
,
decoded_processor
.
text
)
self
.
assertListEqual
([
"<s> <s> </s>"
,
"<s> <s> <s>"
],
decoded_processor
.
text
)
self
.
assertListEqual
(
logit_scores_decoder
,
decoded_processor
.
logit_score
)
self
.
assertListEqual
(
lm_scores_decoder
,
decoded_processor
.
lm_score
)
def
test_decoder_with_params
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment