Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
45dc04f3
Commit
45dc04f3
authored
Oct 08, 2019
by
thomwolf
Browse files
tf model [WIP]
parent
24831477
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
493 additions
and
18 deletions
+493
-18
transformers/modeling_ctrl.py
transformers/modeling_ctrl.py
+18
-18
transformers/modeling_tf_ctrl.py
transformers/modeling_tf_ctrl.py
+475
-0
No files found.
transformers/modeling_ctrl.py
View file @
45dc04f3
...
@@ -111,7 +111,7 @@ class MultiHeadAttention(torch.nn.Module):
...
@@ -111,7 +111,7 @@ class MultiHeadAttention(torch.nn.Module):
v
=
self
.
split_into_heads
(
v
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
if
layer_past
is
not
None
:
past_key
,
past_value
=
layer_past
[
0
],
layer_past
[
1
]
past_key
,
past_value
=
layer_past
[
0
],
layer_past
[
1
]
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
1
)
k
=
torch
.
cat
((
past_key
,
k
),
dim
=-
2
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
v
=
torch
.
cat
((
past_value
,
v
),
dim
=-
2
)
present
=
torch
.
stack
((
k
,
v
))
present
=
torch
.
stack
((
k
,
v
))
...
@@ -167,25 +167,25 @@ class EncoderLayer(torch.nn.Module):
...
@@ -167,25 +167,25 @@ class EncoderLayer(torch.nn.Module):
class
CTRLPreTrainedModel
(
PreTrainedModel
):
class
CTRLPreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
CTRLConfig
pretrained_model_archive_map
=
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
"""
config_class
=
CTRLConfig
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
pretrained_model_archive_map
=
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
# Slightly different from the TF version which uses truncated_normal for initialization
base_model_prefix
=
"transformer"
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
def
_init_weights
(
self
,
module
):
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
...
...
transformers/modeling_tf_ctrl.py
0 → 100644
View file @
45dc04f3
# coding=utf-8
# Copyright 2018 Salesforce and HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" TF 2.0 CTRL model."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
logging
import
os
import
sys
from
io
import
open
import
numpy
as
np
import
tensorflow
as
tf
from
.configuration_ctrl
import
CTRLConfig
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.file_utils
import
add_start_docstrings
from
.modeling_tf_pytorch_utils
import
load_pytorch_checkpoint_in_tf2_model
logger
=
logging
.
getLogger
(
__name__
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
"ctrl"
:
"https://s3.amazonaws.com/models.huggingface.co/bert/ctrl-tf_model.h5"
}
def
angle_defn
(
pos
,
i
,
d_model_size
):
angle_rates
=
1
/
np
.
power
(
10000
,
(
2
*
(
i
//
2
))
/
np
.
float32
(
d_model_size
))
return
pos
*
angle_rates
def
positional_encoding
(
position
,
d_model_size
,
dtype
):
# create the sinusoidal pattern for the positional encoding
angle_rads
=
angle_defn
(
np
.
arange
(
position
)[:,
np
.
newaxis
],
np
.
arange
(
d_model_size
)[
np
.
newaxis
,
:],
d_model_size
)
sines
=
np
.
sin
(
angle_rads
[:,
0
::
2
])
cosines
=
np
.
cos
(
angle_rads
[:,
1
::
2
])
pos_encoding
=
tf
.
cast
(
np
.
concatenate
([
sines
,
cosines
],
axis
=-
1
)[
np
.
newaxis
,
...],
dtype
=
tf
.
float32
)
return
pos_encoding
def
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
=
None
,
head_mask
=
None
):
# calculate attention
matmul_qk
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
dk
=
tf
.
cast
(
tf
.
shape
(
k
)[
-
1
],
tf
.
float32
)
scaled_attention_logits
=
matmul_qk
/
tf
.
math
.
sqrt
(
dk
)
if
mask
is
not
None
:
scaled_attention_logits
+=
(
mask
*
-
1e4
)
if
attention_mask
is
not
None
:
# Apply the attention mask
scaled_attention_logits
=
scaled_attention_logits
+
attention_mask
attention_weights
=
tf
.
nn
.
softmax
(
scaled_attention_logits
,
axis
=-
1
)
# Mask heads if we want to
if
head_mask
is
not
None
:
attention_weights
=
attention_weights
*
head_mask
output
=
tf
.
matmul
(
attention_weights
,
v
)
return
output
,
attention_weights
class
TFMultiHeadAttention
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
output_attentions
=
False
,
**
kwargs
):
super
(
TFMultiHeadAttention
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
output_attentions
self
.
num_heads
=
num_heads
self
.
d_model_size
=
d_model_size
self
.
depth
=
int
(
d_model_size
/
self
.
num_heads
)
self
.
Wq
=
tf
.
keras
.
layers
.
Dense
(
d_model_size
,
name
=
'Wq'
)
self
.
Wk
=
tf
.
keras
.
layers
.
Dense
(
d_model_size
,
name
=
'Wk'
)
self
.
Wv
=
tf
.
keras
.
layers
.
Dense
(
d_model_size
,
name
=
'Wv'
)
self
.
dense
=
tf
.
keras
.
layers
.
Dense
(
d_model_size
,
name
=
'dense'
)
def
split_into_heads
(
self
,
x
,
batch_size
):
x
=
tf
.
reshape
(
x
,
(
batch_size
,
-
1
,
self
.
num_heads
,
self
.
depth
))
return
tf
.
transpose
(
x
,
perm
=
[
0
,
2
,
1
,
3
])
def
call
(
self
,
inputs
,
training
=
False
)
v
,
k
,
q
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
batch_size
=
q
.
shape
[
0
]
q
=
self
.
Wq
(
q
)
k
=
self
.
Wk
(
k
)
v
=
self
.
Wv
(
v
)
q
=
self
.
split_into_heads
(
q
,
batch_size
)
k
=
self
.
split_into_heads
(
k
,
batch_size
)
v
=
self
.
split_into_heads
(
v
,
batch_size
)
if
layer_past
is
not
None
:
past_key
,
past_value
=
tf
.
unstack
(
layer_past
,
axis
=
1
)
k
=
tf
.
concat
((
past_key
,
k
),
dim
=-
2
)
v
=
tf
.
concat
((
past_value
,
v
),
dim
=-
2
)
present
=
tf
.
stack
((
k
,
v
),
axis
=
1
)
output
=
scaled_dot_product_attention
(
q
,
k
,
v
,
mask
,
attention_mask
,
head_mask
)
scaled_attention
=
tf
.
transpose
(
output
[
0
],
perm
=
[
0
,
2
,
1
,
3
])
attn
=
output
[
1
]
original_size_attention
=
tf
.
reshape
(
scaled_attention
,
(
batch_size
,
-
1
,
self
.
d_model_size
))
output
=
self
.
dense
(
original_size_attention
)
outputs
=
(
output
,
present
)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
attn
,)
return
outputs
def
point_wise_feed_forward_network
(
d_model_size
,
dff
):
return
tf
.
keras
.
Sequential
([
tf
.
keras
.
layers
.
Dense
(
dff
,
activation
=
'relu'
),
tf
.
keras
.
layers
.
Dense
(
d_model_size
)])
class
TFEncoderLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
d_model_size
,
num_heads
,
dff
,
rate
=
0.1
,
output_attentions
=
False
,
**
kwargs
):
super
(
TFEncoderLayer
,
self
).
__init__
(
**
kwargs
)
self
.
multi_head_attention
=
MultiHeadAttention
(
d_model_size
,
num_heads
,
output_attentions
)
self
.
ffn
=
point_wise_feed_forward_network
(
d_model_size
,
dff
)
self
.
layernorm1
=
torch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
self
.
layernorm2
=
torch
.
nn
.
LayerNorm
(
d_model_size
,
eps
=
1e-6
)
self
.
dropout1
=
torch
.
nn
.
Dropout
(
rate
)
self
.
dropout2
=
torch
.
nn
.
Dropout
(
rate
)
def
call
(
self
,
inputs
,
training
=
False
):
x
,
mask
,
layer_past
,
attention_mask
,
head_mask
=
inputs
normed
=
self
.
layernorm1
(
x
)
attn_outputs
=
self
.
multi_head_attention
(
normed
,
normed
,
normed
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
)
attn_output
=
attn_outputs
[
0
]
attn_output
=
self
.
dropout1
(
attn_output
,
training
=
training
)
out1
=
x
+
attn_output
out2
=
self
.
layernorm2
(
out1
)
ffn_output
=
self
.
ffn
(
out2
)
ffn_output
=
self
.
dropout2
(
ffn_output
,
training
=
training
)
out2
=
out1
+
ffn_output
outputs
=
(
out2
,)
+
attn_outputs
[
1
:]
return
outputs
class
TFCTRLPreTrainedModel
(
TFPreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
CTRLConfig
pretrained_model_archive_map
=
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
load_pt_weights
=
load_bert_pt_weights_in_tf2
def
_init_weights
(
self
,
module
):
""" Initialize the weights.
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
if
isinstance
(
module
,
(
nn
.
Linear
,
Conv1D
))
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
CTRL_START_DOCSTRING
=
r
""" CTRL model was proposed in
`CTRL: A Conditional Transformer Language Model for Controllable Generation`_
by Nitish Shirish Keskar*, Bryan McCann*, Lav R. Varshney, Caiming Xiong and Richard Socher.
It's a causal (unidirectional) transformer pre-trained using language modeling on a very large
corpus of ~140 GB of text data with the first token reserved as a control code (such as Links, Books, Wikipedia etc.).
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`CTRL: A Conditional Transformer Language Model for Controllable Generation`:
https://www.github.com/salesforce/ctrl
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~transformers.CTRLConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
CTRL_INPUTS_DOCSTRING
=
r
""" Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
CTRL is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.CTRLTokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**past**:
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past` output below). Can be used to speed up sequential decoding.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
A parallel sequence of tokens (can be used to indicate various portions of the inputs).
The embeddings from these tokens will be summed with the respective token embeddings.
Indices are selected in the vocabulary (unlike BERT which has a specific vocabulary for segment indices).
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare CTRL Model transformer outputting raw hidden-states without any specific head on top."
,
CTRL_START_DOCSTRING
,
CTRL_INPUTS_DOCSTRING
)
class
TFCTRLModel
(
TFCTRLPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the last layer of the model.
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFCTRLModel
,
self
).
__init__
(
**
kwargs
)
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
d_model_size
=
config
.
n_embd
self
.
num_layers
=
config
.
n_layer
self
.
pos_encoding
=
positional_encoding
(
config
.
n_positions
,
self
.
d_model_size
,
torch
.
float
)
self
.
output_attentions
=
config
.
output_attentions
self
.
w
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
n_embd
)
self
.
dropout
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
EncoderLayer
(
config
.
n_embd
,
config
.
n_head
,
config
.
dff
,
config
.
resid_pdrop
,
config
.
output_attentions
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layernorm
=
nn
.
LayerNorm
(
config
.
n_embd
,
eps
=
config
.
layer_norm_epsilon
)
self
.
init_weights
()
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
w
=
self
.
_get_resized_embeddings
(
self
.
w
,
new_num_tokens
)
return
self
.
w
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
h
[
layer
].
attn
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
input_shape
=
input_ids
.
size
()
input_ids
=
input_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
token_type_ids
is
not
None
:
token_type_ids
=
token_type_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
position_ids
is
not
None
:
position_ids
=
position_ids
.
view
(
-
1
,
input_shape
[
-
1
])
if
past
is
None
:
past_length
=
0
past
=
[
None
]
*
len
(
self
.
h
)
else
:
past_length
=
past
[
0
][
0
].
size
(
-
2
)
if
position_ids
is
None
:
position_ids
=
torch
.
arange
(
past_length
,
input_ids
.
size
(
-
1
)
+
past_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
# Attention mask.
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
.
view
(
-
1
,
input_shape
[
-
1
])
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_mask
=
attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
attention_mask
=
(
1.0
-
attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# head_mask has shape n_layer x batch x n_heads x N x N
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
n_layer
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
n_layer
x
=
self
.
w
(
input_ids
)
# x = embedded.unsqueeze(0) if len(input_ids.shape)<2 else embedded
seq_len
=
input_ids
.
shape
[
1
]
mask
=
torch
.
triu
(
torch
.
ones
(
seq_len
,
seq_len
),
1
).
to
(
x
.
device
)
x
*=
np
.
sqrt
(
self
.
d_model_size
)
pos_x
=
self
.
pos_encoding
[
position_ids
,
:].
to
(
x
.
device
)
x
+=
pos_x
x
=
self
.
dropout
(
x
)
output_shape
=
input_shape
+
(
x
.
size
(
-
1
),)
presents
=
()
all_hidden_states
=
()
all_attentions
=
[]
for
i
,
(
h
,
layer_past
)
in
enumerate
(
zip
(
self
.
h
,
past
)):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
.
view
(
*
output_shape
),)
outputs
=
h
(
x
,
mask
,
layer_past
=
layer_past
,
attention_mask
=
attention_mask
,
head_mask
=
head_mask
[
i
])
x
,
present
=
outputs
[:
2
]
presents
=
presents
+
(
present
,)
if
self
.
output_attentions
:
all_attentions
.
append
(
outputs
[
2
])
x
=
self
.
layernorm
(
x
)
x
=
x
.
view
(
*
output_shape
)
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
x
,)
outputs
=
(
x
,
presents
)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
# let the number of heads free (-1) so we can extract attention even after head pruning
attention_output_shape
=
input_shape
[:
-
1
]
+
(
-
1
,)
+
all_attentions
[
0
].
shape
[
-
2
:]
all_attentions
=
tuple
(
t
.
view
(
*
attention_output_shape
)
for
t
in
all_attentions
)
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
@
add_start_docstrings
(
"""The CTRL Model transformer with a language modeling head on top
(linear layer with weights tied to the input embeddings). """
,
CTRL_START_DOCSTRING
,
CTRL_INPUTS_DOCSTRING
)
class
CTRLLMHeadModel
(
CTRLPreTrainedModel
):
r
"""
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for language modeling.
Note that the labels **are shifted** inside the model, i.e. you can set ``lm_labels = input_ids``
Indices are selected in ``[-1, 0, ..., config.vocab_size]``
All labels set to ``-1`` are ignored (masked), the loss is only
computed for labels in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**past**:
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
that contains pre-computed hidden-states (key and values in the attention blocks).
Can be used (see `past` input) to speed up sequential decoding.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import torch
from transformers import CTRLTokenizer, CTRLLMHeadModel
tokenizer = CTRLTokenizer.from_pretrained('ctrl')
model = CTRLLMHeadModel.from_pretrained('ctrl')
input_ids = torch.tensor(tokenizer.encode("Links Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=input_ids)
loss, logits = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
CTRLLMHeadModel
,
self
).
__init__
(
config
)
self
.
transformer
=
CTRLModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
True
)
self
.
init_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
""" Make sure we are sharing the input and output embeddings.
Export to TorchScript can't handle parameter sharing so we are cloning them instead.
"""
self
.
_tie_or_clone_weights
(
self
.
lm_head
,
self
.
transformer
.
w
)
def
forward
(
self
,
input_ids
,
past
=
None
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
labels
=
None
):
transformer_outputs
=
self
.
transformer
(
input_ids
,
past
=
past
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
hidden_states
=
transformer_outputs
[
0
]
lm_logits
=
self
.
lm_head
(
hidden_states
)
outputs
=
(
lm_logits
,)
+
transformer_outputs
[
1
:]
if
labels
is
not
None
:
# Shift so that tokens < n predict n
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
labels
[...,
1
:].
contiguous
()
# Flatten the tokens
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (loss), lm_logits, presents, (all hidden_states), (attentions)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment