Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
6b3438df
Commit
6b3438df
authored
Sep 09, 2019
by
thomwolf
Browse files
fixing GPT2 double head model and updating the torch version tests
parent
e3600372
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
98 additions
and
48 deletions
+98
-48
pytorch_transformers/modeling_gpt2.py
pytorch_transformers/modeling_gpt2.py
+8
-5
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+16
-13
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+10
-5
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+1
-1
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+30
-6
pytorch_transformers/tests/modeling_tf_gpt2_test.py
pytorch_transformers/tests/modeling_tf_gpt2_test.py
+33
-18
No files found.
pytorch_transformers/modeling_gpt2.py
View file @
6b3438df
...
@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -367,6 +367,13 @@ class GPT2Model(GPT2PreTrainedModel):
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
if
past
is
None
:
past_length
=
0
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
past
=
[
None
]
*
len
(
self
.
h
)
...
@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -378,6 +385,7 @@ class GPT2Model(GPT2PreTrainedModel):
# Attention mask.
# Attention mask.
if
attention_mask
is
not
None
:
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
...
@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
...
@@ -407,14 +415,9 @@ class GPT2Model(GPT2PreTrainedModel):
else
:
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
head_mask
=
[
None
]
*
self
.
config
.
n_layer
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
inputs_embeds
=
self
.
wte
(
input_ids
)
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_embeds
=
self
.
wte
(
token_type_ids
)
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
else
:
token_type_embeds
=
0
token_type_embeds
=
0
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
6b3438df
...
@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
...
@@ -314,17 +314,16 @@ class TFGPT2Embeddings(tf.keras.layers.Layer):
def
_linear
(
self
,
inputs
):
def
_linear
(
self
,
inputs
):
"""Computes logits by running inputs through a linear layer.
"""Computes logits by running inputs through a linear layer.
Args:
Args:
inputs: A float32 tensor with shape [
batch_size, length
, hidden_size]
inputs: A float32 tensor with shape [
...
, hidden_size]
Returns:
Returns:
float32 tensor with shape [
batch_size, length
, vocab_size].
float32 tensor with shape [
...
, vocab_size].
"""
"""
batch_size
=
shape_list
(
inputs
)[
0
]
first_dims
=
shape_list
(
inputs
)[:
-
1
]
length
=
shape_list
(
inputs
)[
1
]
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
x
=
tf
.
reshape
(
inputs
,
[
-
1
,
self
.
hidden_size
])
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
logits
=
tf
.
matmul
(
x
,
self
.
weight
,
transpose_b
=
True
)
return
tf
.
reshape
(
logits
,
[
batch_size
,
length
,
self
.
vocab_size
])
return
tf
.
reshape
(
logits
,
first_dims
+
[
self
.
vocab_size
])
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
class
TFGPT2MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
...
@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -679,10 +678,11 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
@
tf
.
function
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
def
call
(
self
,
inputs
,
training
=
False
):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
raise
ValueError
(
"Inputs should be a list or a dict with at least two elements: 'inputs_ids' and 'mc_token_ids'"
)
input_ids
=
inputs
mc_token_ids
,
past
,
attention_mask
,
token_type_ids
,
position_ids
,
head_mask
=
None
,
None
,
None
,
None
,
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
input_ids
=
inputs
[
0
]
mc_token_ids
=
inputs
[
1
]
mc_token_ids
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
past
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
past
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
None
attention_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
attention_mask
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
None
token_type_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
token_type_ids
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
None
...
@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -691,7 +691,7 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
assert
len
(
inputs
)
<=
7
,
"Too many inputs."
else
:
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
input_ids
=
inputs
.
get
(
'input_ids'
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
)
mc_token_ids
=
inputs
.
get
(
'mc_token_ids'
,
None
)
past
=
inputs
.
get
(
'past'
,
None
)
past
=
inputs
.
get
(
'past'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
None
)
...
@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -699,9 +699,9 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
head_mask
=
inputs
.
get
(
'head_mask'
,
None
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
assert
len
(
shape_list
(
input_ids
)
)
==
3
,
"Inputs should have 3 dimensions: batch, choices, sequence length"
input_shapes
=
shape_list
(
input_ids
)
num_choices
=
shape_list
(
input_ids
)[
1
]
seq_length
=
shape_list
(
input_ids
)[
2
]
seq_length
=
input_shapes
[
-
1
]
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_input_ids
=
tf
.
reshape
(
input_ids
,
(
-
1
,
seq_length
))
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
flat_attention_mask
=
tf
.
reshape
(
attention_mask
,
(
-
1
,
seq_length
))
if
attention_mask
is
not
None
else
None
...
@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
...
@@ -710,13 +710,16 @@ class TFGPT2DoubleHeadsModel(TFGPT2PreTrainedModel):
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
flat_inputs
=
[
flat_input_ids
,
past
,
flat_attention_mask
,
flat_token_type_ids
,
flat_position_ids
,
head_mask
]
outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
transformer_outputs
=
self
.
transformer
(
flat_inputs
,
training
=
training
)
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
transformer_outputs
[
0
]
hidden_states
=
tf
.
reshape
(
hidden_states
,
input_shapes
+
shape_list
(
hidden_states
)[
-
1
:])
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
lm_logits
=
self
.
transformer
.
wte
(
hidden_states
,
mode
=
"linear"
)
mc_logits
=
self
.
multiple_choice_head
([
hidden_states
,
mc_token_ids
],
training
=
training
)
mc_logits
=
self
.
multiple_choice_head
([
hidden_states
,
mc_token_ids
],
training
=
training
)
mc_logits
=
tf
.
squeeze
(
mc_logits
,
axis
=-
1
)
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_transformers/modeling_tf_utils.py
View file @
6b3438df
...
@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
...
@@ -359,13 +359,18 @@ class TFSequenceSummary(tf.keras.layers.Layer):
elif
self
.
summary_type
==
'mean'
:
elif
self
.
summary_type
==
'mean'
:
output
=
tf
.
mean
(
hidden_states
,
axis
=
1
)
output
=
tf
.
mean
(
hidden_states
,
axis
=
1
)
elif
self
.
summary_type
==
'cls_index'
:
elif
self
.
summary_type
==
'cls_index'
:
hidden_shape
=
shape_list
(
hidden_states
)
# e.g. [batch, num choices, seq length, hidden dims]
if
cls_index
is
None
:
if
cls_index
is
None
:
cls_index
=
tf
.
fill
(
tf
.
shape
(
hidden_states
[...,
:
1
,
:]),
hidden_states
.
shape
[
-
2
]
-
1
,
dtype
=
tf
.
int32
)
cls_index
=
tf
.
fill
(
hidden_shape
[:
-
2
],
hidden_shape
[
-
2
]
-
1
)
# A tensor full of shape [batch] or [batch, num choices] full of sequence length
else
:
cls_shape
=
shape_list
(
cls_index
)
cls_index
=
cls_index
[...,
tf
.
newaxis
,
tf
.
newaxis
]
if
len
(
cls_shape
)
<=
len
(
hidden_shape
)
-
2
:
cls_index
=
cls_index
.
expand
((
-
1
,)
*
(
cls_index
.
dim
()
-
1
)
+
(
hidden_states
.
size
(
-
1
),))
cls_index
=
cls_index
[...,
tf
.
newaxis
]
# else:
# cls_index = cls_index[..., tf.newaxis]
# cls_index = cls_index.expand((-1,) * (cls_index.dim()-1) + (hidden_states.size(-1),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, XX, hidden_size)
output
=
tf
.
gather
(
hidden_states
,
cls_index
,
batch_dims
=
len
(
hidden_shape
)
-
2
)
output
=
tf
.
squeeze
(
output
,
axis
=
len
(
hidden_shape
)
-
2
)
# shape of output: (batch, num choices, hidden_size)
elif
self
.
summary_type
==
'attn'
:
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
raise
NotImplementedError
...
...
pytorch_transformers/modeling_utils.py
View file @
6b3438df
...
@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
...
@@ -679,7 +679,7 @@ class SequenceSummary(nn.Module):
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
def
forward
(
self
,
hidden_states
,
cls_index
=
None
):
def
forward
(
self
,
hidden_states
,
cls_index
=
None
):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
""" hidden_states: float Tensor in shape [bsz,
...,
seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'cls_index' and cls_index is None:
if summary_type == 'cls_index' and cls_index is None:
...
...
pytorch_transformers/tests/modeling_gpt2_test.py
View file @
6b3438df
...
@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -46,6 +46,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
...
@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -69,6 +70,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
...
@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -96,6 +98,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
choice_labels
=
None
...
@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -121,7 +127,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -163,15 +169,27 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
*
args
):
model
=
GPT2DoubleHeadsModel
(
config
)
model
=
GPT2DoubleHeadsModel
(
config
)
model
.
eval
()
model
.
eval
()
loss
,
lm_logits
,
mc_logits
,
_
=
model
(
input_ids
,
token_type_ids
=
token_type_ids
,
lm_labels
=
input_ids
)
multiple_choice_inputs_ids
=
input_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_input_mask
=
input_mask
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
multiple_choice_token_type_ids
=
token_type_ids
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
num_choices
,
-
1
).
contiguous
()
inputs
=
{
'input_ids'
:
multiple_choice_inputs_ids
,
'mc_token_ids'
:
mc_token_ids
,
'attention_mask'
:
multiple_choice_input_mask
,
'token_type_ids'
:
multiple_choice_token_type_ids
,
'lm_labels'
:
multiple_choice_inputs_ids
}
loss
,
lm_logits
,
mc_logits
,
_
=
model
(
**
inputs
)
result
=
{
result
=
{
"loss"
:
loss
,
"loss"
:
loss
,
"lm_logits"
:
lm_logits
"lm_logits"
:
lm_logits
,
"mc_logits"
:
mc_logits
}
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
...
@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
...
@@ -179,11 +197,17 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
[])
[])
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
size
()),
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
size
()),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
inputs_dict
=
{
'input_ids'
:
input_ids
,
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'token_type_ids'
:
token_type_ids
,
...
...
pytorch_transformers/tests/modeling_tf_gpt2_test.py
View file @
6b3438df
...
@@ -37,9 +37,9 @@ else:
...
@@ -37,9 +37,9 @@ else:
class
TFGPT2ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
class
TFGPT2ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
#
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel,
all_model_classes
=
(
TFGPT2Model
,
TFGPT2LMHeadModel
,
#
TFGPT2DoubleHeadsModel) if is_tf_available() else ()
TFGPT2DoubleHeadsModel
)
if
is_tf_available
()
else
()
all_model_classes
=
(
TFGPT2Model
,
TFGPT2LMHeadModel
)
if
is_tf_available
()
else
()
#
all_model_classes = (TFGPT2Model, TFGPT2LMHeadModel) if is_tf_available() else ()
class
TFGPT2ModelTester
(
object
):
class
TFGPT2ModelTester
(
object
):
...
@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -51,6 +51,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
use_token_type_ids
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
use_labels
=
True
,
use_mc_token_ids
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
...
@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -74,6 +75,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
use_mc_token_ids
=
use_mc_token_ids
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
...
@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -101,6 +103,10 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
if
self
.
use_token_type_ids
:
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
mc_token_ids
=
None
if
self
.
use_mc_token_ids
:
mc_token_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
num_choices
],
self
.
seq_length
)
sequence_labels
=
None
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
choice_labels
=
None
...
@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -126,7 +132,7 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2Model
(
config
=
config
)
model
=
TFGPT2Model
(
config
=
config
)
...
@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -162,25 +168,34 @@ class TFGPT2ModelTest(TFCommonTestCases.TFCommonModelTester):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_gpt2_double_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_double_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
mc_token_ids
,
*
args
):
pass
model
=
TFGPT2DoubleHeadsModel
(
config
=
config
)
# model = TFGPT2DoubleHeadsModel(config=config)
# inputs = {'input_ids': input_ids,
multiple_choice_inputs_ids
=
tf
.
tile
(
tf
.
expand_dims
(
input_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
# 'attention_mask': input_mask,
multiple_choice_input_mask
=
tf
.
tile
(
tf
.
expand_dims
(
input_mask
,
1
),
(
1
,
self
.
num_choices
,
1
))
# 'token_type_ids': token_type_ids}
multiple_choice_token_type_ids
=
tf
.
tile
(
tf
.
expand_dims
(
token_type_ids
,
1
),
(
1
,
self
.
num_choices
,
1
))
# seq_relationship_score, = model(inputs)[0]
# result = {
inputs
=
{
'input_ids'
:
multiple_choice_inputs_ids
,
# "seq_relationship_score": seq_relationship_score.numpy(),
'mc_token_ids'
:
mc_token_ids
,
# }
'attention_mask'
:
multiple_choice_input_mask
,
# self.parent.assertListEqual(
'token_type_ids'
:
multiple_choice_token_type_ids
}
# list(result["seq_relationship_score"].shape),
lm_logits
,
mc_logits
=
model
(
inputs
)[:
2
]
# [self.batch_size, 2])
result
=
{
"lm_logits"
:
lm_logits
.
numpy
(),
"mc_logits"
:
mc_logits
.
numpy
()
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"lm_logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
,
self
.
seq_length
,
self
.
vocab_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"mc_logits"
].
shape
),
[
self
.
batch_size
,
self
.
num_choices
])
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
mc_token_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
return
config
,
inputs_dict
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment