Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
727a79b3
Commit
727a79b3
authored
Nov 08, 2019
by
thomwolf
Browse files
added TF2 model and tests - updated templates
parent
8fda532c
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
651 additions
and
386 deletions
+651
-386
templates/adding_a_new_model/modeling_tf_xxx.py
templates/adding_a_new_model/modeling_tf_xxx.py
+2
-0
templates/adding_a_new_model/modeling_xxx.py
templates/adding_a_new_model/modeling_xxx.py
+2
-0
transformers/__init__.py
transformers/__init__.py
+3
-0
transformers/configuration_auto.py
transformers/configuration_auto.py
+5
-1
transformers/configuration_t5.py
transformers/configuration_t5.py
+1
-2
transformers/modeling_t5.py
transformers/modeling_t5.py
+57
-20
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+2
-2
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+506
-285
transformers/modeling_utils.py
transformers/modeling_utils.py
+2
-4
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+20
-3
transformers/tests/modeling_tf_t5_test.py
transformers/tests/modeling_tf_t5_test.py
+51
-69
No files found.
templates/adding_a_new_model/modeling_tf_xxx.py
View file @
727a79b3
...
...
@@ -26,6 +26,8 @@ import logging
import
math
import
os
import
sys
import
copy
import
itertools
from
io
import
open
import
numpy
as
np
...
...
templates/adding_a_new_model/modeling_xxx.py
View file @
727a79b3
...
...
@@ -25,6 +25,8 @@ import logging
import
math
import
os
import
sys
import
copy
import
itertools
from
io
import
open
import
torch
...
...
transformers/__init__.py
View file @
727a79b3
...
...
@@ -158,6 +158,9 @@ if is_tf_available():
TFCTRLLMHeadModel
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_t5
import
(
TFT5PreTrainedModel
,
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
...
...
transformers/configuration_auto.py
View file @
727a79b3
...
...
@@ -27,6 +27,7 @@ from .configuration_xlm import XLMConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_t5
import
T5Config
logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -64,6 +65,7 @@ class AutoConfig(object):
The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
...
...
@@ -114,7 +116,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
't5'
in
pretrained_model_name_or_path
:
return
T5Config
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
...
...
transformers/configuration_t5.py
View file @
727a79b3
...
...
@@ -27,8 +27,7 @@ from .configuration_utils import PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-config.json"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-config.json"
,
't5-small'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json"
,
}
...
...
transformers/modeling_t5.py
View file @
727a79b3
...
...
@@ -41,8 +41,7 @@ logger = logging.getLogger(__name__)
# for the pretrained weights provided with the models
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-pytorch_model.bin"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-pytorch_model.bin"
,
't5-small'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-pytorch_model.bin"
,
}
####################################################
...
...
@@ -442,7 +441,7 @@ class T5PreTrainedModel(PreTrainedModel):
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
factor
*
1.0
)
elif
isinstance
(
module
,
T5Model
):
elif
isinstance
(
module
,
(
T5Model
,
T5WithLMHeadModel
)
):
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module
.
shared
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
1.0
)
...
...
@@ -502,11 +501,10 @@ class T5Stack(T5PreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
attention_mask
.
dim
()
==
2
:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
hidden_states
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
...
...
@@ -593,7 +591,7 @@ class T5Stack(T5PreTrainedModel):
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
It's an encoder decoder
pre-trained transformer
.
It's an encoder decoder
transformer pre-trained in a text-to-text denoising generative setting
.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
...
...
@@ -634,16 +632,13 @@ T5_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare
single stack (encoder or decoder) of a
T5 Model transformer outputting raw hidden-states"
@
add_start_docstrings
(
"The bare T5 Model transformer outputting raw hidden-states"
"without any specific head on top."
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
T5Model
(
T5PreTrainedModel
):
...
...
@@ -661,8 +656,8 @@ class T5Model(T5PreTrainedModel):
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-
base-uncased
')
model = T5Model.from_pretrained('t5-
base-uncased
')
tokenizer = T5Tokenizer.from_pretrained('t5-
small
')
model = T5Model.from_pretrained('t5-
small
')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
...
...
@@ -752,8 +747,8 @@ class T5WithLMHeadModel(T5PreTrainedModel):
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-
base-uncased
')
model = T5WithLMHeadModel.from_pretrained('t5-
base-uncased
')
tokenizer = T5Tokenizer.from_pretrained('t5-
small
')
model = T5WithLMHeadModel.from_pretrained('t5-
small
')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
...
...
@@ -763,31 +758,73 @@ class T5WithLMHeadModel(T5PreTrainedModel):
super
(
T5WithLMHeadModel
,
self
).
__init__
(
config
)
self
.
model_dim
=
config
.
d_model
self
.
transformer
=
T5Model
(
config
)
self
.
shared
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
encoder_config
=
copy
.
deepcopy
(
config
)
self
.
encoder
=
T5Stack
(
encoder_config
)
decoder_config
=
copy
.
deepcopy
(
config
)
decoder_config
.
is_decoder
=
True
self
.
decoder
=
T5Stack
(
decoder_config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
False
)
self
.
init_weights
()
def
get_input_embeddings
(
self
):
return
self
.
shared
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
shared
=
new_embeddings
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
forward
(
self
,
**
kwargs
):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
lm_labels
=
kwargs
.
pop
(
'decoder_lm_labels'
,
None
)
outputs
=
self
.
transformer
(
**
kwargs
)
sequence_output
=
outputs
[
0
]
kwargs_common
=
dict
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"encoder_"
)
and
not
k
.
startswith
(
"decoder_"
))
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
dict
((
k
[
len
(
"encoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"encoder_"
)))
kwargs_decoder
.
update
(
dict
((
k
[
len
(
"decoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"decoder_"
)))
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
# Decode
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
sequence_output
=
decoder_outputs
[
0
]
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output
=
sequence_output
*
(
self
.
model_dim
**
-
0.5
)
lm_logits
=
self
.
lm_head
(
sequence_output
)
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
decoder_
outputs
=
(
lm_logits
,)
+
decoder_
outputs
[
1
:]
# Add hidden states and attention if they are here
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
decoder_
outputs
=
(
loss
,)
+
decoder_
outputs
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
return
outputs
# (lm_loss), lm_logits, (hidden_states), (attentions)
return
decoder_outputs
+
encoder_outputs
transformers/modeling_tf_pytorch_utils.py
View file @
727a79b3
...
...
@@ -156,7 +156,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
#
logger.
warning
("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
...
...
@@ -269,7 +269,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
e
.
args
+=
(
pt_weight
.
shape
,
array
.
shape
)
raise
e
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
pt_weight_name
))
#
logger.
warning
("Initialize PyTorch weight {}".format(pt_weight_name))
new_pt_params_dict
[
pt_weight_name
]
=
torch
.
from_numpy
(
array
)
loaded_pt_weights_data_ptr
[
pt_weight
.
data_ptr
()]
=
torch
.
from_numpy
(
array
)
...
...
transformers/modeling_tf_t5.py
View file @
727a79b3
...
...
@@ -22,24 +22,21 @@ import logging
import
math
import
os
import
sys
import
copy
import
itertools
from
io
import
open
import
numpy
as
np
import
tensorflow
as
tf
from
.configuration_t5
import
T5Config
from
.modeling_tf_utils
import
TFPreTrainedModel
,
get_initializer
from
.modeling_tf_utils
import
TFPreTrainedModel
,
TFSharedEmbeddings
,
shape_list
,
get_initializer
,
DUMMY_INPUTS
from
.file_utils
import
add_start_docstrings
logger
=
logging
.
getLogger
(
__name__
)
####################################################
# This dict contrains shortcut names and associated url
# for the pretrained weights provided with the models
####################################################
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-tf_model.h5"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-tf_model.h5"
,
't5-small'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-tf_model.h5"
,
}
####################################################
...
...
@@ -48,33 +45,294 @@ TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP = {
# - TFPreTrainedModel for the models (it-self a sub-class of tf.keras.Model)
####################################################
####################################################
# Here is an example of typical layer in a TF 2.0 model of the library
# The classes are usually identical to the PyTorch ones and prefixed with 'TF'.
#
# Note that class __init__ parameters includes **kwargs (send to 'super').
# This let us have a control on class scope and variable names:
# More precisely, we set the names of the class attributes (lower level layers) to
# to the equivalent attributes names in the PyTorch model so we can have equivalent
# class and scope structure between PyTorch and TF 2.0 models and easily load one in the other.
#
# See the conversion methods in modeling_tf_pytorch_utils.py for more details
####################################################
class
TFT5Layer
(
tf
.
keras
.
layers
.
Layer
):
class
TFT5DenseReluDense
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFT5DenseReluDense
,
self
).
__init__
(
**
kwargs
)
self
.
wi
=
tf
.
keras
.
layers
.
Dense
(
config
.
d_ff
,
use_bias
=
False
,
name
=
'wi'
)
self
.
wo
=
tf
.
keras
.
layers
.
Dense
(
config
.
d_model
,
use_bias
=
False
,
name
=
'wo'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
self
.
act
=
tf
.
keras
.
activations
.
relu
def
call
(
self
,
hidden_states
,
training
=
False
):
h
=
self
.
wi
(
hidden_states
)
h
=
self
.
act
(
h
)
h
=
self
.
dropout
(
h
,
training
=
training
)
h
=
self
.
wo
(
h
)
return
h
class
TFT5LayerFF
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFT5Layer
,
self
).
__init__
(
**
kwargs
)
self
.
attention
=
TFT5Attention
(
config
,
name
=
'attention'
)
self
.
intermediate
=
TFT5Intermediate
(
config
,
name
=
'intermediate'
)
self
.
transformer_output
=
TFT5Output
(
config
,
name
=
'output'
)
def
call
(
self
,
inputs
,
training
=
False
):
hidden_states
,
attention_mask
,
head_mask
=
inputs
attention_outputs
=
self
.
attention
([
hidden_states
,
attention_mask
,
head_mask
],
training
=
training
)
attention_output
=
attention_outputs
[
0
]
intermediate_output
=
self
.
intermediate
(
attention_output
)
layer_output
=
self
.
transformer_output
([
intermediate_output
,
attention_output
],
training
=
training
)
outputs
=
(
layer_output
,)
+
attention_outputs
[
1
:]
# add attentions if we output them
super
(
TFT5LayerFF
,
self
).
__init__
(
**
kwargs
)
self
.
DenseReluDense
=
TFT5DenseReluDense
(
config
,
name
=
'DenseReluDense'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'layer_norm'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
def
call
(
self
,
hidden_states
,
training
=
False
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
y
=
self
.
DenseReluDense
(
norm_x
,
training
=
training
)
layer_output
=
hidden_states
+
self
.
dropout
(
y
,
training
=
training
)
return
layer_output
class
TFT5Attention
(
tf
.
keras
.
layers
.
Layer
):
NEW_ID
=
itertools
.
count
()
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
,
**
kwargs
):
super
(
TFT5Attention
,
self
).
__init__
(
**
kwargs
)
self
.
layer_id
=
next
(
TFT5Attention
.
NEW_ID
)
self
.
is_decoder
=
config
.
is_decoder
self
.
has_relative_attention_bias
=
has_relative_attention_bias
self
.
output_attentions
=
config
.
output_attentions
self
.
relative_attention_num_buckets
=
config
.
relative_attention_num_buckets
self
.
dim
=
config
.
d_model
self
.
d_kv
=
config
.
d_kv
self
.
n_heads
=
config
.
num_heads
assert
self
.
dim
%
self
.
n_heads
==
0
assert
self
.
dim
//
self
.
n_heads
==
self
.
d_kv
# Mesh TensorFlow initialization to avoid scaling before softmax
self
.
q
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'q'
)
self
.
k
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'k'
)
self
.
v
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'v'
)
self
.
o
=
tf
.
keras
.
layers
.
Dense
(
self
.
dim
,
use_bias
=
False
,
name
=
'o'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
if
self
.
has_relative_attention_bias
:
self
.
relative_attention_bias
=
tf
.
keras
.
layers
.
Embedding
(
self
.
relative_attention_num_buckets
,
self
.
n_heads
,
name
=
'relative_attention_bias'
)
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
raise
NotImplementedError
@
staticmethod
def
_relative_position_bucket
(
relative_position
,
bidirectional
=
True
,
num_buckets
=
32
,
max_distance
=
128
):
"""
Adapted from Mesh Tensorflow:
https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
Translate relative position to a bucket number for relative attention.
The relative position is defined as memory_position - query_position, i.e.
the distance in tokens from the attending position to the attended-to
position. If bidirectional=False, then positive relative positions are
invalid.
We use smaller buckets for small absolute relative_position and larger buckets
for larger absolute relative_positions. All relative positions >=max_distance
map to the same bucket. All relative positions <=-max_distance map to the
same bucket. This should allow for more graceful generalization to longer
sequences than the model has been trained on.
Args:
relative_position: an int32 Tensor
bidirectional: a boolean - whether the attention is bidirectional
num_buckets: an integer
max_distance: an integer
Returns:
a Tensor with the same shape as relative_position, containing int32
values in the range [0, num_buckets)
"""
ret
=
0
n
=
-
relative_position
if
bidirectional
:
num_buckets
//=
2
ret
+=
tf
.
dtypes
.
cast
(
tf
.
math
.
less
(
n
,
0
),
tf
.
int32
)
*
num_buckets
n
=
tf
.
math
.
abs
(
n
)
else
:
n
=
tf
.
math
.
maximum
(
n
,
0
)
# now n is in the range [0, inf)
max_exact
=
num_buckets
//
2
is_small
=
tf
.
math
.
less
(
n
,
max_exact
)
val_if_large
=
max_exact
+
tf
.
dtypes
.
cast
(
tf
.
math
.
log
(
tf
.
dtypes
.
cast
(
n
,
tf
.
float32
)
/
max_exact
)
/
math
.
log
(
max_distance
/
max_exact
)
*
(
num_buckets
-
max_exact
),
tf
.
int32
)
val_if_large
=
tf
.
math
.
minimum
(
val_if_large
,
num_buckets
-
1
)
ret
+=
tf
.
where
(
is_small
,
n
,
val_if_large
)
return
ret
def
compute_bias
(
self
,
qlen
,
klen
):
""" Compute binned relative position bias """
context_position
=
tf
.
range
(
qlen
)[:,
None
]
memory_position
=
tf
.
range
(
klen
)[
None
,
:]
relative_position
=
memory_position
-
context_position
# shape (qlen, klen)
rp_bucket
=
self
.
_relative_position_bucket
(
relative_position
,
bidirectional
=
not
self
.
is_decoder
,
num_buckets
=
self
.
relative_attention_num_buckets
)
values
=
self
.
relative_attention_bias
(
rp_bucket
)
# shape (qlen, klen, num_heads)
values
=
tf
.
expand_dims
(
tf
.
transpose
(
values
,
[
2
,
0
,
1
]),
axis
=
0
)
# shape (1, num_heads, qlen, klen)
return
values
def
call
(
self
,
input
,
mask
=
None
,
kv
=
None
,
position_bias
=
None
,
cache
=
None
,
head_mask
=
None
,
training
=
False
):
"""
Self-attention (if kv is None) or attention over source sentence (provided by kv).
"""
# Input is (bs, qlen, dim)
# Mask is (bs, klen) (non-causal) or (bs, klen, klen)
bs
,
qlen
,
dim
=
shape_list
(
input
)
if
kv
is
None
:
klen
=
qlen
if
cache
is
None
else
cache
[
'slen'
]
+
qlen
else
:
klen
=
shape_list
(
kv
)[
1
]
# assert dim == self.dim, 'Dimensions do not match: %s input vs %s configured' % (dim, self.dim)
n_heads
=
self
.
n_heads
dim_per_head
=
self
.
dim
//
n_heads
def
shape
(
x
):
""" projection """
return
tf
.
transpose
(
tf
.
reshape
(
x
,
(
bs
,
-
1
,
self
.
n_heads
,
dim_per_head
)),
perm
=
(
0
,
2
,
1
,
3
))
def
unshape
(
x
):
""" compute context """
return
tf
.
reshape
(
tf
.
transpose
(
x
,
perm
=
(
0
,
2
,
1
,
3
)),
(
bs
,
-
1
,
self
.
n_heads
*
dim_per_head
))
q
=
shape
(
self
.
q
(
input
))
# (bs, n_heads, qlen, dim_per_head)
if
kv
is
None
:
k
=
shape
(
self
.
k
(
input
))
# (bs, n_heads, qlen, dim_per_head)
v
=
shape
(
self
.
v
(
input
))
# (bs, n_heads, qlen, dim_per_head)
elif
cache
is
None
or
self
.
layer_id
not
in
cache
:
k
=
v
=
kv
k
=
shape
(
self
.
k
(
k
))
# (bs, n_heads, qlen, dim_per_head)
v
=
shape
(
self
.
v
(
v
))
# (bs, n_heads, qlen, dim_per_head)
if
cache
is
not
None
:
if
self
.
layer_id
in
cache
:
if
kv
is
None
:
k_
,
v_
=
cache
[
self
.
layer_id
]
k
=
tf
.
concat
([
k_
,
k
],
axis
=
2
)
# (bs, n_heads, klen, dim_per_head)
v
=
tf
.
concat
([
v_
,
v
],
axis
=
2
)
# (bs, n_heads, klen, dim_per_head)
else
:
k
,
v
=
cache
[
self
.
layer_id
]
cache
[
self
.
layer_id
]
=
(
k
,
v
)
# q = q / math.sqrt(dim_per_head) # No scaling in T5
scores
=
tf
.
matmul
(
q
,
k
,
transpose_b
=
True
)
# (bs, n_heads, qlen, klen)
if
position_bias
is
None
:
if
not
self
.
has_relative_attention_bias
:
raise
ValueError
(
"No position_bias provided and no weights to compute position_bias"
)
position_bias
=
self
.
compute_bias
(
qlen
,
klen
)
scores
+=
position_bias
if
mask
is
not
None
:
scores
+=
mask
# mask = (mask == 0).expand_as(scores) # (bs, n_heads, qlen, klen)
# scores.masked_fill_(mask, -float('inf')) # (bs, n_heads, qlen, klen)
weights
=
tf
.
nn
.
softmax
(
scores
,
axis
=-
1
)
# (bs, n_heads, qlen, klen)
weights
=
self
.
dropout
(
weights
,
training
=
training
)
# (bs, n_heads, qlen, klen)
# Mask heads if we want to
if
head_mask
is
not
None
:
weights
=
weights
*
head_mask
context
=
tf
.
matmul
(
weights
,
v
)
# (bs, n_heads, qlen, dim_per_head)
context
=
unshape
(
context
)
# (bs, qlen, dim)
context
=
self
.
o
(
context
)
outputs
=
(
context
,)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
weights
,)
if
self
.
has_relative_attention_bias
:
outputs
=
outputs
+
(
position_bias
,)
return
outputs
class
TFT5LayerSelfAttention
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
,
**
kwargs
):
super
(
TFT5LayerSelfAttention
,
self
).
__init__
(
**
kwargs
)
self
.
SelfAttention
=
TFT5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
,
name
=
'SelfAttention'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'layer_norm'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
def
call
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
,
training
=
False
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
SelfAttention
(
norm_x
,
mask
=
attention_mask
,
position_bias
=
position_bias
,
head_mask
=
head_mask
,
training
=
training
)
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
,
training
=
training
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
class
TFT5LayerCrossAttention
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
,
**
kwargs
):
super
(
TFT5LayerCrossAttention
,
self
).
__init__
(
**
kwargs
)
self
.
EncDecAttention
=
TFT5Attention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
,
name
=
'EncDecAttention'
)
self
.
layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'layer_norm'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
def
call
(
self
,
hidden_states
,
kv
,
attention_mask
=
None
,
position_bias
=
None
,
head_mask
=
None
,
training
=
False
):
norm_x
=
self
.
layer_norm
(
hidden_states
)
attention_output
=
self
.
EncDecAttention
(
norm_x
,
mask
=
attention_mask
,
kv
=
kv
,
position_bias
=
position_bias
,
head_mask
=
head_mask
,
training
=
training
)
y
=
attention_output
[
0
]
layer_output
=
hidden_states
+
self
.
dropout
(
y
,
training
=
training
)
outputs
=
(
layer_output
,)
+
attention_output
[
1
:]
# add attentions if we output them
return
outputs
class
TFT5Block
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
has_relative_attention_bias
=
False
,
**
kwargs
):
super
(
TFT5Block
,
self
).
__init__
(
**
kwargs
)
self
.
is_decoder
=
config
.
is_decoder
self
.
layer
=
[]
self
.
layer
.
append
(
TFT5LayerSelfAttention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
,
name
=
'layer_._0'
))
if
self
.
is_decoder
:
self
.
layer
.
append
(
TFT5LayerCrossAttention
(
config
,
has_relative_attention_bias
=
has_relative_attention_bias
,
name
=
'layer_._1'
))
self
.
layer
.
append
(
TFT5LayerFF
(
config
,
name
=
'layer_._2'
))
else
:
self
.
layer
.
append
(
TFT5LayerFF
(
config
,
name
=
'layer_._1'
))
def
call
(
self
,
hidden_states
,
attention_mask
=
None
,
position_bias
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
encoder_decoder_position_bias
=
None
,
head_mask
=
None
,
training
=
False
):
self_attention_outputs
=
self
.
layer
[
0
](
hidden_states
,
attention_mask
=
attention_mask
,
position_bias
=
position_bias
,
head_mask
=
head_mask
,
training
=
training
)
hidden_states
=
self_attention_outputs
[
0
]
outputs
=
self_attention_outputs
[
1
:]
if
not
self
.
is_decoder
:
hidden_states
=
self
.
layer
[
1
](
hidden_states
,
training
=
training
)
else
:
cross_attention_outputs
=
self
.
layer
[
1
](
hidden_states
,
kv
=
encoder_hidden_states
,
attention_mask
=
encoder_attention_mask
,
position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
,
training
=
training
)
hidden_states
=
cross_attention_outputs
[
0
]
outputs
=
cross_attention_outputs
[
1
:]
+
outputs
hidden_states
=
self
.
layer
[
2
](
hidden_states
,
training
=
training
)
outputs
=
(
hidden_states
,)
+
outputs
# add attentions if we output them
return
outputs
...
...
@@ -85,6 +343,19 @@ class TFT5Layer(tf.keras.layers.Layer):
class
TFT5MainLayer
(
tf
.
keras
.
layers
.
Layer
):
def
__init__
(
self
,
config
,
**
kwargs
):
super
(
TFT5MainLayer
,
self
).
__init__
(
**
kwargs
)
self
.
output_attentions
=
config
.
output_attentions
self
.
output_hidden_states
=
config
.
output_hidden_states
self
.
is_decoder
=
config
.
is_decoder
self
.
config
=
config
self
.
num_hidden_layers
=
config
.
num_layers
self
.
block
=
[
TFT5Block
(
config
,
has_relative_attention_bias
=
bool
(
i
==
0
),
name
=
'block_._{}'
.
format
(
i
))
for
i
in
range
(
config
.
num_layers
)]
self
.
final_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
'final_layer_norm'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
dropout_rate
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
raise
NotImplementedError
# Not implemented yet in the library fr TF 2.0 models
...
...
@@ -92,51 +363,56 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def
_prune_heads
(
self
,
heads_to_prune
):
raise
NotImplementedError
# Not implemented yet in the library fr TF 2.0 models
def
call
(
self
,
inputs
,
attention_mask
=
None
,
token_type_ids
=
None
,
position_ids
=
None
,
head_mask
=
None
,
training
=
False
):
# We allow three types of multi-inputs:
# - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call
# - all the arguments provided as a list/tuple (ordered) in the first positional argument of call
# The last two options are useful to use the tf.keras fit() method.
if
isinstance
(
inputs
,
(
tuple
,
list
)):
input_ids
=
inputs
[
0
]
attention_mask
=
inputs
[
1
]
if
len
(
inputs
)
>
1
else
attention_mask
token_type_ids
=
inputs
[
2
]
if
len
(
inputs
)
>
2
else
token_type_ids
position_ids
=
inputs
[
3
]
if
len
(
inputs
)
>
3
else
position_ids
head_mask
=
inputs
[
4
]
if
len
(
inputs
)
>
4
else
head_mask
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
elif
isinstance
(
inputs
,
dict
):
input_ids
=
inputs
.
get
(
'input_ids'
)
attention_mask
=
inputs
.
get
(
'attention_mask'
,
attention_mask
)
token_type_ids
=
inputs
.
get
(
'token_type_ids'
,
token_type_ids
)
position_ids
=
inputs
.
get
(
'position_ids'
,
position_ids
)
head_mask
=
inputs
.
get
(
'head_mask'
,
head_mask
)
assert
len
(
inputs
)
<=
5
,
"Too many inputs."
else
:
input_ids
=
inputs
def
call
(
self
,
hidden_states
,
attention_mask
=
None
,
encoder_hidden_states
=
None
,
encoder_attention_mask
=
None
,
head_mask
=
None
,
training
=
False
):
batch_size
,
seq_length
=
shape_list
(
hidden_states
)[:
2
]
if
attention_mask
is
None
:
attention_mask
=
tf
.
fill
(
tf
.
shape
(
input_ids
),
1
)
if
token_type_ids
is
None
:
token_type_ids
=
tf
.
fill
(
tf
.
shape
(
input_ids
),
0
)
# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
extended_attention_mask
=
attention_mask
[:,
tf
.
newaxis
,
tf
.
newaxis
,
:]
attention_mask
=
tf
.
fill
((
batch_size
,
seq_length
),
1
)
if
self
.
is_decoder
and
encoder_attention_mask
is
None
:
encoder_seq_length
=
encoder_hidden_states
.
shape
[
1
]
encoder_attention_mask
=
tf
.
fill
((
batch_size
,
encoder_seq_length
),
1
)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
attention_mask
=
tf
.
cast
(
attention_mask
,
dtype
=
tf
.
float32
)
num_dims_attention_mask
=
len
(
shape_list
(
attention_mask
))
if
num_dims_attention_mask
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
num_dims_attention_mask
==
2
:
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
self
.
config
.
is_decoder
:
seq_ids
=
tf
.
range
(
seq_length
)
causal_mask
=
tf
.
less_equal
(
tf
.
tile
(
seq_ids
[
None
,
None
,
:],
(
batch_size
,
seq_length
,
1
)),
seq_ids
[
None
,
:,
None
])
causal_mask
=
tf
.
cast
(
causal_mask
,
dtype
=
tf
.
float32
)
extended_attention_mask
=
causal_mask
[:,
None
,
:,
:]
*
attention_mask
[:,
None
,
None
,
:]
else
:
extended_attention_mask
=
attention_mask
[:,
None
,
None
,
:]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
extended_attention_mask
=
tf
.
cast
(
extended_attention_mask
,
tf
.
float32
)
extended_attention_mask
=
(
1.0
-
extended_attention_mask
)
*
-
10000.0
if
self
.
is_decoder
:
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
encoder_attention_mask
=
tf
.
cast
(
encoder_attention_mask
,
dtype
=
tf
.
float32
)
num_dims_encoder_attention_mask
=
len
(
shape_list
(
encoder_attention_mask
))
if
num_dims_encoder_attention_mask
==
3
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
:,
:]
if
num_dims_encoder_attention_mask
==
2
:
encoder_extended_attention_mask
=
encoder_attention_mask
[:,
None
,
None
,
:]
encoder_extended_attention_mask
=
(
1.0
-
encoder_extended_attention_mask
)
*
-
10000.0
else
:
encoder_extended_attention_mask
=
None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
...
...
@@ -148,14 +424,44 @@ class TFT5MainLayer(tf.keras.layers.Layer):
head_mask
=
[
None
]
*
self
.
num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)
##################################
# Replace this with your model code
embedding_output
=
self
.
embeddings
(
input_ids
,
position_ids
=
position_ids
,
token_type_ids
=
token_type_ids
)
encoder_outputs
=
self
.
encoder
([
embedding_output
,
extended_attention_mask
,
head_mask
],
training
=
training
)
sequence_output
=
encoder_outputs
[
0
]
outputs
=
(
sequence_output
,)
+
encoder_outputs
[
1
:]
# add hidden_states and attentions if they are here
return
outputs
# sequence_output, (hidden_states), (attentions)
all_hidden_states
=
()
all_attentions
=
()
position_bias
=
None
encoder_decoder_position_bias
=
None
for
i
,
layer_module
in
enumerate
(
self
.
block
):
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
layer_outputs
=
layer_module
(
hidden_states
,
attention_mask
=
extended_attention_mask
,
position_bias
=
position_bias
,
encoder_hidden_states
=
encoder_hidden_states
,
encoder_attention_mask
=
encoder_extended_attention_mask
,
encoder_decoder_position_bias
=
encoder_decoder_position_bias
,
head_mask
=
head_mask
[
i
],
training
=
training
)
hidden_states
=
layer_outputs
[
0
]
if
i
==
0
:
position_bias
=
layer_outputs
[
2
if
self
.
output_attentions
else
1
]
if
self
.
is_decoder
:
encoder_decoder_position_bias
=
layer_outputs
[
4
if
self
.
output_attentions
else
2
]
if
self
.
output_attentions
:
all_attentions
=
all_attentions
+
(
layer_outputs
[
1
],)
hidden_states
=
self
.
final_layer_norm
(
hidden_states
)
layer_output
=
self
.
dropout
(
hidden_states
,
training
=
training
)
# Add last layer
if
self
.
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_states
,)
outputs
=
(
hidden_states
,)
if
self
.
output_hidden_states
:
outputs
=
outputs
+
(
all_hidden_states
,)
if
self
.
output_attentions
:
outputs
=
outputs
+
(
all_attentions
,)
return
outputs
# last-layer hidden state, (all hidden states), (all attentions)
####################################################
...
...
@@ -173,18 +479,26 @@ class TFT5PreTrainedModel(TFPreTrainedModel):
pretrained_model_archive_map
=
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
base_model_prefix
=
"transformer"
@
property
def
dummy_inputs
(
self
):
input_ids
=
tf
.
constant
([[
7
,
6
,
0
,
0
,
1
],
[
1
,
2
,
3
,
0
,
0
],
[
0
,
0
,
0
,
4
,
5
]])
input_mask
=
tf
.
constant
([[
1
,
1
,
0
,
0
,
1
],
[
1
,
1
,
1
,
0
,
0
],
[
1
,
0
,
0
,
1
,
1
]])
dummy_inputs
=
{
'decoder_input_ids'
:
input_ids
,
'encoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
return
dummy_inputs
T5_START_DOCSTRING
=
r
""" The XXX model was proposed in
`XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_
by Jacob Devlin, Ming-Wei Chang, Kenton Lee and Kristina Toutanova. It's a bidirectional t
ransformer
pre-trained using a combination of masked language modeling objective and next sentence prediction
on a large corpus comprising the Toronto Book Corpus and Wikipedia
.
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text T
ransformer
`_
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
It's an encoder decoder transformer pre-trained in a text-to-text denoising generative setting
.
This model is a tf.keras.Model `tf.keras.Model`_ sub-class. Use it as a regular TF 2.0 Keras Model and
refer to the TF 2.0 documentation for all matter related to general usage and behavior.
.. _`
XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding
`:
https://arxiv.org/abs/1
8
10.
04805
.. _`
Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer
`:
https://arxiv.org/abs/1
9
10.
10683
.. _`tf.keras.Model`:
https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/Model
...
...
@@ -206,67 +520,50 @@ T5_START_DOCSTRING = r""" The XXX model was proposed in
`model({'input_ids': input_ids, 'token_type_ids': token_type_ids})`
Parameters:
config (:class:`~transformers.
Xxx
Config`): Model configuration class with all the parameters of the model.
config (:class:`~transformers.
T5
Config`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the configuration.
Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model weights.
"""
XXX
_INPUTS_DOCSTRING
=
r
"""
T5
_INPUTS_DOCSTRING
=
r
"""
Inputs:
**input_ids**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
To match pre-training,
XXX
input sequence should be formatted with [CLS] and [SEP] tokens as follows:
To match pre-training,
T5
input sequence should be formatted with [CLS] and [SEP] tokens as follows:
(a) For sequence pairs:
``tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1``
(b) For single sequences:
``tokens: [CLS] the dog is hairy . [SEP]``
``token_type_ids: 0 0 0 0 0 0 0``
Xxx is a model with absolute position embeddings so it's usually advised to pad the inputs on
the right rather than the left.
Indices can be obtained using :class:`transformers.XxxTokenizer`.
T5 is a model with relative position embeddings so you should be able to pad the inputs on
the right or the left.
Indices can be obtained using :class:`transformers.T5Tokenizer`.
See :func:`transformers.PreTrainedTokenizer.encode` and
:func:`transformers.PreTrainedTokenizer.convert_tokens_to_ids` for details.
**attention_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**token_type_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
corresponds to a `sentence B` token
(see `XXX: Pre-training of Deep Bidirectional Transformers for Language Understanding`_ for more details).
**position_ids**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``Numpy array`` or ``tf.Tensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
@
add_start_docstrings
(
"The bare Xxx Model transformer outputing raw hidden-states without any specific head on top."
,
XXX_START_DOCSTRING
,
XXX_INPUTS_DOCSTRING
)
class
TFXxxModel
(
TFXxxPreTrainedModel
):
@
add_start_docstrings
(
"The bare T5 Model transformer outputting raw hidden-states"
"without any specific head on top."
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
TFT5Model
(
TFT5PreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the model.
**pooler_output**: ``tf.Tensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Xxx pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
...
...
@@ -278,127 +575,72 @@ class TFXxxModel(TFXxxPreTrainedModel):
Examples::
import tensorflow as tf
from transformers import
Xxx
Tokenizer, TF
Xxx
Model
from transformers import
T5
Tokenizer, TF
T5
Model
tokenizer =
Xxx
Tokenizer.from_pretrained('
xxx-base-uncased
')
model = TF
Xxx
Model.from_pretrained('
xxx-base-uncased
')
tokenizer =
T5
Tokenizer.from_pretrained('
t5-small
')
model = TF
T5
Model.from_pretrained('
t5-small
')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXxxModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXxxMainLayer
(
config
,
name
=
'transformer'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
return
outputs
@
add_start_docstrings
(
"""Xxx Model with a `language modeling` head on top. """
,
XXX_START_DOCSTRING
,
XXX_INPUTS_DOCSTRING
)
class
TFXxxForMaskedLM
(
TFXxxPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**prediction_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.vocab_size)``
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XxxTokenizer, TFXxxForMaskedLM
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
model = TFXxxForMaskedLM.from_pretrained('xxx-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
prediction_scores = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXxxForMaskedLM
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
transformer
=
TFXxxMainLayer
(
config
,
name
=
'transformer'
)
self
.
mlm
=
TFXxxMLMHead
(
config
,
self
.
transformer
.
embeddings
,
name
=
'mlm'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
super
(
TFT5Model
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
shared
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
d_model
,
name
=
'shared'
)
sequence_output
=
outputs
[
0
]
prediction_scores
=
self
.
mlm
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
)
)
encoder_config
=
copy
.
deepcopy
(
config
)
self
.
encoder
=
TFT5MainLayer
(
encoder_config
,
name
=
'encoder'
)
outputs
=
(
prediction_scores
,)
+
outputs
[
2
:]
# Add hidden states and attention if they are here
decoder_config
=
copy
.
deepcopy
(
config
)
decoder_config
.
is_decoder
=
True
self
.
decoder
=
TFT5MainLayer
(
decoder_config
,
name
=
'decoder'
)
return
outputs
# prediction_scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Xxx Model transformer with a sequence classification/regression head on top (a linear layer on top of
the pooled output) e.g. for GLUE tasks. """
,
XXX_START_DOCSTRING
,
XXX_INPUTS_DOCSTRING
)
class
TFXxxForSequenceClassification
(
TFXxxPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**logits**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, config.num_labels)``
Classification (or regression if config.num_labels==1) scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XxxTokenizer, TFXxxForSequenceClassification
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
model = TFXxxForSequenceClassification.from_pretrained('xxx-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
logits = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXxxForSequenceClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
TFXxxMainLayer
(
config
,
name
=
'transformer'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
pooled_output
=
outputs
[
1
]
def
call
(
self
,
decoder_input_ids
,
**
kwargs
):
# We allow two types of multi-inputs:
# - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call
# The last option is useful to use the tf.keras fit() method.
pooled_output
=
self
.
dropout
(
pooled_output
,
training
=
kwargs
.
get
(
'training'
,
False
))
logits
=
self
.
classifier
(
pooled_output
)
if
isinstance
(
decoder_input_ids
,
dict
):
kwargs
.
update
(
decoder_input_ids
)
else
:
kwargs
[
'decoder_input_ids'
]
=
decoder_input_ids
kwargs_common
=
dict
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"encoder_"
)
and
not
k
.
startswith
(
"decoder_"
))
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
dict
((
k
[
len
(
"encoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"encoder_"
)))
kwargs_decoder
.
update
(
dict
((
k
[
len
(
"decoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"decoder_"
)))
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
# Decode
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
return
outputs
# logits, (hidden_states), (attentions)
return
decoder_outputs
+
encoder_outputs
@
add_start_docstrings
(
"""Xxx Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
XXX_START_DOCSTRING
,
XXX_INPUTS_DOCSTRING
)
class
TFXxxForTokenClassification
(
TFXxxPreTrainedModel
):
@
add_start_docstrings
(
"""T5 Model with a `language modeling` head on top. """
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
TFT5WithLMHeadModel
(
TFT5PreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.
num_labels
)``
Classif
ic
a
tion scores
(
before SoftMax).
**
prediction_
scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.
vocab_size
)``
Pred
iction scores
of the language modeling head (scores for each vocabulary token
before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
...
...
@@ -410,87 +652,66 @@ class TFXxxForTokenClassification(TFXxxPreTrainedModel):
Examples::
import tensorflow as tf
from transformers import
Xxx
Tokenizer, TF
XxxForTokenClassification
from transformers import
T5
Tokenizer, TF
T5WithLMHeadModel
tokenizer =
Xxx
Tokenizer.from_pretrained('
xxx-base-uncased
')
model = TF
XxxForTokenClassification
.from_pretrained('
xxx-base-uncased
')
tokenizer =
T5
Tokenizer.from_pretrained('
t5-small
')
model = TF
T5WithLMHeadModel
.from_pretrained('
t5-small
')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
prediction_
scores = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXxxForTokenClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
TFXxxMainLayer
(
config
,
name
=
'transformer'
)
self
.
dropout
=
tf
.
keras
.
layers
.
Dropout
(
config
.
hidden_dropout_prob
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
sequence_output
=
outputs
[
0
]
sequence_output
=
self
.
dropout
(
sequence_output
,
training
=
kwargs
.
get
(
'training'
,
False
))
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
2
:]
# add hidden states and attention if they are here
return
outputs
# scores, (hidden_states), (attentions)
@
add_start_docstrings
(
"""Xxx Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
the hidden-states output to compute `span start logits` and `span end logits`). """
,
XXX_START_DOCSTRING
,
XXX_INPUTS_DOCSTRING
)
class
TFXxxForQuestionAnswering
(
TFXxxPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**start_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
Span-start scores (before SoftMax).
**end_scores**: ``Numpy array`` or ``tf.Tensor`` of shape ``(batch_size, sequence_length,)``
Span-end scores (before SoftMax).
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``Numpy array`` or ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
super
(
TFT5WithLMHeadModel
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
model_dim
=
config
.
d_model
Examples::
self
.
shared
=
TFSharedEmbeddings
(
config
.
vocab_size
,
config
.
d_model
,
name
=
'shared'
)
import tensorflow as tf
from transformers import XxxTokenizer, TFXxxForQuestionAnswering
encoder_config
=
copy
.
deepcopy
(
config
)
self
.
encoder
=
TFT5MainLayer
(
encoder_config
,
name
=
'encoder'
)
tokenizer = XxxTokenizer.from_pretrained('xxx-base-uncased')
model = TFXxxForQuestionAnswering.from_pretrained('xxx-base-uncased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
start_scores, end_scores = outputs[:2]
decoder_config
=
copy
.
deepcopy
(
config
)
decoder_config
.
is_decoder
=
True
self
.
decoder
=
TFT5MainLayer
(
decoder_config
,
name
=
'decoder'
)
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXxxForQuestionAnswering
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
TFXxxMainLayer
(
config
,
name
=
'transformer'
)
self
.
qa_outputs
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'qa_outputs'
)
def
call
(
self
,
inputs
,
**
kwargs
):
outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
def
call
(
self
,
decoder_input_ids
,
**
kwargs
):
# We allow two types of multi-inputs:
# - traditional keyword arguments in the call method
# - all the arguments provided as a dict in the first positional argument of call
# The last option is useful to use the tf.keras fit() method.
sequence_output
=
outputs
[
0
]
if
isinstance
(
decoder_input_ids
,
dict
):
kwargs
.
update
(
decoder_input_ids
)
else
:
kwargs
[
'decoder_input_ids'
]
=
decoder_input_ids
kwargs_common
=
dict
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"encoder_"
)
and
not
k
.
startswith
(
"decoder_"
))
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
dict
((
k
[
len
(
"encoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"encoder_"
)))
kwargs_decoder
.
update
(
dict
((
k
[
len
(
"decoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"decoder_"
)))
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
logits
=
self
.
qa_outputs
(
sequence_output
)
start_logits
,
end_logits
=
tf
.
split
(
logits
,
2
,
axis
=-
1
)
start_logits
=
tf
.
squeeze
(
start_logits
,
axis
=-
1
)
end_logits
=
tf
.
squeeze
(
end_logits
,
axis
=-
1
)
# Decode
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
outputs
=
(
start_logits
,
end_logits
,)
+
outputs
[
2
:]
sequence_output
=
decoder_outputs
[
0
]
*
(
self
.
model_dim
**
-
0.5
)
lm_logits
=
self
.
shared
(
sequence_output
,
mode
=
"linear"
)
decoder_outputs
=
(
lm_logits
,)
+
decoder_outputs
[
1
:]
return
outputs
# start_logits, end_logits, (hidden_states), (attentions)
return
decoder_outputs
+
encoder_outputs
transformers/modeling_utils.py
View file @
727a79b3
...
...
@@ -160,8 +160,7 @@ class PreTrainedModel(nn.Module):
base_model
.
vocab_size
=
new_num_tokens
# Tie weights again if needed
if
hasattr
(
self
,
'tie_weights'
):
self
.
tie_weights
()
self
.
tie_weights
()
return
model_embeds
...
...
@@ -458,8 +457,7 @@ class PreTrainedModel(nn.Module):
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied
model
.
tie_weights
()
# make sure word embedding weights are still tied if needed
# Set model in evaluation mode to desactivate DropOut modules by default
model
.
eval
()
...
...
transformers/tests/modeling_tf_common_test.py
View file @
727a79b3
...
...
@@ -69,6 +69,7 @@ class TFCommonTestCases:
test_torchscript
=
True
test_pruning
=
True
test_resize_embeddings
=
True
is_encoder_decoder
=
False
def
test_initialization
(
self
):
pass
...
...
@@ -156,7 +157,11 @@ class TFCommonTestCases:
def
test_compile_tf_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
input_ids
=
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'input_ids'
,
dtype
=
'int32'
)
if
self
.
is_encoder_decoder
:
input_ids
=
{
'decoder_input_ids'
:
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'decoder_input_ids'
,
dtype
=
'int32'
),
'encoder_input_ids'
:
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'encoder_input_ids'
,
dtype
=
'int32'
)}
else
:
input_ids
=
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'input_ids'
,
dtype
=
'int32'
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
3e-5
,
epsilon
=
1e-08
,
clipnorm
=
1.0
)
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
metric
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
'accuracy'
)
...
...
@@ -189,7 +194,7 @@ class TFCommonTestCases:
outputs_dict
=
model
(
inputs_dict
)
inputs_keywords
=
copy
.
deepcopy
(
inputs_dict
)
input_ids
=
inputs_keywords
.
pop
(
'input_ids'
)
input_ids
=
inputs_keywords
.
pop
(
'input_ids'
,
inputs_keywords
.
pop
(
'decoder_input_ids'
)
)
outputs_keywords
=
model
(
input_ids
,
**
inputs_keywords
)
output_dict
=
outputs_dict
[
0
].
numpy
()
...
...
@@ -216,12 +221,24 @@ class TFCommonTestCases:
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
out_len
=
len
(
outputs
)
if
self
.
is_encoder_decoder
:
self
.
assertEqual
(
out_len
%
2
,
0
)
decoder_attentions
=
outputs
[(
out_len
//
2
)
-
1
]
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
False
)
self
.
assertEqual
(
len
(
decoder_attentions
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertListEqual
(
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
self
.
assertEqual
(
out_len
+
1
,
len
(
outputs
))
self
.
assertEqual
(
out_len
+
(
2
if
self
.
is_encoder_decoder
else
1
)
,
len
(
outputs
))
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
...
...
transformers/tests/modeling_tf_t5_test.py
View file @
727a79b3
...
...
@@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester
from
transformers
import
T5Config
,
is_tf_available
if
False
:
#
is_tf_available():
if
is_tf_available
():
import
tensorflow
as
tf
from
transformers.modeling_tf_t5
import
(
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
...
...
@@ -35,7 +35,8 @@ else:
class
TFT5ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFT5Model
,
TFT5WithLMHeadModel
)
if
False
else
()
# is_tf_available() else ()
is_encoder_decoder
=
True
all_model_classes
=
(
TFT5Model
,
TFT5WithLMHeadModel
)
if
is_tf_available
()
else
()
class
TFT5ModelTester
(
object
):
...
...
@@ -45,22 +46,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
seq_length
=
7
,
is_training
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
n_positions
=
14
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
hidden_act
=
"gelu"
,
hidden_dropout_prob
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
d_ff
=
37
,
relative_attention_num_buckets
=
8
,
dropout_rate
=
0.1
,
initializer_factor
=
0.002
,
scope
=
None
,
):
self
.
parent
=
parent
...
...
@@ -68,22 +63,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
hidden_act
=
hidden_act
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
d_ff
=
d_ff
self
.
relative_attention_num_buckets
=
relative_attention_num_buckets
self
.
dropout_rate
=
dropout_rate
self
.
initializer_factor
=
initializer_factor
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
...
...
@@ -93,61 +82,53 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
config
=
T5Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
num_hidden_layers
=
self
.
num_hidden_layers
,
num_attention_heads
=
self
.
num_attention_heads
,
intermediate_size
=
self
.
intermediate_size
,
hidden_act
=
self
.
hidden_act
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
max_position_embeddings
=
self
.
max_position_embeddings
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_range
=
self
.
initializer_range
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
n_positions
=
self
.
n_positions
,
d_model
=
self
.
hidden_size
,
d_ff
=
self
.
d_ff
,
d_kv
=
self
.
hidden_size
//
self
.
num_attention_heads
,
num_layers
=
self
.
num_hidden_layers
,
num_heads
=
self
.
num_attention_heads
,
relative_attention_num_buckets
=
self
.
relative_attention_num_buckets
,
dropout_rate
=
self
.
dropout_rate
,
initializer_factor
=
self
.
initializer_factor
)
return
(
config
,
input_ids
,
input_mask
,
token_labels
)
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
input_mask
,
token_labels
):
model
=
TFT5Model
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'token_type_ids'
:
token_type_ids
}
sequence_output
,
pooled_output
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
inputs
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
encoder_output
,
decoder_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
encoder_output
,
decoder_output
=
model
(
input_ids
,
decoder_attention_mask
=
input_mask
,
encoder_input_ids
=
input_ids
)
result
=
{
"
sequence
_output"
:
sequence
_output
.
numpy
(),
"
pooled
_output"
:
pooled
_output
.
numpy
(),
"
encoder
_output"
:
encoder
_output
.
numpy
(),
"
decoder
_output"
:
decoder
_output
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
list
(
result
[
"encoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"decoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
shape
),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice
_labels
):
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
token
_labels
):
model
=
TFT5WithLMHeadModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
'
attention_mask
'
:
input_
mask
,
'
token_type_ids'
:
token_type_ids
}
prediction_scores
,
=
model
(
inputs
)
inputs
=
{
'
encoder_
input_ids'
:
input_ids
,
'
decoder_input_ids
'
:
input_
ids
,
'
decoder_attention_mask'
:
input_mask
}
prediction_scores
,
decoder_output
=
model
(
inputs
)
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
...
...
@@ -158,14 +139,15 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
(
config
,
input_ids
,
input_mask
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'encoder_input_ids'
:
input_ids
,
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
return
config
,
inputs_dict
def
setUp
(
self
):
self
.
model_tester
=
TFT5ModelTest
.
TFT5ModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
T5Config
,
hidden_size
=
37
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
T5Config
,
d_model
=
37
)
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
...
...
@@ -181,7 +163,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
[
't5-
base
'
]:
for
model_name
in
[
't5-
small
'
]:
model
=
TFT5Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment