Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6b3438df
"git@developer.sourcefind.cn:ox696c/ktransformers.git" did not exist on "34c199403be222da54f6eb55cb19ece97b9ee995"
Commit
6b3438df
authored
Sep 09, 2019
by
thomwolf
Browse files
fixing GPT2 double head model and updating the torch version tests
parent
e3600372
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
98 additions
and
48 deletions
+98
-48
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+8
-5
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+16
-13
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+10
-5
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+30
-6
pytorch_transformers/tests/modeling_tf_gpt2_test.py
pytorch_transformers/tests/modeling_tf_gpt2_test.py
+33
-18
No files found.
pytorch_transformers/modeling_gpt2.py
View file @
6b3438df
...
@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
past
=
[
None
]
*
len
(
self
.
h
)
...
@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Attention mask.
# Attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...
@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
config
.
n_layer
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
inputs_embeds
=
self
.
wte
(
input_ids
)
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_embeds
=
self
.
wte
(
token_type_ids
)
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
else
:
token_type_embeds
=
0
token_type_embeds
=
0
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
6b3438df
...
@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
...
@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
def
_linear
(
self
,
inputs
):
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
"""Computes logits by running inputs through a linear layer.
Args:
Args:
inputs: A float32 tensor with shape [
batch_size, length
, hidden_size]
inputs: A float32 tensor with shape [
...
, hidden_size]
Returns:
Returns:
float32 tensor with shape [
batch_size, length
, vocab_size].
float32 tensor with shape [
...
, vocab_size].
"""
"""
batch_size
=
shape_list
(
inputs
)[
0
]
first_dims
=
shape_list
(
inputs
)[:
-
1
]
length
=
shape_list
(
inputs
)[
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
vocab_size
])
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
...
@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
@
tf
.
function
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
raise
ValueError
(
"Inputs should be a list or a dict with at least two elements: 'inputs_ids' and 'mc_token_ids'"
)
input_ids
=
inputs
mc_token_ids
,
past
,
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
mc_token_ids
=
inputs
[
1
]
mc_token_ids
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
past
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
past
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
attention_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
attention_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
token_type_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
token_type_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
...
@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
None
)
past
=
inputs
.
get
(
'past'
,
None
)
past
=
inputs
.
get
(
'past'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
...
@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
assert
len
(
shape_list
(
input_ids
)
)
==
3
,
"Inputs should have 3 dimensions: batch, choices, sequence length"
input_shapes
=
shape_list
(
input_ids
)
num_choices
=
shape_list
(
input_ids
)[
1
]
seq_length
=
shape_list
(
input_ids
)[
2
]
seq_length
=
input_shapes
[
-
1
]
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
...
@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
tf
.
reshape
(
hidden_states
,
input_shapes
+
shape_list
(
hidden_states
)[
-
1
:])
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
mc_logits
=
self
.
multiple_choice_head
([
hidden_states
,
mc_token_ids
],
training
=
training
)
mc_logits
=
self
.
multiple_choice_head
([
hidden_states
,
mc_token_ids
],
training
=
training
)
mc_logits
=
tf
.
squeeze
(
mc_logits
,
axis
=-
1
)
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_transformers/modeling_tf_utils.py
View file @
6b3438df
...
@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
...
@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif
self
.
summary_type
==
'mean'
:
elif
self
.
summary_type
==
'mean'
:
output
=
tf
.
mean
(
hidden_states
,
axis
=
1
)
output
=
tf
.
mean
(
hidden_states
,
axis
=
1
)
elif
self
.
summary_type
==
'cls_index'
:
elif
self
.
summary_type
==
'cls_index'
:
hidden_shape
=
shape_list
(
hidden_states
)
# e.g. [batch, num choices, seq length, hidden dims]
if
cls_index
is
None
:
if
cls_index
is
None
:
cls_index
=
tf
.
fill
(
tf
.
shape
(
hidden_states
[...,
:
1
,
:]),
hidden_states
.
shape
[
-
2
]
-
1
,
dtype
=
tf
.
int32
)
cls_index
=
tf
.
fill
(
hidden_shape
[:
-
2
],
hidden_shape
[
-
2
]
-
1
)
# A tensor full of shape [batch] or [batch, num choices] full of sequence length
else
:
cls_shape
=
shape_list
(
cls_index
)
cls_index
=
cls_index
[...,
tf
.
newaxis
,
tf
.
newaxis
]
if
len
(
cls_shape
)
<=
len
(
hidden_shape
)
-
2
:
cls_index
=
cls_index
.
expand
((
-
1
,)
*
(
cls_index
.
dim
()
-
1
)
+
(
hidden_states
.
size
(
-
1
),))
cls_index
=
cls_index
[...,
tf
.
newaxis
]
# else:
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, XX, hidden_size)
output
=
tf
.
gather
(
hidden_states
,
cls_index
,
batch_dims
=
len
(
hidden_shape
)
-
2
)
output
=
tf
.
squeeze
(
output
,
axis
=
len
(
hidden_shape
)
-
2
)
# shape of output: (batch, num choices, hidden_size)
elif
self
.
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
raise
NotImplementedError
...
...
pytorch_transformers/modeling_utils.py
View file @
6b3438df
...
@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
...
@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
def
forward
(
self
,
hidden_states
,
cls_index
=
None
):
def
forward
(
self
,
hidden_states
,
cls_index
=
None
):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
""" hidden_states: float Tensor in shape [bsz,
...,
seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'cls_index' and cls_index is None:
if summary_type == 'cls_index' and cls_index is None:
...
...
pytorch_transformers/tests/modeling_gpt2_test.py
View file @
6b3438df
...
@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
...
@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
...
@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
choice_labels
=
None
...
@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
*
args
):
model
=
GPT2DoubleHeadsModel
(
config
)
model
=
GPT2DoubleHeadsModel
(
config
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
,
mc_logits
,
_
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
lm_labels
=
input_ids
)
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
inputs
=
{
'input_ids'
:
multiple_choice_inputs_ids
,
'mc_token_ids'
:
mc_token_ids
,
'attention_mask'
:
multiple_choice_input_mask
,
'token_type_ids'
:
multiple_choice_token_type_ids
,
'lm_labels'
:
multiple_choice_inputs_ids
}
loss
,
lm_logits
,
mc_logits
,
_
=
model
(
**
inputs
)
result
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"lm_logits"
:
lm_logits
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
}
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
inputs_dict
=
{
'input_ids'
:
input_ids
,
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'token_type_ids'
:
token_type_ids
,
...
...
pytorch_transformers/tests/modeling_tf_gpt2_test.py
View file @
6b3438df
...
@@ -37,9 +37,9 @@ else:
...
@@ -37,9 +37,9 @@ else:
class
TFGPT2ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
class
TFGPT2ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
#
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
all_model_classes
=
(
TFGPT2Model
,
TFGPT2LMHeadModel
,
#
TFGPT2DoubleHeadsModel) if is_tf_available() else ()
TFGPT2DoubleHeadsModel
)
if
is_tf_available
()
else
()
all_model_classes
=
(
TFGPT2Model
,
TFGPT2LMHeadModel
)
if
is_tf_available
()
else
()
#
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
class
TFGPT2ModelTester
(
object
):
class
TFGPT2ModelTester
(
object
):
...
@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
...
@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
...
@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
choice_labels
=
None
...
@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2Model
(
config
=
config
)
model
=
TFGPT2Model
(
config
=
config
)
...
@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_gpt2_double_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_double_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
*
args
):
pass
model
=
TFGPT2DoubleHeadsModel
(
config
=
config
)
# model = TFGPT2DoubleHeadsModel(config=config)
# inputs = {'input_ids': input_ids,
multiple_choice_inputs_ids
=
tf
.
tile
(
tf
.
expand_dims
(
input_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
# 'attention_mask': input_mask,
multiple_choice_input_mask
=
tf
.
tile
(
tf
.
expand_dims
(
input_mask
,
1
),
(
1
,
self
.
num_choices
,
1
))
# 'token_type_ids': token_type_ids}
multiple_choice_token_type_ids
=
tf
.
tile
(
tf
.
expand_dims
(
token_type_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
# seq_relationship_score, = model(inputs)[0]
# result = {
inputs
=
{
'input_ids'
:
multiple_choice_inputs_ids
,
# "seq_relationship_score": seq_relationship_score.numpy(),
'mc_token_ids'
:
mc_token_ids
,
# }
'attention_mask'
:
multiple_choice_input_mask
,
# self.parent.assertListEqual(
'token_type_ids'
:
multiple_choice_token_type_ids
}
# list(result["seq_relationship_score"].shape),
lm_logits
,
mc_logits
=
model
(
inputs
)[:
2
]
# [self.batch_size, 2])
result
=
{
"lm_logits"
:
lm_logits
.
numpy
(),
"mc_logits"
:
mc_logits
.
numpy
()
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
return
config
,
inputs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment