Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
54192058
"examples/legacy/seq2seq/run_eval.py" did not exist on "b76cb1c3dfc64d1dcaddc3d6d9313dddeb626d05"
Unverified
Commit
54192058
authored
May 19, 2022
by
Patrick von Platen
Committed by
GitHub
May 19, 2022
Browse files
[Test OPT] Add batch generation test opt (#17359)
* up * up
parent
48c22691
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
41 additions
and
0 deletions
+41
-0
tests/models/opt/test_modeling_opt.py
tests/models/opt/test_modeling_opt.py
+41
-0
No files found.
tests/models/opt/test_modeling_opt.py
View file @
54192058
...
@@ -366,6 +366,47 @@ class OPTGenerationTest(unittest.TestCase):
...
@@ -366,6 +366,47 @@ class OPTGenerationTest(unittest.TestCase):
self
.
assertListEqual
(
predicted_outputs
,
EXPECTED_OUTPUTS
)
self
.
assertListEqual
(
predicted_outputs
,
EXPECTED_OUTPUTS
)
def
test_batch_generation
(
self
):
model_id
=
"facebook/opt-350m"
tokenizer
=
GPT2Tokenizer
.
from_pretrained
(
model_id
)
model
=
OPTForCausalLM
.
from_pretrained
(
model_id
)
model
.
to
(
torch_device
)
tokenizer
.
padding_side
=
"left"
# use different length sentences to test batching
sentences
=
[
"Hello, my dog is a little"
,
"Today, I"
,
]
inputs
=
tokenizer
(
sentences
,
return_tensors
=
"pt"
,
padding
=
True
)
input_ids
=
inputs
[
"input_ids"
].
to
(
torch_device
)
outputs
=
model
.
generate
(
input_ids
=
input_ids
,
attention_mask
=
inputs
[
"attention_mask"
].
to
(
torch_device
),
)
inputs_non_padded
=
tokenizer
(
sentences
[
0
],
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
output_non_padded
=
model
.
generate
(
input_ids
=
inputs_non_padded
)
num_paddings
=
inputs_non_padded
.
shape
[
-
1
]
-
inputs
[
"attention_mask"
][
-
1
].
long
().
sum
().
cpu
().
item
()
inputs_padded
=
tokenizer
(
sentences
[
1
],
return_tensors
=
"pt"
).
input_ids
.
to
(
torch_device
)
output_padded
=
model
.
generate
(
input_ids
=
inputs_padded
,
max_length
=
model
.
config
.
max_length
-
num_paddings
)
batch_out_sentence
=
tokenizer
.
batch_decode
(
outputs
,
skip_special_tokens
=
True
)
non_padded_sentence
=
tokenizer
.
decode
(
output_non_padded
[
0
],
skip_special_tokens
=
True
)
padded_sentence
=
tokenizer
.
decode
(
output_padded
[
0
],
skip_special_tokens
=
True
)
expected_output_sentence
=
[
"Hello, my dog is a little bit of a dork.
\n
I'm a little bit"
,
"Today, I was in the middle of a conversation with a friend about the"
,
]
self
.
assertListEqual
(
expected_output_sentence
,
batch_out_sentence
)
self
.
assertListEqual
(
batch_out_sentence
,
[
non_padded_sentence
,
padded_sentence
])
def
test_generation_post_attn_layer_norm
(
self
):
def
test_generation_post_attn_layer_norm
(
self
):
model_id
=
"facebook/opt-350m"
model_id
=
"facebook/opt-350m"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment