Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e53331c9
Unverified
Commit
e53331c9
authored
Nov 22, 2022
by
Joao Gante
Committed by
GitHub
Nov 22, 2022
Browse files
Generate: fix plbart generation tests (#20391)
parent
2e17db8a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
3 deletions
+4
-3
tests/models/plbart/test_modeling_plbart.py
tests/models/plbart/test_modeling_plbart.py
+4
-3
No files found.
tests/models/plbart/test_modeling_plbart.py
View file @
e53331c9
...
@@ -409,12 +409,12 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
...
@@ -409,12 +409,12 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
src_text
=
[
"Is 0 the first Fibonacci number ?"
,
"Find the sum of all prime numbers ."
]
src_text
=
[
"Is 0 the first Fibonacci number ?"
,
"Find the sum of all prime numbers ."
]
tgt_text
=
[
"0 the first Fibonacci number?"
,
"the sum of all prime numbers.......... the the"
]
tgt_text
=
[
"0 the first Fibonacci number?"
,
"the sum of all prime numbers.......... the the"
]
@
unittest
.
skip
(
"This test is broken, fix me gante"
)
def
test_base_generate
(
self
):
def
test_base_generate
(
self
):
inputs
=
self
.
tokenizer
([
self
.
src_text
[
0
]],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
self
.
tokenizer
([
self
.
src_text
[
0
]],
return_tensors
=
"pt"
).
to
(
torch_device
)
src_lan
=
self
.
tokenizer
.
_convert_lang_code_special_format
(
"en_XX"
)
translated_tokens
=
self
.
model
.
generate
(
translated_tokens
=
self
.
model
.
generate
(
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
),
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
),
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
"en_XX"
],
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
src_lan
],
)
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
self
.
assertEqual
(
self
.
tgt_text
[
0
],
decoded
[
0
])
self
.
assertEqual
(
self
.
tgt_text
[
0
],
decoded
[
0
])
...
@@ -422,8 +422,9 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
...
@@ -422,8 +422,9 @@ class PLBartBaseIntegrationTest(AbstractSeq2SeqIntegrationTest):
@
slow
@
slow
def
test_fill_mask
(
self
):
def
test_fill_mask
(
self
):
inputs
=
self
.
tokenizer
([
"Is 0 the <mask> Fibonacci <mask> ?"
],
return_tensors
=
"pt"
).
to
(
torch_device
)
inputs
=
self
.
tokenizer
([
"Is 0 the <mask> Fibonacci <mask> ?"
],
return_tensors
=
"pt"
).
to
(
torch_device
)
src_lan
=
self
.
tokenizer
.
_convert_lang_code_special_format
(
"en_XX"
)
outputs
=
self
.
model
.
generate
(
outputs
=
self
.
model
.
generate
(
inputs
[
"input_ids"
],
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
"en_XX"
],
num_beams
=
1
inputs
[
"input_ids"
],
decoder_start_token_id
=
self
.
tokenizer
.
lang_code_to_id
[
src_lan
],
num_beams
=
1
)
)
prediction
:
str
=
self
.
tokenizer
.
batch_decode
(
prediction
:
str
=
self
.
tokenizer
.
batch_decode
(
outputs
,
clean_up_tokenization_spaces
=
True
,
skip_special_tokens
=
True
outputs
,
clean_up_tokenization_spaces
=
True
,
skip_special_tokens
=
True
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment