Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e11d923b
Unverified
Commit
e11d923b
authored
Aug 25, 2020
by
Sam Shleifer
Committed by
GitHub
Aug 25, 2020
Browse files
Fix pegasus-xsum integration test (#6726)
parent
7e6397a7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
tests/test_modeling_pegasus.py
tests/test_modeling_pegasus.py
+3
-3
No files found.
tests/test_modeling_pegasus.py
View file @
e11d923b
...
@@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
...
@@ -20,8 +20,8 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
checkpoint_name
=
"google/pegasus-xsum"
checkpoint_name
=
"google/pegasus-xsum"
src_text
=
[
PGE_ARTICLE
,
XSUM_ENTRY_LONGER
]
src_text
=
[
PGE_ARTICLE
,
XSUM_ENTRY_LONGER
]
tgt_text
=
[
tgt_text
=
[
"California's largest electricity provider has turned off power to
ten
s of thousands of customers."
,
"California's largest electricity provider has turned off power to
hundred
s of thousands of customers."
,
"N-Dubz have
reveale
d they were
n't expecting
to get four nominations
at
this year's Mobo Awards."
,
"N-Dubz have
sai
d they were
surprised
to get four nominations
for
this year's Mobo Awards."
,
]
]
@
cached_property
@
cached_property
...
@@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
...
@@ -37,7 +37,7 @@ class PegasusXSUMIntegrationTest(AbstractSeq2SeqIntegrationTest):
assert
inputs
.
input_ids
.
shape
==
(
2
,
421
)
assert
inputs
.
input_ids
.
shape
==
(
2
,
421
)
translated_tokens
=
self
.
model
.
generate
(
**
inputs
)
translated_tokens
=
self
.
model
.
generate
(
**
inputs
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
decoded
=
self
.
tokenizer
.
batch_decode
(
translated_tokens
,
skip_special_tokens
=
True
)
self
.
assert
Equal
(
self
.
tgt_text
,
decoded
)
assert
self
.
tgt_text
==
decoded
if
"cuda"
not
in
torch_device
:
if
"cuda"
not
in
torch_device
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment