Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
64d83c7a
Commit
64d83c7a
authored
Sep 05, 2019
by
thomwolf
Browse files
WIP
parent
01597e5b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
668 additions
and
0 deletions
+668
-0
pytorch_transformers/modeling_tf_gpt2.py
pytorch_transformers/modeling_tf_gpt2.py
+650
-0
pytorch_transformers/modeling_tf_utils.py
pytorch_transformers/modeling_tf_utils.py
+18
-0
No files found.
pytorch_transformers/modeling_tf_gpt2.py
0 → 100644
View file @
64d83c7a
# coding=utf-8
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 OpenAI GPT-2 model. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
json
import
logging
import
math
import
os
import
sys
from
io
import
open
import
numpy
as
np
import
tensorflow
as
tf
from
.modeling_tf_utils
import
TFPreTrainedModel
from
.configuration_gpt2
import
GPT2Config
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"gpt2"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-tf_model.h5"
,
"gpt2-medium"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-medium-tf_model.h5"
,
"gpt2-large"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-large-tf_model.h5"
}
def
load_gpt2_pt_weights_in_tf
(
tf_model
,
config
,
pytorch_checkpoint_path
):
""" Load pytorch checkpoints in a TF 2.0 model and save it using HDF5 format
We use HDF5 to easily do transfer learning
(see https://github.com/tensorflow/tensorflow/blob/ee16fcac960ae660e0e4496658a366e2f745e1f0/tensorflow/python/keras/engine/network.py#L1352-L1357).
"""
try
:
import
re
import
torch
import
numpy
from
tensorflow.python.keras
import
backend
as
K
except
ImportError
:
logger
.
error
(
"Loading a PyTorch model in TensorFlow, requires PyTorch to be installed. Please see "
"https://pytorch.org/ for installation instructions."
)
raise
pt_path
=
os
.
path
.
abspath
(
pytorch_checkpoint_path
)
logger
.
info
(
"Loading PyTorch weights from {}"
.
format
(
pt_path
))
# Load pytorch model
state_dict
=
torch
.
load
(
pt_path
,
map_location
=
'cpu'
)
inputs_list
=
[[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]]
tf_inputs
=
tf
.
constant
(
inputs_list
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# build the network
symbolic_weights
=
tf_model
.
trainable_weights
+
tf_model
.
non_trainable_weights
weight_value_tuples
=
[]
for
symbolic_weight
in
symbolic_weights
:
name
=
symbolic_weight
.
name
name
=
name
.
replace
(
'cls_mlm'
,
'cls'
)
# We had to split this layer in two in the TF model to be
name
=
name
.
replace
(
'cls_nsp'
,
'cls'
)
# able to do transfer learning (Keras only allow to remove full layers)
name
=
name
.
replace
(
':0'
,
''
)
name
=
name
.
replace
(
'layer_'
,
'layer/'
)
name
=
name
.
split
(
'/'
)
name
=
name
[
1
:]
transpose
=
bool
(
name
[
-
1
]
==
'kernel'
)
if
name
[
-
1
]
==
'kernel'
or
name
[
-
1
]
==
'embeddings'
:
name
[
-
1
]
=
'weight'
name
=
'.'
.
join
(
name
)
assert
name
in
state_dict
array
=
state_dict
[
name
].
numpy
()
if
transpose
:
array
=
numpy
.
transpose
(
array
)
try
:
assert
list
(
symbolic_weight
.
shape
)
==
list
(
array
.
shape
)
except
AssertionError
as
e
:
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
K
.
batch_set_value
(
weight_value_tuples
)
tfo
=
tf_model
(
tf_inputs
,
training
=
False
)
# Make sure restore ops are run
return
tf_model
def
gelu
(
x
):
"""Gaussian Error Linear Unit.
This is a smoother version of the RELU.
Original paper: https://arxiv.org/abs/1606.08415
Args:
x: float Tensor to perform activation.
Returns:
`x` with the GELU activation applied.
"""
cdf
=
0.5
*
(
1.0
+
tf
.
tanh
(
(
np
.
sqrt
(
2
/
np
.
pi
)
*
(
x
+
0.044715
*
tf
.
pow
(
x
,
3
)))))
return
x
*
cdf
class
TFAttention
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
nx
,
n_ctx
,
config
,
scale
=
False
):
super
(
Attention
,
self
).
__init__
()
self
.
output_attentions
=
config
.
output_attentions
n_state
=
nx
# in Attention: n_state=768 (nx=n_embd)
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
assert
n_state
%
config
.
n_head
==
0
self
.
register_buffer
(
"bias"
,
torch
.
tril
(
torch
.
ones
(
n_ctx
,
n_ctx
)).
view
(
1
,
1
,
n_ctx
,
n_ctx
))
self
.
n_head
=
config
.
n_head
self
.
split_size
=
n_state
self
.
scale
=
scale
self
.
c_attn
=
Conv1D
(
n_state
*
3
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
# Convert to set and emove already pruned heads
for
head
in
heads
:
# Compute how many pruned heads are before the head and move the index accordingly
head
=
head
-
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index_attn
=
torch
.
cat
([
index
,
index
+
self
.
split_size
,
index
+
(
2
*
self
.
split_size
)])
# Prune conv1d layers
self
.
c_attn
=
prune_conv1d_layer
(
self
.
c_attn
,
index_attn
,
dim
=
1
)
self
.
c_proj
=
prune_conv1d_layer
(
self
.
c_proj
,
index
,
dim
=
0
)
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
if
self
.
scale
:
w
=
w
/
math
.
sqrt
(
v
.
size
(
-
1
))
nd
,
ns
=
w
.
size
(
-
2
),
w
.
size
(
-
1
)
b
=
self
.
bias
[:,
:,
ns
-
nd
:
ns
,
:
ns
]
w
=
w
*
b
-
1e4
*
(
1
-
b
)
w
=
nn
.
Softmax
(
dim
=-
1
)(
w
)
w
=
self
.
attn_dropout
(
w
)
# Mask heads if we want to
if
head_mask
is
not
None
:
w
=
w
*
head_mask
outputs
=
[
torch
.
matmul
(
w
,
v
)]
if
self
.
output_attentions
:
outputs
.
append
(
w
)
return
outputs
def
merge_heads
(
self
,
x
):
x
=
x
.
permute
(
0
,
2
,
1
,
3
).
contiguous
()
new_x_shape
=
x
.
size
()[:
-
2
]
+
(
x
.
size
(
-
2
)
*
x
.
size
(
-
1
),)
return
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct merge_states
def
split_heads
(
self
,
x
,
k
=
False
):
new_x_shape
=
x
.
size
()[:
-
1
]
+
(
self
.
n_head
,
x
.
size
(
-
1
)
//
self
.
n_head
)
x
=
x
.
view
(
*
new_x_shape
)
# in Tensorflow implem: fct split_states
if
k
:
return
x
.
permute
(
0
,
2
,
3
,
1
)
# (batch, head, head_features, seq_length)
else
:
return
x
.
permute
(
0
,
2
,
1
,
3
)
# (batch, head, seq_length, head_features)
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
x
=
self
.
c_attn
(
x
)
query
,
key
,
value
=
x
.
split
(
self
.
split_size
,
dim
=
2
)
query
=
self
.
split_heads
(
query
)
key
=
self
.
split_heads
(
key
,
k
=
True
)
value
=
self
.
split_heads
(
value
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
].
transpose
(
-
2
,
-
1
),
layer_past
[
1
]
# transpose back cf below
key
=
torch
.
cat
((
past_key
,
key
),
dim
=-
1
)
value
=
torch
.
cat
((
past_value
,
value
),
dim
=-
2
)
present
=
torch
.
stack
((
key
.
transpose
(
-
2
,
-
1
),
value
))
# transpose to have same shapes for stacking
attn_outputs
=
self
.
_attn
(
query
,
key
,
value
,
head_mask
)
a
=
attn_outputs
[
0
]
a
=
self
.
merge_heads
(
a
)
a
=
self
.
c_proj
(
a
)
a
=
self
.
resid_dropout
(
a
)
outputs
=
[
a
,
present
]
+
attn_outputs
[
1
:]
return
outputs
# a, present, (attentions)
class
MLP
(
nn
.
Module
):
def
__init__
(
self
,
n_state
,
config
):
# in MLP: n_state=3072 (4 * n_embd)
super
(
MLP
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
c_fc
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
nx
,
n_state
)
self
.
act
=
gelu
self
.
dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
def
forward
(
self
,
x
):
h
=
self
.
act
(
self
.
c_fc
(
x
))
h2
=
self
.
c_proj
(
h
)
return
self
.
dropout
(
h2
)
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
(
Block
,
self
).
__init__
()
nx
=
config
.
n_embd
self
.
ln_1
=
nn
.
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_2
=
nn
.
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
,
head_mask
=
None
):
output_attn
=
self
.
attn
(
self
.
ln_1
(
x
),
layer_past
=
layer_past
,
head_mask
=
head_mask
)
a
=
output_attn
[
0
]
# output_attn: a, present, (attentions)
x
=
x
+
a
m
=
self
.
mlp
(
self
.
ln_2
(
x
))
x
=
x
+
m
outputs
=
[
x
]
+
output_attn
[
1
:]
return
outputs
# x, present, (attentions)
class
GPT2PreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
GPT2Config
pretrained_model_archive_map
=
GPT2_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_gpt2
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
GPT2PreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
GPT2_START_DOCSTRING
=
r
""" OpenAI GPT-2 model was proposed in
`Language Models are Unsupervised Multitask Learners`_
by Alec Radford*, Jeffrey Wu*, Rewon Child, David Luan, Dario Amodei** and Ilya Sutskever**.
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~40 GB of text data.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Language Models are Unsupervised Multitask Learners`:
https://openai.com/blog/better-language-models/
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~pytorch_transformers.GPT2Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~pytorch_transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
GPT2_INPUTS_DOCSTRING
=
r
""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
GPT-2 is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare GPT2 Model transformer outputing raw hidden-states without any specific head on top."
,
GPT2_START_DOCSTRING
,
GPT2_INPUTS_DOCSTRING
)
class
GPT2Model
(
GPT2PreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2Model.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def
__init__
(
self
,
config
):
super
(
GPT2Model
,
self
).
__init__
(
config
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
output_attentions
=
config
.
output_attentions
self
.
wte
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
)
self
.
wpe
=
nn
.
Embedding
(
config
.
n_positions
,
config
.
n_embd
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
ln_f
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
init_weights
()
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
wte
=
self
.
_get_resized_embeddings
(
self
.
wte
,
new_num_tokens
)
return
self
.
wte
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
past
=
None
,
head_mask
=
None
):
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_ids
.
size
(
-
1
))
position_ids
=
position_ids
.
view
(
-
1
,
position_ids
.
size
(
-
1
))
inputs_embeds
=
self
.
wte
(
input_ids
)
position_embeds
=
self
.
wpe
(
position_ids
)
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
token_type_ids
.
size
(
-
1
))
token_type_embeds
=
self
.
wte
(
token_type_ids
)
else
:
token_type_embeds
=
0
hidden_states
=
inputs_embeds
+
position_embeds
+
token_type_embeds
hidden_states
=
self
.
drop
(
hidden_states
)
output_shape
=
input_shape
+
(
hidden_states
.
size
(
-
1
),)
presents
=
()
all_attentions
=
[]
all_hidden_states
=
()
for
i
,
(
block
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
.
view
(
*
output_shape
),)
outputs
=
block
(
hidden_states
,
layer_past
,
head_mask
[
i
])
hidden_states
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
hidden_states
=
self
.
ln_f
(
hidden_states
)
hidden_states
=
hidden_states
.
view
(
*
output_shape
)
# Add last hidden state
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
outputs
=
(
hidden_states
,
presents
)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
# last hidden state, presents, (all hidden_states), (attentions)
@
add_start_docstrings
(
"""The GPT2 Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """
,
GPT2_START_DOCSTRING
,
GPT2_INPUTS_DOCSTRING
)
class
GPT2LMHeadModel
(
GPT2PreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import torch
from pytorch_transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
GPT2LMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
wte
)
def
forward
(
self
,
input_ids
,
position_ids
=
None
,
token_type_ids
=
None
,
labels
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
past
=
past
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
@
add_start_docstrings
(
"""The GPT2 Model transformer with a language modeling and a multiple-choice classification
head on top e.g. for RocStories/SWAG tasks. The two heads are two linear layers.
The language modeling head has its weights tied to the input embeddings,
the classification head takes as input the input of a specified classification token index in the input sequence).
"""
,
GPT2_START_DOCSTRING
)
class
GPT2DoubleHeadsModel
(
GPT2PreTrainedModel
):
r
""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to score.
Indices can be obtained using :class:`pytorch_transformers.BPT2Tokenizer`.
See :func:`pytorch_transformers.PreTrainedTokenizer.encode` and
:func:`pytorch_transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**mc_token_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices)``:
Index of the classification token in each input sequence.
Selected in the range ``[0, input_ids.size(-1) - 1[``.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
**mc_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**lm_loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**mc_loss**: (`optional`, returned when ``multiple_choice_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Multiple choice classification loss.
**lm_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**mc_prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, num_choices)``
Prediction scores of the multiplechoice classification head (scores for each choice before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import torch
from pytorch_transformers import GPT2Tokenizer, GPT2DoubleHeadsModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2DoubleHeadsModel.from_pretrained('gpt2')
# Add a [CLS] to the vocabulary (we should train it also!)
tokenizer.add_special_tokens({'cls_token': '[CLS]'})
model.resize_token_embeddings(len(tokenizer)) # Update the model embeddings with the new vocabulary size
print(tokenizer.cls_token_id, len(tokenizer)) # The newly token the last token of the vocabulary
choices = ["Hello, my dog is cute [CLS]", "Hello, my cat is cute [CLS]"]
encoded_choices = [tokenizer.encode(s) for s in choices]
cls_token_location = [tokens.index(tokenizer.cls_token_id) for tokens in encoded_choices]
input_ids = torch.tensor(encoded_choices).unsqueeze(0) # Batch size: 1, number of choices: 2
mc_token_ids = torch.tensor([cls_token_location]) # Batch size: 1
outputs = model(input_ids, mc_token_ids=mc_token_ids)
lm_prediction_scores, mc_prediction_scores = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
GPT2DoubleHeadsModel
,
self
).
__init__
(
config
)
self
.
transformer
=
GPT2Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
multiple_choice_head
=
SequenceSummary
(
config
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
wte
)
def
forward
(
self
,
input_ids
,
mc_token_ids
=
None
,
lm_labels
=
None
,
mc_labels
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
past
=
None
,
head_mask
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
,
past
=
past
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
mc_logits
=
self
.
multiple_choice_head
(
hidden_states
,
mc_token_ids
).
squeeze
(
-
1
)
outputs
=
(
lm_logits
,
mc_logits
)
+
transformer_outputs
[
1
:]
if
mc_labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
mc_logits
.
view
(
-
1
,
mc_logits
.
size
(
-
1
)),
mc_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm loss), (mc loss), lm logits, mc logits, presents, (all hidden_states), (attentions)
pytorch_transformers/modeling_tf_utils.py
View file @
64d83c7a
...
...
@@ -255,3 +255,21 @@ class TFPreTrainedModel(tf.keras.Model):
ret
=
model
(
inputs
,
training
=
False
)
# Make sure restore ops are run
return
model
class
TFConv1D
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
nf
,
nx
):
""" TFConv1D layer as defined by Radford et al. for OpenAI GPT (and also used in GPT-2)
Basically works like a Linear layer but the weights are transposed
"""
super
(
TFConv1D
,
self
).
__init__
()
self
.
nf
=
nf
w
=
torch
.
empty
(
nx
,
nf
)
nn
.
init
.
normal_
(
w
,
std
=
0.02
)
self
.
weight
=
nn
.
Parameter
(
w
)
self
.
bias
=
nn
.
Parameter
(
torch
.
zeros
(
nf
))
def
call
(
self
,
x
):
size_out
=
t
.
shape
(
x
)[:
-
1
]
+
(
self
.
nf
,)
x
=
tf
.
addmm
(
self
.
bias
,
x
.
view
(
-
1
,
x
.
size
(
-
1
)),
self
.
weight
)
x
=
x
.
view
(
*
size_out
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment