Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
72863735
Commit
72863735
authored
Sep 09, 2019
by
thomwolf
Browse files
WIP GPT2
parent
34f28b2a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
507 additions
and
426 deletions
+507
-426
pytorch_transformers/modeling_tf_bert.py
pytorch_transformers/modeling_tf_bert.py
+12
-107
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+162
-310
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+106
-4
pytorch_transformers/tests/modeling_gpt2_test.py
pytorch_transformers/tests/modeling_gpt2_test.py
+11
-5
pytorch_transformers/tests/modeling_tf_gpt2_test.py
pytorch_transformers/tests/modeling_tf_gpt2_test.py
+216
-0
No files found.
pytorch_transformers/modeling_tf_bert.py
View file @
72863735
...
...
@@ -704,20 +704,7 @@ class TFBertModel(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForPreTraining
(
TFBertPreTrainedModel
):
r
"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when both ``masked_lm_labels`` and ``next_sentence_label`` are provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total loss as the sum of the masked language modeling loss and the next sequence prediction (classification) loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**seq_relationship_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, 2)``
...
...
@@ -762,15 +749,7 @@ class TFBertForPreTraining(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForMaskedLM
(
TFBertPreTrainedModel
):
r
"""
**masked_lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``masked_lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
...
...
@@ -786,8 +765,8 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForMaskedLM.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids
, masked_lm_labels=input_ids
)
loss,
prediction_scores = outputs[:2]
outputs = model(input_ids)
prediction_scores = outputs[:2]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
...
...
@@ -811,12 +790,6 @@ class TFBertForMaskedLM(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForNextSentencePrediction
(
TFBertPreTrainedModel
):
r
"""
**next_sentence_label**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair (see ``input_ids`` docstring)
Indices should be in ``[0, 1]``.
``0`` indicates sequence B is a continuation of sequence A,
``1`` indicates sequence B is a random sequence.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``next_sentence_label`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Next sequence prediction (classification) loss.
...
...
@@ -862,15 +835,7 @@ class TFBertForNextSentencePrediction(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForSequenceClassification
(
TFBertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the sequence classification/regression loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
If ``config.num_labels == 1`` a regression loss is computed (Mean-Square loss),
If ``config.num_labels > 1`` a classification loss is computed (Cross-Entropy).
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification (or regression if config.num_labels==1) loss.
**logits**: ``torch.FloatTensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
...
...
@@ -886,8 +851,7 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1]).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids)
loss, logits = outputs[:2]
"""
...
...
@@ -905,7 +869,8 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
if
training
:
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -915,53 +880,10 @@ class TFBertForSequenceClassification(TFBertPreTrainedModel):
@
add_start_docstrings
(
"""Bert Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RocStories/SWAG tasks. """
,
BERT_START_DOCSTRING
)
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForMultipleChoice
(
TFBertPreTrainedModel
):
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
To match pre-training, BERT input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Indices can be obtained using :class:`pytorch_transformers.BertTokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Mask to avoid performing attention on padding token indices.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**classification_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above).
Classification scores (before SoftMax).
...
...
@@ -979,8 +901,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
model = BertForMultipleChoice.from_pretrained('bert-base-uncased')
choices = ["Hello, my dog is cute", "Hello, my cat is amazing"]
input_ids = torch.tensor([tokenizer.encode(s) for s in choices]).unsqueeze(0) # Batch size 1, 2 choices
labels = torch.tensor(1).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids)
loss, classification_scores = outputs[:2]
"""
...
...
@@ -1025,7 +946,8 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
pooled_output
=
outputs
[
1
]
pooled_output
=
self
.
dropout
(
pooled_output
)
if
training
:
pooled_output
=
self
.
dropout
(
pooled_output
)
logits
=
self
.
classifier
(
pooled_output
)
reshaped_logits
=
tf
.
reshape
(
logits
,
(
-
1
,
num_choices
))
...
...
@@ -1039,13 +961,7 @@ class TFBertForMultipleChoice(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForTokenClassification
(
TFBertPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the token classification loss.
Indices should be in ``[0, ..., config.num_labels - 1]``.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
...
...
@@ -1061,8 +977,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForTokenClassification.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
outputs = model(input_ids)
loss, scores = outputs[:2]
"""
...
...
@@ -1080,7 +995,8 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
)
if
training
:
sequence_output
=
self
.
dropout
(
sequence_output
)
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
...
...
@@ -1093,18 +1009,7 @@ class TFBertForTokenClassification(TFBertPreTrainedModel):
BERT_START_DOCSTRING
,
BERT_INPUTS_DOCSTRING
)
class
TFBertForQuestionAnswering
(
TFBertPreTrainedModel
):
r
"""
**start_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the start of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
**end_positions**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for position (index) of the end of the labelled span for computing the token classification loss.
Positions are clamped to the length of the sequence (`sequence_length`).
Position outside of the sequence are not taken into account for computing the loss.
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
**start_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length,)``
...
...
pytorch_transformers/modeling_tf_gpt2.py
View file @
72863735
This diff is collapsed.
Click to expand it.
pytorch_transformers/modeling_tf_utils.py
View file @
72863735
...
...
@@ -273,15 +273,117 @@ class TFConv1D(tf.keras.layers.Layer):
mean
=
0.
,
stddev
=
0.02
))
self
.
bias
=
self
.
add_weight
(
"bias"
,
shape
=
[
self
.
nx
,
self
.
nf
],
shape
=
[
1
,
self
.
nf
],
initializer
=
tf
.
zeros_initializer
())
@
tf
.
function
def
call
(
self
,
x
):
size_out
=
tf
.
shape
(
x
)[:
-
1
]
+
(
self
.
nf
,)
bz
,
sl
=
shape_list
(
x
)[:
2
]
x
=
tf
.
reshape
(
x
,
[
-
1
,
tf
.
shape
(
x
)[
-
1
]
])
x
=
tf
.
reshape
(
x
,
[
-
1
,
self
.
nx
])
x
=
tf
.
matmul
(
x
,
self
.
weight
)
+
self
.
bias
x
=
tf
.
reshape
(
x
,
size_out
)
x
=
tf
.
reshape
(
x
,
[
bz
,
sl
,
self
.
nf
])
return
x
class
TFSequenceSummary
(
tf
.
keras
.
layers
.
Layer
):
r
""" Compute a single vector summary of a sequence hidden states according to various possibilities:
Args of the config class:
summary_type:
- 'last' => [default] take the last token hidden state (like XLNet)
- 'first' => take the first token hidden state (like Bert)
- 'mean' => take the mean of all tokens hidden states
- 'cls_index' => supply a Tensor of classification token position (GPT/GPT-2)
- 'attn' => Not implemented now, use multi-head attention
summary_use_proj: Add a projection after the vector extraction
summary_proj_to_labels: If True, the projection outputs to config.num_labels classes (otherwise to hidden_size). Default: False.
summary_activation: 'tanh' => add a tanh activation to the output, Other => no activation. Default
summary_first_dropout: Add a dropout before the projection and activation
summary_last_dropout: Add a dropout after the projection and activation
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFSequenceSummary
,
self
).
__init__
(
**
kwargs
)
self
.
summary_type
=
config
.
summary_type
if
hasattr
(
config
,
'summary_use_proj'
)
else
'last'
if
self
.
summary_type
==
'attn'
:
# We should use a standard multi-head attention module with absolute positional embedding for that.
# Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
self
.
summary
=
tf
.
keras
.
layers
.
Identity
(
name
=
'summary'
)
if
hasattr
(
config
,
'summary_use_proj'
)
and
config
.
summary_use_proj
:
if
hasattr
(
config
,
'summary_proj_to_labels'
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
num_classes
=
config
.
num_labels
else
:
num_classes
=
config
.
hidden_size
self
.
summary
=
tf
.
keras
.
layers
.
Dense
(
num_classes
,
name
=
'summary'
)
self
.
activation
=
None
if
hasattr
(
config
,
'summary_activation'
)
and
config
.
summary_activation
==
'tanh'
:
self
.
activation
=
tf
.
keras
.
layers
.
Tanh
()
self
.
first_dropout
=
None
if
hasattr
(
config
,
'summary_first_dropout'
)
and
config
.
summary_first_dropout
>
0
:
self
.
first_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_first_dropout
)
self
.
last_dropout
=
None
if
hasattr
(
config
,
'summary_last_dropout'
)
and
config
.
summary_last_dropout
>
0
:
self
.
last_dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
summary_last_dropout
)
@
tf
.
function
def
call
(
self
,
inputs
,
training
=
False
):
""" hidden_states: float Tensor in shape [bsz, seq_len, hidden_size], the hidden-states of the last layer.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
if summary_type == 'cls_index' and cls_index is None:
we take the last token of the sequence as classification token
"""
if
not
isinstance
(
inputs
,
(
dict
,
tuple
,
list
)):
hidden_states
=
inputs
cls_index
=
None
elif
isinstance
(
inputs
,
(
tuple
,
list
)):
hidden_states
=
inputs
[
0
]
cls_index
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
None
assert
len
(
inputs
)
<=
2
,
"Too many inputs."
else
:
input_ids
=
inputs
.
get
(
'input_ids'
)
cls_index
=
inputs
.
get
(
'cls_index'
,
None
)
if
self
.
summary_type
==
'last'
:
output
=
hidden_states
[:,
-
1
]
elif
self
.
summary_type
==
'first'
:
output
=
hidden_states
[:,
0
]
elif
self
.
summary_type
==
'mean'
:
output
=
tf
.
mean
(
hidden_states
,
axis
=
1
)
elif
self
.
summary_type
==
'cls_index'
:
if
cls_index
is
None
:
cls_index
=
tf
.
fill
(
tf
.
shape
(
hidden_states
[...,
:
1
,
:]),
hidden_states
.
shape
[
-
2
]
-
1
,
dtype
=
tf
.
int32
)
else
:
cls_index
=
cls_index
[...,
tf
.
newaxis
,
tf
.
newaxis
]
cls_index
=
cls_index
.
expand
((
-
1
,)
*
(
cls_index
.
dim
()
-
1
)
+
(
hidden_states
.
size
(
-
1
),))
# shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
output
=
hidden_states
.
gather
(
-
2
,
cls_index
).
squeeze
(
-
2
)
# shape (bsz, XX, hidden_size)
elif
self
.
summary_type
==
'attn'
:
raise
NotImplementedError
if
training
and
self
.
first_dropout
is
not
None
:
output
=
self
.
first_dropout
(
output
)
output
=
self
.
summary
(
output
)
if
self
.
activation
is
not
None
:
output
=
self
.
activation
(
output
)
if
training
and
self
.
last_dropout
is
not
None
:
output
=
self
.
last_dropout
(
output
)
return
output
def
shape_list
(
x
):
"""Deal with dynamic shape in tensorflow cleanly."""
static
=
x
.
shape
.
as_list
()
dynamic
=
tf
.
shape
(
x
)
return
[
dynamic
[
i
]
if
s
is
None
else
s
for
i
,
s
in
enumerate
(
static
)]
pytorch_transformers/tests/modeling_gpt2_test.py
View file @
72863735
...
...
@@ -44,6 +44,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
seq_length
=
7
,
is_training
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
...
...
@@ -66,6 +67,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
...
...
@@ -86,6 +88,10 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
...
...
@@ -115,14 +121,14 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
check_loss_output
(
self
,
result
):
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
GPT2Model
(
config
=
config
)
model
.
eval
()
...
...
@@ -139,7 +145,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertEqual
(
len
(
result
[
"presents"
]),
config
.
n_layer
)
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
GPT2LMHeadModel
(
config
)
model
.
eval
()
...
...
@@ -157,7 +163,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
list
(
result
[
"lm_logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
head_mask
,
token_type_ids
,
*
args
):
def
create_and_check_double_lm_head_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
GPT2DoubleHeadsModel
(
config
)
model
.
eval
()
...
...
@@ -177,7 +183,7 @@ class GPT2ModelTest(CommonTestCases.CommonModelTester):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
...
...
pytorch_transformers/tests/modeling_tf_gpt2_test.py
0 → 100644
View file @
72863735
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
unittest
import
shutil
import
pytest
import
sys
from
.modeling_tf_common_test
import
(
TFCommonTestCases
,
ids_tensor
)
from
.configuration_common_test
import
ConfigTester
from
pytorch_transformers
import
GPT2Config
,
is_tf_available
try
:
import
tensorflow
as
tf
from
pytorch_transformers.modeling_tf_gpt2
import
(
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
,
TF_GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
)
except
ImportError
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require TensorFlow"
)
class
TFGPT2ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFGPT2Model
,
TFGPT2LMHeadModel
,
TFGPT2DoubleHeadsModel
)
if
is_tf_available
()
else
()
class
TFGPT2ModelTester
(
object
):
def
__init__
(
self
,
parent
,
batch_size
=
13
,
seq_length
=
7
,
is_training
=
True
,
use_token_type_ids
=
True
,
use_input_mask
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
):
self
.
parent
=
parent
self
.
batch_size
=
batch_size
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_input_mask
=
use_input_mask
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
input_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
input_mask
=
None
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
GPT2Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
n_embd
=
self
.
hidden_size
,
n_layer
=
self
.
num_hidden_layers
,
n_head
=
self
.
num_attention_heads
,
# intermediate_size=self.intermediate_size,
# hidden_act=self.hidden_act,
# hidden_dropout_prob=self.hidden_dropout_prob,
# attention_probs_dropout_prob=self.attention_probs_dropout_prob,
n_positions
=
self
.
max_position_embeddings
,
n_ctx
=
self
.
max_position_embeddings
# type_vocab_size=self.type_vocab_size,
# initializer_range=self.initializer_range
)
head_mask
=
ids_tensor
([
self
.
num_hidden_layers
,
self
.
num_attention_heads
],
2
)
return
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_gpt2_model
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2Model
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
=
model
(
inputs
)[
0
]
inputs
=
[
input_ids
,
None
,
input_mask
]
# None is the input for 'past'
sequence_output
=
model
(
inputs
)[
0
]
sequence_output
=
model
(
input_ids
)[
0
]
result
=
{
"sequence_output"
:
sequence_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
def
create_and_check_gpt2_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
model
=
TFGPT2LMHeadModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
prediction_scores
=
model
(
inputs
)[
0
]
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"prediction_scores"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
vocab_size
])
def
create_and_check_gpt2_double_head
(
self
,
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
*
args
):
pass
# model = TFGPT2DoubleHeadsModel(config=config)
# inputs = {'input_ids': input_ids,
# 'attention_mask': input_mask,
# 'token_type_ids': token_type_ids}
# seq_relationship_score, = model(inputs)[0]
# result = {
# "seq_relationship_score": seq_relationship_score.numpy(),
# }
# self.parent.assertListEqual(
# list(result["seq_relationship_score"].shape),
# [self.batch_size, 2])
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
input_mask
,
head_mask
,
token_type_ids
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFGPT2ModelTest
.
TFGPT2ModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
GPT2Config
,
hidden_size
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
def
test_gpt2_model
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_model
(
*
config_and_inputs
)
def
test_gpt2_lm_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_lm_head
(
*
config_and_inputs
)
def
test_gpt2_double_head
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_gpt2_double_head
(
*
config_and_inputs
)
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/pytorch_transformers_test/"
for
model_name
in
list
(
TF_gpt2_PRETRAINED_MODEL_ARCHIVE_MAP
.
keys
())[:
1
]:
model
=
TFGPT2Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
if
__name__
==
"__main__"
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment