Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2c778428
Unverified
Commit
2c778428
authored
Apr 29, 2020
by
Sam Shleifer
Committed by
GitHub
Apr 29, 2020
Browse files
[Fix common tests on GPU] send model, ids to torch_device (#4014)
parent
6faca88e
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
16 deletions
+18
-16
tests/test_modeling_common.py
tests/test_modeling_common.py
+18
-16
No files found.
tests/test_modeling_common.py
View file @
2c778428
...
@@ -19,6 +19,7 @@ import os.path
...
@@ -19,6 +19,7 @@ import os.path
import
random
import
random
import
tempfile
import
tempfile
import
unittest
import
unittest
from
typing
import
List
from
transformers
import
is_torch_available
from
transformers
import
is_torch_available
...
@@ -629,10 +630,10 @@ class ModelTesterMixin:
...
@@ -629,10 +630,10 @@ class ModelTesterMixin:
# iterate over all generative models
# iterate over all generative models
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
.
to
(
torch_device
)
if
config
.
bos_token_id
is
None
:
if
config
.
bos_token_id
is
None
:
# if bos token id is not defined mo
b
el needs input_ids
# if bos token id is not defined
,
mo
d
el needs input_ids
with
self
.
assertRaises
(
AssertionError
):
with
self
.
assertRaises
(
AssertionError
):
model
.
generate
(
do_sample
=
True
,
max_length
=
5
)
model
.
generate
(
do_sample
=
True
,
max_length
=
5
)
# num_return_sequences = 1
# num_return_sequences = 1
...
@@ -651,7 +652,10 @@ class ModelTesterMixin:
...
@@ -651,7 +652,10 @@ class ModelTesterMixin:
# check bad words tokens language generation
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids
=
[
self
.
_generate_random_bad_tokens
(
1
,
model
),
self
.
_generate_random_bad_tokens
(
2
,
model
)]
bad_words_ids
=
[
self
.
_generate_random_bad_tokens
(
1
,
model
.
config
),
self
.
_generate_random_bad_tokens
(
2
,
model
.
config
),
]
output_tokens
=
model
.
generate
(
output_tokens
=
model
.
generate
(
input_ids
,
do_sample
=
True
,
bad_words_ids
=
bad_words_ids
,
num_return_sequences
=
2
input_ids
,
do_sample
=
True
,
bad_words_ids
=
bad_words_ids
,
num_return_sequences
=
2
)
)
...
@@ -661,10 +665,12 @@ class ModelTesterMixin:
...
@@ -661,10 +665,12 @@ class ModelTesterMixin:
def
test_lm_head_model_random_beam_search_generate
(
self
):
def
test_lm_head_model_random_beam_search_generate
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
input_ids
=
inputs_dict
[
"input_ids"
]
if
"input_ids"
in
inputs_dict
else
inputs_dict
[
"inputs"
]
input_ids
=
(
inputs_dict
[
"input_ids"
]
if
"input_ids"
in
inputs_dict
else
inputs_dict
[
"inputs"
]).
to
(
torch_device
)
for
model_class
in
self
.
all_generative_model_classes
:
for
model_class
in
self
.
all_generative_model_classes
:
model
=
model_class
(
config
)
model
=
model_class
(
config
)
.
to
(
torch_device
)
if
config
.
bos_token_id
is
None
:
if
config
.
bos_token_id
is
None
:
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
# if bos token id is not defined mobel needs input_ids, num_return_sequences = 1
...
@@ -684,7 +690,10 @@ class ModelTesterMixin:
...
@@ -684,7 +690,10 @@ class ModelTesterMixin:
# check bad words tokens language generation
# check bad words tokens language generation
# create list of 1-seq bad token and list of 2-seq of bad tokens
# create list of 1-seq bad token and list of 2-seq of bad tokens
bad_words_ids
=
[
self
.
_generate_random_bad_tokens
(
1
,
model
),
self
.
_generate_random_bad_tokens
(
2
,
model
)]
bad_words_ids
=
[
self
.
_generate_random_bad_tokens
(
1
,
model
.
config
),
self
.
_generate_random_bad_tokens
(
2
,
model
.
config
),
]
output_tokens
=
model
.
generate
(
output_tokens
=
model
.
generate
(
input_ids
,
do_sample
=
False
,
bad_words_ids
=
bad_words_ids
,
num_beams
=
2
,
num_return_sequences
=
2
input_ids
,
do_sample
=
False
,
bad_words_ids
=
bad_words_ids
,
num_beams
=
2
,
num_return_sequences
=
2
)
)
...
@@ -692,20 +701,13 @@ class ModelTesterMixin:
...
@@ -692,20 +701,13 @@ class ModelTesterMixin:
generated_ids
=
output_tokens
[:,
input_ids
.
shape
[
-
1
]
:]
generated_ids
=
output_tokens
[:,
input_ids
.
shape
[
-
1
]
:]
self
.
assertFalse
(
self
.
_check_match_tokens
(
generated_ids
.
tolist
(),
bad_words_ids
))
self
.
assertFalse
(
self
.
_check_match_tokens
(
generated_ids
.
tolist
(),
bad_words_ids
))
def
_generate_random_bad_tokens
(
self
,
num_bad_tokens
,
model
)
:
def
_generate_random_bad_tokens
(
self
,
num_bad_tokens
:
int
,
config
)
->
List
[
int
]
:
# special tokens cannot be bad tokens
# special tokens cannot be bad tokens
special_tokens
=
[]
special_tokens
=
[
x
for
x
in
[
config
.
bos_token_id
,
config
.
eos_token_id
,
config
.
pad_token_id
]
if
x
is
not
None
]
if
model
.
config
.
bos_token_id
is
not
None
:
special_tokens
.
append
(
model
.
config
.
bos_token_id
)
if
model
.
config
.
pad_token_id
is
not
None
:
special_tokens
.
append
(
model
.
config
.
pad_token_id
)
if
model
.
config
.
eos_token_id
is
not
None
:
special_tokens
.
append
(
model
.
config
.
eos_token_id
)
# create random bad tokens that are not special tokens
# create random bad tokens that are not special tokens
bad_tokens
=
[]
bad_tokens
=
[]
while
len
(
bad_tokens
)
<
num_bad_tokens
:
while
len
(
bad_tokens
)
<
num_bad_tokens
:
token
=
ids_tensor
((
1
,
1
),
self
.
model_tester
.
vocab_size
).
squeeze
(
0
).
numpy
()[
0
]
token
=
ids_tensor
((
1
,
1
),
self
.
model_tester
.
vocab_size
).
squeeze
(
0
).
cpu
().
numpy
()[
0
]
if
token
not
in
special_tokens
:
if
token
not
in
special_tokens
:
bad_tokens
.
append
(
token
)
bad_tokens
.
append
(
token
)
return
bad_tokens
return
bad_tokens
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment