Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
60a5babd
Commit
60a5babd
authored
Nov 05, 2019
by
thomwolf
Browse files
adding files
parent
e99071f1
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1278 additions
and
0 deletions
+1278
-0
transformers/configuration_t5.py
transformers/configuration_t5.py
+130
-0
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
+65
-0
transformers/modeling_t5.py
transformers/modeling_t5.py
+373
-0
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+496
-0
transformers/tokenization_t5.py
transformers/tokenization_t5.py
+214
-0
No files found.
transformers/configuration_t5.py
0 → 100644
View file @
60a5babd
# coding=utf-8
# Copyright 2010, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" T5 model configuration """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
sys
import
six
from
io
import
open
from
.configuration_utils
import
PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-config.json"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-config.json"
,
}
class
T5Config
(
PretrainedConfig
):
r
"""
:class:`~transformers.T5Config` is the configuration class to store the configuration of a
`T5Model`.
Arguments:
vocab_size_or_config_json_file: Vocabulary size of `inputs_ids` in `T5Model`.
hidden_size: Size of the encoder layers and the pooler layer.
num_hidden_layers: Number of hidden layers in the Transformer encoder.
num_attention_heads: Number of attention heads for each attention layer in
the Transformer encoder.
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu", "swish" and "gelu_new" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
probabilities.
max_position_embeddings: The maximum sequence length that this model might
ever be used with. Typically set this to something large just in case
(e.g., 512 or 1024 or 2048).
type_vocab_size: The vocabulary size of the `token_type_ids` passed into
`T5Model`.
initializer_range: The sttdev of the truncated_normal_initializer for
initializing all weight matrices.
layer_norm_eps: The epsilon used by LayerNorm.
"""
pretrained_config_archive_map
=
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
def
__init__
(
self
,
vocab_size_or_config_json_file
=
50257
,
n_positions
=
1024
,
n_ctx
=
1024
,
n_embd
=
768
,
n_layer
=
12
,
n_head
=
12
,
resid_pdrop
=
0.1
,
embd_pdrop
=
0.1
,
attn_pdrop
=
0.1
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
num_labels
=
1
,
summary_type
=
'cls_index'
,
summary_use_proj
=
True
,
summary_activation
=
None
,
summary_proj_to_labels
=
True
,
summary_first_dropout
=
0.1
,
**
kwargs
):
super
(
T5Config
,
self
).
__init__
(
**
kwargs
)
self
.
vocab_size
=
vocab_size_or_config_json_file
if
isinstance
(
vocab_size_or_config_json_file
,
six
.
string_types
)
else
-
1
self
.
n_ctx
=
n_ctx
self
.
n_positions
=
n_positions
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
resid_pdrop
=
resid_pdrop
self
.
embd_pdrop
=
embd_pdrop
self
.
attn_pdrop
=
attn_pdrop
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
summary_type
=
summary_type
self
.
summary_use_proj
=
summary_use_proj
self
.
summary_activation
=
summary_activation
self
.
summary_first_dropout
=
summary_first_dropout
self
.
summary_proj_to_labels
=
summary_proj_to_labels
if
isinstance
(
vocab_size_or_config_json_file
,
six
.
string_types
):
with
open
(
vocab_size_or_config_json_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
json_config
=
json
.
loads
(
reader
.
read
())
for
key
,
value
in
json_config
.
items
():
self
.
__dict__
[
key
]
=
value
elif
not
isinstance
(
vocab_size_or_config_json_file
,
int
):
raise
ValueError
(
"First argument must be either a vocabulary size (int)"
"or the path to a pretrained model config file (str)"
)
@
property
def
max_position_embeddings
(
self
):
return
self
.
n_positions
@
property
def
hidden_size
(
self
):
return
self
.
n_embd
@
property
def
num_attention_heads
(
self
):
return
self
.
n_head
@
property
def
num_hidden_layers
(
self
):
return
self
.
n_layer
transformers/convert_t5_original_tf_checkpoint_to_pytorch.py
0 → 100755
View file @
60a5babd
# coding=utf-8
# Copyright 2018 The T5 authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Convert T5 checkpoint."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
argparse
import
torch
from
transformers
import
T5Config
,
T5ForPreTraining
,
load_tf_weights_in_t5
import
logging
logging
.
basicConfig
(
level
=
logging
.
INFO
)
def
convert_tf_checkpoint_to_pytorch
(
tf_checkpoint_path
,
t5_config_file
,
pytorch_dump_path
):
# Initialise PyTorch model
config
=
T5Config
.
from_json_file
(
t5_config_file
)
print
(
"Building PyTorch model from configuration: {}"
.
format
(
str
(
config
)))
model
=
T5ForPreTraining
(
config
)
# Load weights from tf checkpoint
load_tf_weights_in_t5
(
model
,
config
,
tf_checkpoint_path
)
# Save pytorch-model
print
(
"Save PyTorch model to {}"
.
format
(
pytorch_dump_path
))
torch
.
save
(
model
.
state_dict
(),
pytorch_dump_path
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
## Required parameters
parser
.
add_argument
(
"--tf_checkpoint_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the TensorFlow checkpoint path."
)
parser
.
add_argument
(
"--t5_config_file"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"The config json file corresponding to the pre-trained T5 model.
\n
"
"This specifies the model architecture."
)
parser
.
add_argument
(
"--pytorch_dump_path"
,
default
=
None
,
type
=
str
,
required
=
True
,
help
=
"Path to the output PyTorch model."
)
args
=
parser
.
parse_args
()
convert_tf_checkpoint_to_pytorch
(
args
.
tf_checkpoint_path
,
args
.
t5_config_file
,
args
.
pytorch_dump_path
)
transformers/modeling_t5.py
0 → 100644
View file @
60a5babd
# coding=utf-8
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch T5 model. """
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
json
import
logging
import
math
import
os
import
sys
from
io
import
open
import
torch
from
torch
import
nn
from
torch.nn
import
CrossEntropyLoss
,
MSELoss
from
.modeling_utils
import
PreTrainedModel
,
prune_linear_layer
from
.configuration_t5
import
T5Config
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
####################################################
# This dict contrains shortcut names and associated url
# for the pretrained weights provided with the models
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-pytorch_model.bin"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-pytorch_model.bin"
,
}
####################################################
# This is a conversion method from TF 1.0 to PyTorch
# More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
####################################################
def
load_tf_weights_in_t5
(
model
,
config
,
tf_checkpoint_path
):
""" Load tf checkpoints in a pytorch model.
"""
try
:
import
re
import
numpy
as
np
import
tensorflow
as
tf
except
ImportError
:
logger
.
error
(
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
"https://www.tensorflow.org/install/ for installation instructions."
)
raise
tf_path
=
os
.
path
.
abspath
(
tf_checkpoint_path
)
logger
.
info
(
"Converting TensorFlow checkpoint from {}"
.
format
(
tf_path
))
# Load weights from TF model
init_vars
=
tf
.
train
.
list_variables
(
tf_path
)
names
=
[]
arrays
=
[]
for
name
,
shape
in
init_vars
:
logger
.
info
(
"Loading TF weight {} with shape {}"
.
format
(
name
,
shape
))
array
=
tf
.
train
.
load_variable
(
tf_path
,
name
)
names
.
append
(
name
)
arrays
.
append
(
array
)
for
name
,
array
in
zip
(
names
,
arrays
):
name
=
name
.
split
(
'/'
)
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
# which are not required for using pretrained model
if
any
(
n
in
[
"adam_v"
,
"adam_m"
,
"global_step"
]
for
n
in
name
):
logger
.
info
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
pointer
=
model
for
m_name
in
name
:
if
re
.
fullmatch
(
r
'[A-Za-z]+_\d+'
,
m_name
):
l
=
re
.
split
(
r
'_(\d+)'
,
m_name
)
else
:
l
=
[
m_name
]
if
l
[
0
]
==
'kernel'
or
l
[
0
]
==
'gamma'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'output_bias'
or
l
[
0
]
==
'beta'
:
pointer
=
getattr
(
pointer
,
'bias'
)
elif
l
[
0
]
==
'output_weights'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
l
[
0
]
==
'squad'
:
pointer
=
getattr
(
pointer
,
'classifier'
)
else
:
try
:
pointer
=
getattr
(
pointer
,
l
[
0
])
except
AttributeError
:
logger
.
info
(
"Skipping {}"
.
format
(
"/"
.
join
(
name
)))
continue
if
len
(
l
)
>=
2
:
num
=
int
(
l
[
1
])
pointer
=
pointer
[
num
]
if
m_name
[
-
11
:]
==
'_embeddings'
:
pointer
=
getattr
(
pointer
,
'weight'
)
elif
m_name
==
'kernel'
:
array
=
np
.
transpose
(
array
)
try
:
assert
pointer
.
shape
==
array
.
shape
except
AssertionError
as
e
:
e
.
args
+=
(
pointer
.
shape
,
array
.
shape
)
raise
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
name
))
pointer
.
data
=
torch
.
from_numpy
(
array
)
return
model
####################################################
# PyTorch Models are constructed by sub-classing
# - torch.nn.Module for the layers and
# - PreTrainedModel for the models (it-self a sub-class of torch.nn.Module)
####################################################
class
T5Layer
(
nn
.
Module
):
def
__init__
(
self
,
config
):
super
(
T5Layer
,
self
).
__init__
()
self
.
attention
=
T5Attention
(
config
)
self
.
intermediate
=
T5Intermediate
(
config
)
self
.
output
=
T5Output
(
config
)
def
forward
(
self
,
hidden_states
,
attention_mask
=
None
,
head_mask
=
None
):
attention_outputs
=
self
.
attention
(
hidden_states
,
attention_mask
,
head_mask
)
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
output
(
intermediate_output
,
attention_output
)
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
return
outputs
class
T5PreTrainedModel
(
PreTrainedModel
):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
config_class
=
T5Config
pretrained_model_archive_map
=
T5_PRETRAINED_MODEL_ARCHIVE_MAP
load_tf_weights
=
load_tf_weights_in_t5
base_model_prefix
=
"transformer"
def
_init_weights
(
self
,
module
):
""" Initialize the weights """
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
self
.
config
.
initializer_range
)
elif
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
if
isinstance
(
module
,
nn
.
Linear
)
and
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
It's an encoder decoder pre-trained transformer.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
.. _`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`:
https://arxiv.org/abs/1910.10683
.. _`torch.nn.Module`:
https://pytorch.org/docs/stable/nn.html#module
Parameters:
config (:class:`~transformers.T5Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
T5_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training, T5 input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare single stack (encoder or decoder) of a T5 Model transformer outputting raw hidden-states"
"without any specific head on top."
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
T5Model
(
T5PreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-base-uncased')
model = T5Model.from_pretrained('t5-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def
__init__
(
self
,
config
):
super
(
T5Model
,
self
).
__init__
(
config
)
self
.
embeddings
=
T5Embeddings
(
config
)
self
.
encoder
=
T5Encoder
(
config
)
self
.
pooler
=
T5Pooler
(
config
)
self
.
init_weights
()
@
property
def
get_input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
embeddings
.
word_embeddings
=
new_embeddings
def
_prune_heads
(
self
,
heads_to_prune
):
""" Prunes heads of the model.
heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
See base class PreTrainedModel
"""
for
layer
,
heads
in
heads_to_prune
.
items
():
self
.
encoder
.
layer
[
layer
].
attention
.
prune_heads
(
heads
)
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones_like
(
input_ids
)
if
token_type_ids
is
None
:
token_type_ids
=
torch
.
zeros_like
(
input_ids
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
)
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
extended_attention_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# fp16 compatibility
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
if
head_mask
is
not
None
:
if
head_mask
.
dim
()
==
1
:
head_mask
=
head_mask
.
unsqueeze
(
0
).
unsqueeze
(
0
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
head_mask
=
head_mask
.
expand
(
self
.
config
.
num_hidden_layers
,
-
1
,
-
1
,
-
1
,
-
1
)
elif
head_mask
.
dim
()
==
2
:
head_mask
=
head_mask
.
unsqueeze
(
1
).
unsqueeze
(
-
1
).
unsqueeze
(
-
1
)
# We can specify head_mask for each layer
head_mask
=
head_mask
.
to
(
dtype
=
next
(
self
.
parameters
()).
dtype
)
# switch to fload if need + fp16 compatibility
else
:
head_mask
=
[
None
]
*
self
.
config
.
num_hidden_layers
##################################
# Replace this with your model code
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
encoder_outputs
=
self
.
encoder
(
embedding_output
,
extended_attention_mask
,
head_mask
=
head_mask
)
sequence_output
=
encoder_outputs
[
0
]
outputs
=
(
sequence_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, (hidden_states), (attentions)
@
add_start_docstrings
(
"""T5 Model with a `language modeling` head on top. """
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
T5WithLMHead
(
T5PreTrainedModel
):
r
"""
**lm_labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Labels for computing the masked language modeling loss.
Indices should be in ``[-1, 0, ..., config.vocab_size]`` (see ``input_ids`` docstring)
Tokens with indices set to ``-1`` are ignored (masked), the loss is only computed for the tokens with labels
in ``[0, ..., config.vocab_size]``
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``lm_labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Masked language modeling loss.
**prediction_scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-base-uncased')
model = T5ForMaskedLM.from_pretrained('t5-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
"""
def
__init__
(
self
,
config
):
super
(
T5ForMaskedLM
,
self
).
__init__
(
config
)
self
.
transformer
=
T5Model
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
)
self
.
init_weights
()
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
forward
(
self
,
input_ids
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
lm_labels
=
None
):
outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
token_type_ids
=
token_type_ids
,
position_ids
=
position_ids
,
head_mask
=
head_mask
)
sequence_output
=
outputs
[
0
]
lm_logits
=
self
.
cls
(
sequence_output
)
outputs
=
(
lm_logits
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# (lm_loss), lm_logits, (hidden_states), (attentions)
transformers/modeling_tf_t5.py
0 → 100644
View file @
60a5babd
This diff is collapsed.
Click to expand it.
transformers/tokenization_t5.py
0 → 100644
View file @
60a5babd
# coding=utf-8
# Copyright 2018 T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tokenization class for model T5."""
from
__future__
import
absolute_import
,
division
,
print_function
,
unicode_literals
import
collections
import
logging
import
os
import
unicodedata
from
io
import
open
from
.tokenization_utils
import
PreTrainedTokenizer
logger
=
logging
.
getLogger
(
__name__
)
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to file names for serializing Tokenizer instances
####################################################
VOCAB_FILES_NAMES
=
{
'vocab_file'
:
'vocab.txt'
}
####################################################
# Mapping from the keyword arguments names of Tokenizer `__init__`
# to pretrained vocabulary URL for all the model shortcut names.
####################################################
PRETRAINED_VOCAB_FILES_MAP
=
{
'vocab_file'
:
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-vocab.txt"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-vocab.txt"
,
}
}
####################################################
# Mapping from model shortcut names to max length of inputs
####################################################
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
=
{
't5-base-uncased'
:
512
,
't5-large-uncased'
:
512
,
}
####################################################
# Mapping from model shortcut names to a dictionary of additional
# keyword arguments for Tokenizer `__init__`.
# To be used for checkpoint specific configurations.
####################################################
PRETRAINED_INIT_CONFIGURATION
=
{
't5-base-uncased'
:
{
'do_lower_case'
:
True
},
't5-large-uncased'
:
{
'do_lower_case'
:
True
},
}
def
load_vocab
(
vocab_file
):
"""Loads a vocabulary file into a dictionary."""
vocab
=
collections
.
OrderedDict
()
with
open
(
vocab_file
,
"r"
,
encoding
=
"utf-8"
)
as
reader
:
tokens
=
reader
.
readlines
()
for
index
,
token
in
enumerate
(
tokens
):
token
=
token
.
rstrip
(
'
\n
'
)
vocab
[
token
]
=
index
return
vocab
class
T5Tokenizer
(
PreTrainedTokenizer
):
r
"""
Constructs a T5Tokenizer.
:class:`~transformers.T5Tokenizer` runs end-to-end tokenization: punctuation splitting + wordpiece
Args:
vocab_file: Path to a one-wordpiece-per-line vocabulary file
do_lower_case: Whether to lower case the input. Only has an effect when do_wordpiece_only=False
"""
vocab_files_names
=
VOCAB_FILES_NAMES
pretrained_vocab_files_map
=
PRETRAINED_VOCAB_FILES_MAP
pretrained_init_configuration
=
PRETRAINED_INIT_CONFIGURATION
max_model_input_sizes
=
PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
def
__init__
(
self
,
vocab_file
,
do_lower_case
=
True
,
unk_token
=
"[UNK]"
,
sep_token
=
"[SEP]"
,
pad_token
=
"[PAD]"
,
cls_token
=
"[CLS]"
,
mask_token
=
"[MASK]"
,
**
kwargs
):
"""Constructs a T5Tokenizer.
Args:
**vocab_file**: Path to a one-wordpiece-per-line vocabulary file
**do_lower_case**: (`optional`) boolean (default True)
Whether to lower case the input
Only has an effect when do_basic_tokenize=True
"""
super
(
T5Tokenizer
,
self
).
__init__
(
unk_token
=
unk_token
,
sep_token
=
sep_token
,
pad_token
=
pad_token
,
cls_token
=
cls_token
,
mask_token
=
mask_token
,
**
kwargs
)
self
.
max_len_single_sentence
=
self
.
max_len
-
2
# take into account special tokens
self
.
max_len_sentences_pair
=
self
.
max_len
-
3
# take into account special tokens
if
not
os
.
path
.
isfile
(
vocab_file
):
raise
ValueError
(
"Can't find a vocabulary file at path '{}'. To load the vocabulary from a Google pretrained "
"model use `tokenizer = T5Tokenizer.from_pretrained(PRETRAINED_MODEL_NAME)`"
.
format
(
vocab_file
))
self
.
vocab
=
load_vocab
(
vocab_file
)
@
property
def
vocab_size
(
self
):
return
len
(
self
.
vocab
)
def
_tokenize
(
self
,
text
):
""" Take as input a string and return a list of strings (tokens) for words/sub-words
"""
split_tokens
=
[]
if
self
.
do_basic_tokenize
:
for
token
in
self
.
basic_tokenizer
.
tokenize
(
text
,
never_split
=
self
.
all_special_tokens
):
for
sub_token
in
self
.
wordpiece_tokenizer
.
tokenize
(
token
):
split_tokens
.
append
(
sub_token
)
else
:
split_tokens
=
self
.
wordpiece_tokenizer
.
tokenize
(
text
)
return
split_tokens
def
_convert_token_to_id
(
self
,
token
):
""" Converts a token (str/unicode) in an id using the vocab. """
return
self
.
vocab
.
get
(
token
,
self
.
vocab
.
get
(
self
.
unk_token
))
def
_convert_id_to_token
(
self
,
index
):
"""Converts an index (integer) in a token (string/unicode) using the vocab."""
return
self
.
ids_to_tokens
.
get
(
index
,
self
.
unk_token
)
def
convert_tokens_to_string
(
self
,
tokens
):
""" Converts a sequence of tokens (string) in a single string. """
out_string
=
' '
.
join
(
tokens
).
replace
(
' ##'
,
''
).
strip
()
return
out_string
def
build_inputs_with_special_tokens
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens.
A BERT sequence has the following format:
single sequence: [CLS] X [SEP]
pair of sequences: [CLS] A [SEP] B [SEP]
"""
if
token_ids_1
is
None
:
return
[
self
.
cls_token_id
]
+
token_ids_0
+
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
sep
=
[
self
.
sep_token_id
]
return
cls
+
token_ids_0
+
sep
+
token_ids_1
+
sep
def
get_special_tokens_mask
(
self
,
token_ids_0
,
token_ids_1
=
None
,
already_has_special_tokens
=
False
):
"""
Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
special tokens using the tokenizer ``prepare_for_model`` or ``encode_plus`` methods.
Args:
token_ids_0: list of ids (must not contain special tokens)
token_ids_1: Optional list of ids (must not contain special tokens), necessary when fetching sequence ids
for sequence pairs
already_has_special_tokens: (default False) Set to True if the token list is already formated with
special tokens for the model
Returns:
A list of integers in the range [0, 1]: 0 for a special token, 1 for a sequence token.
"""
if
already_has_special_tokens
:
if
token_ids_1
is
not
None
:
raise
ValueError
(
"You should not supply a second sequence if the provided sequence of "
"ids is already formated with special tokens for the model."
)
return
list
(
map
(
lambda
x
:
1
if
x
in
[
self
.
sep_token_id
,
self
.
cls_token_id
]
else
0
,
token_ids_0
))
if
token_ids_1
is
not
None
:
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
+
([
0
]
*
len
(
token_ids_1
))
+
[
1
]
return
[
1
]
+
([
0
]
*
len
(
token_ids_0
))
+
[
1
]
def
create_token_type_ids_from_sequences
(
self
,
token_ids_0
,
token_ids_1
=
None
):
"""
Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
A BERT sequence pair mask has the following format:
0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
| first sequence | second sequence
if token_ids_1 is None, only returns the first portion of the mask (0's).
"""
sep
=
[
self
.
sep_token_id
]
cls
=
[
self
.
cls_token_id
]
if
token_ids_1
is
None
:
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
return
len
(
cls
+
token_ids_0
+
sep
)
*
[
0
]
+
len
(
token_ids_1
+
sep
)
*
[
1
]
def
save_vocabulary
(
self
,
vocab_path
):
"""Save the tokenizer vocabulary to a directory or file."""
index
=
0
if
os
.
path
.
isdir
(
vocab_path
):
vocab_file
=
os
.
path
.
join
(
vocab_path
,
VOCAB_FILES_NAMES
[
'vocab_file'
])
else
:
vocab_file
=
vocab_path
with
open
(
vocab_file
,
"w"
,
encoding
=
"utf-8"
)
as
writer
:
for
token
,
token_index
in
sorted
(
self
.
vocab
.
items
(),
key
=
lambda
kv
:
kv
[
1
]):
if
index
!=
token_index
:
logger
.
warning
(
"Saving vocabulary to {}: vocabulary indices are not consecutive."
" Please check that the vocabulary is not corrupted!"
.
format
(
vocab_file
))
index
=
token_index
writer
.
write
(
token
+
u
'
\n
'
)
index
+=
1
return
(
vocab_file
,)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment