Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
932eab94
Unverified
Commit
932eab94
authored
Mar 04, 2020
by
Patrick von Platen
Committed by
GitHub
Mar 04, 2020
Browse files
include tf gpt2 tests for attn mask and past variable (#3122)
parent
256cbbc4
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
0 deletions
+76
-0
tests/test_modeling_tf_gpt2.py
tests/test_modeling_tf_gpt2.py
+76
-0
No files found.
tests/test_modeling_tf_gpt2.py
View file @
932eab94
...
...
@@ -30,6 +30,7 @@ if is_tf_available():
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
,
shape_list
,
)
...
...
@@ -167,6 +168,73 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
],
)
def
create_and_check_gpt2_model_past
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2Model
(
config
=
config
)
# first forward pass
output
,
past
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
)
# create hypothetical next token and extent to next_input_ids
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
next_token_types
=
ids_tensor
([
self
.
batch_size
,
1
],
self
.
type_vocab_size
)
# append to next input_ids and token_type_ids
next_input_ids
=
tf
.
concat
([
input_ids
,
next_tokens
],
axis
=-
1
)
next_token_type_ids
=
tf
.
concat
([
token_type_ids
,
next_token_types
],
axis
=-
1
)
output_from_no_past
,
_
=
model
(
next_input_ids
,
token_type_ids
=
next_token_type_ids
)
output_from_past
,
_
=
model
(
next_tokens
,
token_type_ids
=
next_token_types
,
past
=
past
)
# select random slice
random_slice_idx
=
int
(
ids_tensor
((
1
,),
shape_list
(
output_from_past
)[
-
1
]))
output_from_no_past_slice
=
output_from_no_past
[:,
-
1
,
random_slice_idx
]
output_from_past_slice
=
output_from_past
[:,
0
,
random_slice_idx
]
# test that outputs are equal for slice
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-12
)
def
create_and_check_gpt2_model_attention_mask_past
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2Model
(
config
=
config
)
# create attention mask
half_seq_length
=
self
.
seq_length
//
2
attn_mask_begin
=
tf
.
ones
((
self
.
batch_size
,
half_seq_length
),
dtype
=
tf
.
int32
)
attn_mask_end
=
tf
.
zeros
((
self
.
batch_size
,
self
.
seq_length
-
half_seq_length
),
dtype
=
tf
.
int32
)
attn_mask
=
tf
.
concat
([
attn_mask_begin
,
attn_mask_end
],
axis
=
1
)
# first forward pass
output
,
past
=
model
(
input_ids
,
attention_mask
=
attn_mask
)
# create hypothetical next token and extent to next_input_ids
next_tokens
=
ids_tensor
((
self
.
batch_size
,
1
),
config
.
vocab_size
)
# change a random masked slice from input_ids
random_seq_idx_to_change
=
ids_tensor
((
1
,),
half_seq_length
).
numpy
()
+
1
random_other_next_tokens
=
ids_tensor
((
self
.
batch_size
,
self
.
seq_length
),
config
.
vocab_size
)
vector_condition
=
tf
.
range
(
self
.
seq_length
)
==
(
self
.
seq_length
-
random_seq_idx_to_change
)
condition
=
tf
.
transpose
(
tf
.
broadcast_to
(
tf
.
expand_dims
(
vector_condition
,
-
1
),
(
self
.
seq_length
,
self
.
batch_size
))
)
input_ids
=
tf
.
where
(
condition
,
random_other_next_tokens
,
input_ids
)
# append to next input_ids and attn_mask
next_input_ids
=
tf
.
concat
([
input_ids
,
next_tokens
],
axis
=-
1
)
attn_mask
=
tf
.
concat
([
attn_mask
,
tf
.
ones
((
shape_list
(
attn_mask
)[
0
],
1
),
dtype
=
tf
.
int32
)],
axis
=
1
)
# get two different outputs
output_from_no_past
,
_
=
model
(
next_input_ids
,
attention_mask
=
attn_mask
)
output_from_past
,
_
=
model
(
next_tokens
,
past
=
past
,
attention_mask
=
attn_mask
)
# select random slice
random_slice_idx
=
int
(
ids_tensor
((
1
,),
shape_list
(
output_from_past
)[
-
1
]))
output_from_no_past_slice
=
output_from_no_past
[:,
-
1
,
random_slice_idx
]
output_from_past_slice
=
output_from_past
[:,
0
,
random_slice_idx
]
# test that outputs are equal for slice
tf
.
debugging
.
assert_near
(
output_from_past_slice
,
output_from_no_past_slice
,
rtol
=
1e-12
)
def
create_and_check_gpt2_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2LMHeadModel
(
config
=
config
)
inputs
=
{
...
...
@@ -237,6 +305,14 @@ class TFGPT2ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_model
(
*
config_and_inputs
)
def
test_gpt2_model_past
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_model_past
(
*
config_and_inputs
)
def
test_gpt2_model_att_mask_past
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_model_attention_mask_past
(
*
config_and_inputs
)
def
test_gpt2_lm_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment