Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
9a0399e1
"vscode:/vscode.git/clone" did not exist on "8a312956fd49efd69adb98c40996719d4c276a01"
Unverified
Commit
9a0399e1
authored
Feb 08, 2021
by
Patrick von Platen
Committed by
GitHub
Feb 08, 2021
Browse files
fix bart tests (#10060)
parent
b01483fa
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
7 deletions
+2
-7
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+2
-7
No files found.
tests/test_modeling_bart.py
View file @
9a0399e1
...
...
@@ -42,7 +42,6 @@ if is_torch_available():
BartForSequenceClassification
,
BartModel
,
BartTokenizer
,
BartTokenizerFast
,
pipeline
,
)
from
transformers.models.bart.modeling_bart
import
BartDecoder
,
BartEncoder
,
shift_tokens_right
...
...
@@ -566,10 +565,6 @@ class BartModelIntegrationTests(unittest.TestCase):
def
default_tokenizer
(
self
):
return
BartTokenizer
.
from_pretrained
(
"facebook/bart-large"
)
@
cached_property
def
default_tokenizer_fast
(
self
):
return
BartTokenizerFast
.
from_pretrained
(
"facebook/bart-large"
)
@
slow
def
test_inference_no_head
(
self
):
model
=
BartModel
.
from_pretrained
(
"facebook/bart-large"
).
to
(
torch_device
)
...
...
@@ -589,14 +584,14 @@ class BartModelIntegrationTests(unittest.TestCase):
pbase
=
pipeline
(
task
=
"fill-mask"
,
model
=
"facebook/bart-base"
)
src_text
=
[
" I went to the <mask>."
]
results
=
[
x
[
"token_str"
]
for
x
in
pbase
(
src_text
)]
assert
"
Ġ
bathroom"
in
results
assert
"
bathroom"
in
results
@
slow
def
test_large_mask_filling
(
self
):
plarge
=
pipeline
(
task
=
"fill-mask"
,
model
=
"facebook/bart-large"
)
src_text
=
[
" I went to the <mask>."
]
results
=
[
x
[
"token_str"
]
for
x
in
plarge
(
src_text
)]
expected_results
=
[
"
Ġ
bathroom"
,
"
Ġ
gym"
,
"
Ġ
wrong"
,
"
Ġ
movies"
,
"
Ġ
hospital"
]
expected_results
=
[
"
bathroom"
,
"
gym"
,
"
wrong"
,
"
movies"
,
"
hospital"
]
self
.
assertListEqual
(
results
,
expected_results
)
@
slow
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment