Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
14d058b9
Unverified
Commit
14d058b9
authored
Jan 24, 2023
by
Sanchit Gandhi
Committed by
GitHub
Jan 24, 2023
Browse files
[W2V2 with LM] Fix decoder test with params (#21277)
parent
94a7edd9
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
...odels/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
+10
-3
No files found.
tests/models/wav2vec2_with_lm/test_processor_wav2vec2_with_lm.py
View file @
14d058b9
...
@@ -230,7 +230,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -230,7 +230,6 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
self
.
assertListEqual
(
logit_scores_decoder
,
decoded_processor
.
logit_score
)
self
.
assertListEqual
(
logit_scores_decoder
,
decoded_processor
.
logit_score
)
self
.
assertListEqual
(
lm_scores_decoder
,
decoded_processor
.
lm_score
)
self
.
assertListEqual
(
lm_scores_decoder
,
decoded_processor
.
lm_score
)
@
unittest
.
skip
(
"Fix me Sanchit"
)
def
test_decoder_with_params
(
self
):
def
test_decoder_with_params
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
feature_extractor
=
self
.
get_feature_extractor
()
tokenizer
=
self
.
get_tokenizer
()
tokenizer
=
self
.
get_tokenizer
()
...
@@ -240,7 +239,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -240,7 +239,7 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
logits
=
self
.
_get_dummy_logits
()
logits
=
self
.
_get_dummy_logits
()
beam_width
=
20
beam_width
=
15
beam_prune_logp
=
-
20.0
beam_prune_logp
=
-
20.0
token_min_logp
=
-
4.0
token_min_logp
=
-
4.0
...
@@ -264,9 +263,17 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
...
@@ -264,9 +263,17 @@ class Wav2Vec2ProcessorWithLMTest(unittest.TestCase):
)
)
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
decoded_decoder
=
[
d
[
0
][
0
]
for
d
in
decoded_decoder_out
]
logit_scores
=
[
d
[
0
][
2
]
for
d
in
decoded_decoder_out
]
lm_scores
=
[
d
[
0
][
3
]
for
d
in
decoded_decoder_out
]
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
(
decoded_decoder
,
decoded_processor
)
self
.
assertListEqual
([
"<s> </s> </s>"
,
"<s> <s> </s>"
],
decoded_processor
)
self
.
assertListEqual
([
"</s> <s> <s>"
,
"<s> <s> <s>"
],
decoded_processor
)
self
.
assertTrue
(
np
.
array_equal
(
logit_scores
,
decoded_processor_out
.
logit_score
))
self
.
assertTrue
(
np
.
allclose
([
-
20.054
,
-
18.447
],
logit_scores
,
atol
=
1e-3
))
self
.
assertTrue
(
np
.
array_equal
(
lm_scores
,
decoded_processor_out
.
lm_score
))
self
.
assertTrue
(
np
.
allclose
([
-
15.554
,
-
13.9474
],
lm_scores
,
atol
=
1e-3
))
def
test_decoder_with_params_of_lm
(
self
):
def
test_decoder_with_params_of_lm
(
self
):
feature_extractor
=
self
.
get_feature_extractor
()
feature_extractor
=
self
.
get_feature_extractor
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment