Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e3e65173
Unverified
Commit
e3e65173
authored
Oct 07, 2020
by
Sam Shleifer
Committed by
GitHub
Oct 07, 2020
Browse files
Fix 3 failing slow bart/blender tests (#7652)
parent
960faaaf
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
8 additions
and
18 deletions
+8
-18
tests/test_modeling_bart.py
tests/test_modeling_bart.py
+3
-5
tests/test_modeling_blenderbot.py
tests/test_modeling_blenderbot.py
+5
-13
No files found.
tests/test_modeling_bart.py
View file @
e3e65173
...
...
@@ -368,7 +368,7 @@ class BartHeadTests(unittest.TestCase):
torch
.
Tensor
([
0
,
11349
,
495
,
4040
,
571
,
2
]),
]
for
ex
,
desired_result
in
zip
(
examples
,
fairseq_results
):
bart_toks
=
tokenizer
.
encode
(
ex
,
return_tensors
=
"pt"
)
bart_toks
=
tokenizer
.
encode
(
ex
,
return_tensors
=
"pt"
)
.
squeeze
()
assert_tensors_close
(
desired_result
.
long
(),
bart_toks
,
prefix
=
ex
)
def
test_generate_fp16
(
self
):
...
...
@@ -417,11 +417,9 @@ class BartHeadTests(unittest.TestCase):
def
assert_tensors_close
(
a
,
b
,
atol
=
1e-12
,
prefix
=
""
):
"""If tensors not close, or a and b aren't both tensors, raise a nice Assertion error."""
"""If tensors have different shapes, different values or a and b are not both tensors, raise a nice Assertion error."""
if
a
is
None
and
b
is
None
:
return
True
assert
a
.
shape
==
b
.
shape
try
:
if
torch
.
allclose
(
a
,
b
,
atol
=
atol
):
return
True
...
...
@@ -506,7 +504,7 @@ class BartModelIntegrationTests(unittest.TestCase):
inputs_dict
=
prepare_bart_inputs_dict
(
model
.
config
,
input_ids
=
input_ids_no_pad
)
with
torch
.
no_grad
():
logits2
=
model
(
**
inputs_dict
)[
0
]
logits2
=
model
(
**
inputs_dict
)[
0
]
.
squeeze
()
assert_tensors_close
(
batched_logits
[
1
],
logits2
,
atol
=
TOLERANCE
)
assert_tensors_close
(
expected_slice
,
logits_arr
,
atol
=
TOLERANCE
)
...
...
tests/test_modeling_blenderbot.py
View file @
e3e65173
...
...
@@ -134,38 +134,31 @@ class BlenderbotTesterMixin(ModelTesterMixin, unittest.TestCase):
class
Blenderbot3BIntegrationTests
(
unittest
.
TestCase
):
ckpt
=
"facebook/blenderbot-3B"
@
cached_property
def
model
(
self
):
model
=
BlenderbotForConditionalGeneration
.
from_pretrained
(
self
.
ckpt
).
to
(
torch_device
)
if
torch_device
==
"cuda"
:
model
=
model
.
half
()
return
model
@
cached_property
def
tokenizer
(
self
):
return
BlenderbotTokenizer
.
from_pretrained
(
self
.
ckpt
)
@
slow
def
test_generation_from_short_input_same_as_parlai_3B
(
self
):
torch
.
cuda
.
empty_cache
()
model
=
BlenderbotForConditionalGeneration
.
from_pretrained
(
self
.
ckpt
).
half
().
to
(
torch_device
)
src_text
=
[
"Sam"
]
model_inputs
=
self
.
tokenizer
(
src_text
,
return_tensors
=
"pt"
).
to
(
torch_device
)
generated_utterances
=
self
.
model
.
generate
(
**
model_inputs
,
**
FASTER_GEN_KWARGS
)
generated_utterances
=
model
.
generate
(
**
model_inputs
,
**
FASTER_GEN_KWARGS
)
tgt_text
=
'Sam is a great name. It means "sun" in Gaelic.'
generated_txt
=
self
.
tokenizer
.
batch_decode
(
generated_utterances
,
**
TOK_DECODE_KW
)
assert
generated_txt
[
0
].
strip
()
==
tgt_text
@
slow
def
test_generation_from_long_input_same_as_parlai_3B
(
self
):
src_text
=
"Social anxiety
\n
Wow, I am never shy. Do you have anxiety?
\n
Yes. I end up sweating and blushing and feel like i'm going to throw up.
\n
and why is that?"
model_inputs
=
self
.
tokenizer
([
src_text
],
return_tensors
=
"pt"
).
to
(
torch_device
)
generated_ids
=
self
.
model
.
generate
(
**
model_inputs
,
**
FASTER_GEN_KWARGS
)[
0
]
generated_ids
=
model
.
generate
(
**
model_inputs
,
**
FASTER_GEN_KWARGS
)[
0
]
reply
=
self
.
tokenizer
.
decode
(
generated_ids
,
**
TOK_DECODE_KW
)
assert
"I think it's because we are so worried about what people think of us."
==
reply
.
strip
()
del
model
@
require_torch
...
...
@@ -193,7 +186,6 @@ class Blenderbot90MIntegrationTests(unittest.TestCase):
model_inputs
=
self
.
tokenizer
(
src_text
,
return_tensors
=
"pt"
).
to
(
torch_device
)
assert
isinstance
(
self
.
tokenizer
,
BlenderbotSmallTokenizer
)
assert
self
.
model
.
config
.
do
generated_ids
=
self
.
model
.
generate
(
**
model_inputs
)[
0
]
reply
=
self
.
tokenizer
.
decode
(
generated_ids
,
**
TOK_DECODE_KW
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment