Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
727a79b3
Commit
727a79b3
authored
Nov 08, 2019
by
thomwolf
Browse files
added TF2 model and tests - updated templates
parent
8fda532c
Changes
11
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
651 additions
and
386 deletions
+651
-386
templates/adding_a_new_model/modeling_tf_xxx.py
templates/adding_a_new_model/modeling_tf_xxx.py
+2
-0
templates/adding_a_new_model/modeling_xxx.py
templates/adding_a_new_model/modeling_xxx.py
+2
-0
transformers/__init__.py
transformers/__init__.py
+3
-0
transformers/configuration_auto.py
transformers/configuration_auto.py
+5
-1
transformers/configuration_t5.py
transformers/configuration_t5.py
+1
-2
transformers/modeling_t5.py
transformers/modeling_t5.py
+57
-20
transformers/modeling_tf_pytorch_utils.py
transformers/modeling_tf_pytorch_utils.py
+2
-2
transformers/modeling_tf_t5.py
transformers/modeling_tf_t5.py
+506
-285
transformers/modeling_utils.py
transformers/modeling_utils.py
+2
-4
transformers/tests/modeling_tf_common_test.py
transformers/tests/modeling_tf_common_test.py
+20
-3
transformers/tests/modeling_tf_t5_test.py
transformers/tests/modeling_tf_t5_test.py
+51
-69
No files found.
templates/adding_a_new_model/modeling_tf_xxx.py
View file @
727a79b3
...
@@ -26,6 +26,8 @@ import logging
...
@@ -26,6 +26,8 @@ import logging
import
math
import
math
import
os
import
os
import
sys
import
sys
import
copy
import
itertools
from
io
import
open
from
io
import
open
import
numpy
as
np
import
numpy
as
np
...
...
templates/adding_a_new_model/modeling_xxx.py
View file @
727a79b3
...
@@ -25,6 +25,8 @@ import logging
...
@@ -25,6 +25,8 @@ import logging
import
math
import
math
import
os
import
os
import
sys
import
sys
import
copy
import
itertools
from
io
import
open
from
io
import
open
import
torch
import
torch
...
...
transformers/__init__.py
View file @
727a79b3
...
@@ -158,6 +158,9 @@ if is_tf_available():
...
@@ -158,6 +158,9 @@ if is_tf_available():
TFCTRLLMHeadModel
,
TFCTRLLMHeadModel
,
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
TF_CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_tf_t5
import
(
TFT5PreTrainedModel
,
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
# TF 2.0 <=> PyTorch conversion utilities
# TF 2.0 <=> PyTorch conversion utilities
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
from
.modeling_tf_pytorch_utils
import
(
convert_tf_weight_name_to_pt_weight_name
,
load_pytorch_checkpoint_in_tf2_model
,
load_pytorch_checkpoint_in_tf2_model
,
...
...
transformers/configuration_auto.py
View file @
727a79b3
...
@@ -27,6 +27,7 @@ from .configuration_xlm import XLMConfig
...
@@ -27,6 +27,7 @@ from .configuration_xlm import XLMConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_roberta
import
RobertaConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_distilbert
import
DistilBertConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_ctrl
import
CTRLConfig
from
.configuration_t5
import
T5Config
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -64,6 +65,7 @@ class AutoConfig(object):
...
@@ -64,6 +65,7 @@ class AutoConfig(object):
The configuration class to instantiate is selected as the first pattern matching
The configuration class to instantiate is selected as the first pattern matching
in the `pretrained_model_name_or_path` string (in the following order):
in the `pretrained_model_name_or_path` string (in the following order):
- contains `t5`: T5Config (T5 model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `distilbert`: DistilBertConfig (DistilBERT model)
- contains `bert`: BertConfig (Bert model)
- contains `bert`: BertConfig (Bert model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
- contains `openai-gpt`: OpenAIGPTConfig (OpenAI GPT model)
...
@@ -114,7 +116,9 @@ class AutoConfig(object):
...
@@ -114,7 +116,9 @@ class AutoConfig(object):
assert unused_kwargs == {'foo': False}
assert unused_kwargs == {'foo': False}
"""
"""
if
'distilbert'
in
pretrained_model_name_or_path
:
if
't5'
in
pretrained_model_name_or_path
:
return
T5Config
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'distilbert'
in
pretrained_model_name_or_path
:
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
DistilBertConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
elif
'roberta'
in
pretrained_model_name_or_path
:
elif
'roberta'
in
pretrained_model_name_or_path
:
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
return
RobertaConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
)
...
...
transformers/configuration_t5.py
View file @
727a79b3
...
@@ -27,8 +27,7 @@ from .configuration_utils import PretrainedConfig
...
@@ -27,8 +27,7 @@ from .configuration_utils import PretrainedConfig
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
T5_PRETRAINED_CONFIG_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-config.json"
,
't5-small'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-config.json"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-config.json"
,
}
}
...
...
transformers/modeling_t5.py
View file @
727a79b3
...
@@ -41,8 +41,7 @@ logger = logging.getLogger(__name__)
...
@@ -41,8 +41,7 @@ logger = logging.getLogger(__name__)
# for the pretrained weights provided with the models
# for the pretrained weights provided with the models
####################################################
####################################################
T5_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
T5_PRETRAINED_MODEL_ARCHIVE_MAP
=
{
't5-base-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-base-uncased-pytorch_model.bin"
,
't5-small'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-small-pytorch_model.bin"
,
't5-large-uncased'
:
"https://s3.amazonaws.com/models.huggingface.co/bert/t5-large-uncased-pytorch_model.bin"
,
}
}
####################################################
####################################################
...
@@ -442,7 +441,7 @@ class T5PreTrainedModel(PreTrainedModel):
...
@@ -442,7 +441,7 @@ class T5PreTrainedModel(PreTrainedModel):
if
isinstance
(
module
,
nn
.
LayerNorm
):
if
isinstance
(
module
,
nn
.
LayerNorm
):
module
.
bias
.
data
.
zero_
()
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
factor
*
1.0
)
module
.
weight
.
data
.
fill_
(
factor
*
1.0
)
elif
isinstance
(
module
,
T5Model
):
elif
isinstance
(
module
,
(
T5Model
,
T5WithLMHeadModel
)
):
# Mesh TensorFlow embeddings initialization
# Mesh TensorFlow embeddings initialization
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
module
.
shared
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
1.0
)
module
.
shared
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
1.0
)
...
@@ -502,11 +501,10 @@ class T5Stack(T5PreTrainedModel):
...
@@ -502,11 +501,10 @@ class T5Stack(T5PreTrainedModel):
# ourselves in which case we just need to make it broadcastable to all heads.
# ourselves in which case we just need to make it broadcastable to all heads.
if
attention_mask
.
dim
()
==
3
:
if
attention_mask
.
dim
()
==
3
:
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
extended_attention_mask
=
attention_mask
[:,
None
,
:,
:]
elif
attention_mask
.
dim
()
==
2
:
# Provided a padding mask of dimensions [batch_size, seq_length]
# Provided a padding mask of dimensions [batch_size, seq_length]
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is a decoder, apply a causal mask in addition to the padding mask
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
# - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
if
attention_mask
.
dim
()
==
2
:
if
self
.
config
.
is_decoder
:
if
self
.
config
.
is_decoder
:
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
hidden_states
.
device
)
seq_ids
=
torch
.
arange
(
seq_length
,
device
=
hidden_states
.
device
)
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
causal_mask
=
seq_ids
[
None
,
None
,
:].
repeat
(
batch_size
,
seq_length
,
1
)
<=
seq_ids
[
None
,
:,
None
]
...
@@ -593,7 +591,7 @@ class T5Stack(T5PreTrainedModel):
...
@@ -593,7 +591,7 @@ class T5Stack(T5PreTrainedModel):
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
T5_START_DOCSTRING
=
r
""" The T5 model was proposed in
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
`Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer`_
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
by Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang, Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu.
It's an encoder decoder
pre-trained transformer
.
It's an encoder decoder
transformer pre-trained in a text-to-text denoising generative setting
.
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
This model is a PyTorch `torch.nn.Module`_ sub-class. Use it as a regular PyTorch Module and
refer to the PyTorch documentation for all matter related to general usage and behavior.
refer to the PyTorch documentation for all matter related to general usage and behavior.
...
@@ -634,16 +632,13 @@ T5_INPUTS_DOCSTRING = r"""
...
@@ -634,16 +632,13 @@ T5_INPUTS_DOCSTRING = r"""
Mask to avoid performing attention on padding token indices.
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**position_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Indices of positions of each input sequence tokens in the position embeddings.
Selected in the range ``[0, config.max_position_embeddings - 1]``.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
"""
"""
@
add_start_docstrings
(
"The bare
single stack (encoder or decoder) of a
T5 Model transformer outputting raw hidden-states"
@
add_start_docstrings
(
"The bare T5 Model transformer outputting raw hidden-states"
"without any specific head on top."
,
"without any specific head on top."
,
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
T5_START_DOCSTRING
,
T5_INPUTS_DOCSTRING
)
class
T5Model
(
T5PreTrainedModel
):
class
T5Model
(
T5PreTrainedModel
):
...
@@ -661,8 +656,8 @@ class T5Model(T5PreTrainedModel):
...
@@ -661,8 +656,8 @@ class T5Model(T5PreTrainedModel):
Examples::
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-
base-uncased
')
tokenizer = T5Tokenizer.from_pretrained('t5-
small
')
model = T5Model.from_pretrained('t5-
base-uncased
')
model = T5Model.from_pretrained('t5-
small
')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids)
outputs = model(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
...
@@ -752,8 +747,8 @@ class T5WithLMHeadModel(T5PreTrainedModel):
...
@@ -752,8 +747,8 @@ class T5WithLMHeadModel(T5PreTrainedModel):
Examples::
Examples::
tokenizer = T5Tokenizer.from_pretrained('t5-
base-uncased
')
tokenizer = T5Tokenizer.from_pretrained('t5-
small
')
model = T5WithLMHeadModel.from_pretrained('t5-
base-uncased
')
model = T5WithLMHeadModel.from_pretrained('t5-
small
')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
outputs = model(input_ids, lm_labels=input_ids)
outputs = model(input_ids, lm_labels=input_ids)
loss, prediction_scores = outputs[:2]
loss, prediction_scores = outputs[:2]
...
@@ -763,31 +758,73 @@ class T5WithLMHeadModel(T5PreTrainedModel):
...
@@ -763,31 +758,73 @@ class T5WithLMHeadModel(T5PreTrainedModel):
super
(
T5WithLMHeadModel
,
self
).
__init__
(
config
)
super
(
T5WithLMHeadModel
,
self
).
__init__
(
config
)
self
.
model_dim
=
config
.
d_model
self
.
model_dim
=
config
.
d_model
self
.
transformer
=
T5Model
(
config
)
self
.
shared
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
d_model
)
encoder_config
=
copy
.
deepcopy
(
config
)
self
.
encoder
=
T5Stack
(
encoder_config
)
decoder_config
=
copy
.
deepcopy
(
config
)
decoder_config
.
is_decoder
=
True
self
.
decoder
=
T5Stack
(
decoder_config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
d_model
,
config
.
vocab_size
,
bias
=
False
)
self
.
init_weights
()
self
.
init_weights
()
def
get_input_embeddings
(
self
):
return
self
.
shared
def
set_input_embeddings
(
self
,
new_embeddings
):
self
.
shared
=
new_embeddings
def
get_output_embeddings
(
self
):
def
get_output_embeddings
(
self
):
return
self
.
lm_head
return
self
.
lm_head
def
forward
(
self
,
**
kwargs
):
def
forward
(
self
,
**
kwargs
):
# keyword arguments come in 3 flavors: encoder-specific (prefixed by
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict.
lm_labels
=
kwargs
.
pop
(
'decoder_lm_labels'
,
None
)
lm_labels
=
kwargs
.
pop
(
'decoder_lm_labels'
,
None
)
outputs
=
self
.
transformer
(
**
kwargs
)
sequence_output
=
outputs
[
0
]
kwargs_common
=
dict
((
k
,
v
)
for
k
,
v
in
kwargs
.
items
()
if
not
k
.
startswith
(
"encoder_"
)
and
not
k
.
startswith
(
"decoder_"
))
kwargs_encoder
=
kwargs_common
.
copy
()
kwargs_decoder
=
kwargs_common
.
copy
()
kwargs_encoder
.
update
(
dict
((
k
[
len
(
"encoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"encoder_"
)))
kwargs_decoder
.
update
(
dict
((
k
[
len
(
"decoder_"
):],
v
)
for
k
,
v
in
kwargs
.
items
()
if
k
.
startswith
(
"decoder_"
)))
# Encode if needed (training, first prediction pass)
encoder_hidden_states
=
kwargs_encoder
.
pop
(
"hidden_states"
,
None
)
if
encoder_hidden_states
is
None
:
encoder_inputs_ids
=
kwargs_encoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
encoder_inputs_ids
)
# Convert inputs in embeddings
encoder_outputs
=
self
.
encoder
(
hidden_states
,
**
kwargs_encoder
)
encoder_hidden_states
=
encoder_outputs
[
0
]
else
:
encoder_outputs
=
()
# Decode
decoder_inputs_ids
=
kwargs_decoder
.
pop
(
"input_ids"
)
hidden_states
=
self
.
shared
(
decoder_inputs_ids
)
# Convert inputs in embeddings
kwargs_decoder
[
"encoder_hidden_states"
]
=
encoder_hidden_states
kwargs_decoder
[
"encoder_attention_mask"
]
=
kwargs_encoder
.
get
(
"attention_mask"
,
None
)
decoder_outputs
=
self
.
decoder
(
hidden_states
,
**
kwargs_decoder
)
sequence_output
=
decoder_outputs
[
0
]
# Rescale output before projecting on vocab
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output
=
sequence_output
*
(
self
.
model_dim
**
-
0.5
)
sequence_output
=
sequence_output
*
(
self
.
model_dim
**
-
0.5
)
lm_logits
=
self
.
lm_head
(
sequence_output
)
lm_logits
=
self
.
lm_head
(
sequence_output
)
outputs
=
(
lm_logits
,)
+
outputs
[
1
:]
# Add hidden states and attention if they are here
decoder_
outputs
=
(
lm_logits
,)
+
decoder_
outputs
[
1
:]
# Add hidden states and attention if they are here
if
lm_labels
is
not
None
:
if
lm_labels
is
not
None
:
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_logits
=
lm_logits
[...,
:
-
1
,
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
shift_labels
=
lm_labels
[...,
1
:].
contiguous
()
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss_fct
=
CrossEntropyLoss
(
ignore_index
=-
1
)
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
loss
=
loss_fct
(
shift_logits
.
view
(
-
1
,
shift_logits
.
size
(
-
1
)),
shift_labels
.
view
(
-
1
))
shift_labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
decoder_
outputs
=
(
loss
,)
+
decoder_
outputs
# TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
return
outputs
# (lm_loss), lm_logits, (hidden_states), (attentions)
return
decoder_outputs
+
encoder_outputs
transformers/modeling_tf_pytorch_utils.py
View file @
727a79b3
...
@@ -156,7 +156,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
...
@@ -156,7 +156,7 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
e
.
args
+=
(
symbolic_weight
.
shape
,
array
.
shape
)
raise
e
raise
e
logger
.
info
(
"Initialize TF weight {}"
.
format
(
symbolic_weight
.
name
))
#
logger.
warning
("Initialize TF weight {}".format(symbolic_weight.name))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
weight_value_tuples
.
append
((
symbolic_weight
,
array
))
all_pytorch_weights
.
discard
(
name
)
all_pytorch_weights
.
discard
(
name
)
...
@@ -269,7 +269,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
...
@@ -269,7 +269,7 @@ def load_tf2_weights_in_pytorch_model(pt_model, tf_weights, allow_missing_keys=F
e
.
args
+=
(
pt_weight
.
shape
,
array
.
shape
)
e
.
args
+=
(
pt_weight
.
shape
,
array
.
shape
)
raise
e
raise
e
logger
.
info
(
"Initialize PyTorch weight {}"
.
format
(
pt_weight_name
))
#
logger.
warning
("Initialize PyTorch weight {}".format(pt_weight_name))
new_pt_params_dict
[
pt_weight_name
]
=
torch
.
from_numpy
(
array
)
new_pt_params_dict
[
pt_weight_name
]
=
torch
.
from_numpy
(
array
)
loaded_pt_weights_data_ptr
[
pt_weight
.
data_ptr
()]
=
torch
.
from_numpy
(
array
)
loaded_pt_weights_data_ptr
[
pt_weight
.
data_ptr
()]
=
torch
.
from_numpy
(
array
)
...
...
transformers/modeling_tf_t5.py
View file @
727a79b3
This diff is collapsed.
Click to expand it.
transformers/modeling_utils.py
View file @
727a79b3
...
@@ -160,7 +160,6 @@ class PreTrainedModel(nn.Module):
...
@@ -160,7 +160,6 @@ class PreTrainedModel(nn.Module):
base_model
.
vocab_size
=
new_num_tokens
base_model
.
vocab_size
=
new_num_tokens
# Tie weights again if needed
# Tie weights again if needed
if
hasattr
(
self
,
'tie_weights'
):
self
.
tie_weights
()
self
.
tie_weights
()
return
model_embeds
return
model_embeds
...
@@ -458,8 +457,7 @@ class PreTrainedModel(nn.Module):
...
@@ -458,8 +457,7 @@ class PreTrainedModel(nn.Module):
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
raise
RuntimeError
(
'Error(s) in loading state_dict for {}:
\n\t
{}'
.
format
(
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
model
.
__class__
.
__name__
,
"
\n\t
"
.
join
(
error_msgs
)))
if
hasattr
(
model
,
'tie_weights'
):
model
.
tie_weights
()
# make sure word embedding weights are still tied if needed
model
.
tie_weights
()
# make sure word embedding weights are still tied
# Set model in evaluation mode to desactivate DropOut modules by default
# Set model in evaluation mode to desactivate DropOut modules by default
model
.
eval
()
model
.
eval
()
...
...
transformers/tests/modeling_tf_common_test.py
View file @
727a79b3
...
@@ -69,6 +69,7 @@ class TFCommonTestCases:
...
@@ -69,6 +69,7 @@ class TFCommonTestCases:
test_torchscript
=
True
test_torchscript
=
True
test_pruning
=
True
test_pruning
=
True
test_resize_embeddings
=
True
test_resize_embeddings
=
True
is_encoder_decoder
=
False
def
test_initialization
(
self
):
def
test_initialization
(
self
):
pass
pass
...
@@ -156,6 +157,10 @@ class TFCommonTestCases:
...
@@ -156,6 +157,10 @@ class TFCommonTestCases:
def
test_compile_tf_model
(
self
):
def
test_compile_tf_model
(
self
):
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
config
,
inputs_dict
=
self
.
model_tester
.
prepare_config_and_inputs_for_common
()
if
self
.
is_encoder_decoder
:
input_ids
=
{
'decoder_input_ids'
:
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'decoder_input_ids'
,
dtype
=
'int32'
),
'encoder_input_ids'
:
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'encoder_input_ids'
,
dtype
=
'int32'
)}
else
:
input_ids
=
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'input_ids'
,
dtype
=
'int32'
)
input_ids
=
tf
.
keras
.
Input
(
batch_shape
=
(
2
,
2000
),
name
=
'input_ids'
,
dtype
=
'int32'
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
3e-5
,
epsilon
=
1e-08
,
clipnorm
=
1.0
)
optimizer
=
tf
.
keras
.
optimizers
.
Adam
(
learning_rate
=
3e-5
,
epsilon
=
1e-08
,
clipnorm
=
1.0
)
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
loss
=
tf
.
keras
.
losses
.
SparseCategoricalCrossentropy
(
from_logits
=
True
)
...
@@ -189,7 +194,7 @@ class TFCommonTestCases:
...
@@ -189,7 +194,7 @@ class TFCommonTestCases:
outputs_dict
=
model
(
inputs_dict
)
outputs_dict
=
model
(
inputs_dict
)
inputs_keywords
=
copy
.
deepcopy
(
inputs_dict
)
inputs_keywords
=
copy
.
deepcopy
(
inputs_dict
)
input_ids
=
inputs_keywords
.
pop
(
'input_ids'
)
input_ids
=
inputs_keywords
.
pop
(
'input_ids'
,
inputs_keywords
.
pop
(
'decoder_input_ids'
)
)
outputs_keywords
=
model
(
input_ids
,
**
inputs_keywords
)
outputs_keywords
=
model
(
input_ids
,
**
inputs_keywords
)
output_dict
=
outputs_dict
[
0
].
numpy
()
output_dict
=
outputs_dict
[
0
].
numpy
()
...
@@ -216,12 +221,24 @@ class TFCommonTestCases:
...
@@ -216,12 +221,24 @@ class TFCommonTestCases:
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
out_len
=
len
(
outputs
)
out_len
=
len
(
outputs
)
if
self
.
is_encoder_decoder
:
self
.
assertEqual
(
out_len
%
2
,
0
)
decoder_attentions
=
outputs
[(
out_len
//
2
)
-
1
]
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
False
)
self
.
assertEqual
(
len
(
decoder_attentions
),
self
.
model_tester
.
num_hidden_layers
)
self
.
assertListEqual
(
list
(
decoder_attentions
[
0
].
shape
[
-
3
:]),
[
self
.
model_tester
.
num_attention_heads
,
self
.
model_tester
.
seq_length
,
self
.
model_tester
.
key_len
if
hasattr
(
self
.
model_tester
,
'key_len'
)
else
self
.
model_tester
.
seq_length
])
# Check attention is always last and order is fine
# Check attention is always last and order is fine
config
.
output_attentions
=
True
config
.
output_attentions
=
True
config
.
output_hidden_states
=
True
config
.
output_hidden_states
=
True
model
=
model_class
(
config
)
model
=
model_class
(
config
)
outputs
=
model
(
inputs_dict
)
outputs
=
model
(
inputs_dict
)
self
.
assertEqual
(
out_len
+
1
,
len
(
outputs
))
self
.
assertEqual
(
out_len
+
(
2
if
self
.
is_encoder_decoder
else
1
)
,
len
(
outputs
))
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_attentions
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
self
.
assertEqual
(
model
.
config
.
output_hidden_states
,
True
)
...
...
transformers/tests/modeling_tf_t5_test.py
View file @
727a79b3
...
@@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester
...
@@ -26,7 +26,7 @@ from .configuration_common_test import ConfigTester
from
transformers
import
T5Config
,
is_tf_available
from
transformers
import
T5Config
,
is_tf_available
if
False
:
#
is_tf_available():
if
is_tf_available
():
import
tensorflow
as
tf
import
tensorflow
as
tf
from
transformers.modeling_tf_t5
import
(
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
transformers.modeling_tf_t5
import
(
TFT5Model
,
TFT5WithLMHeadModel
,
TF_T5_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
else
:
...
@@ -35,7 +35,8 @@ else:
...
@@ -35,7 +35,8 @@ else:
class
TFT5ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
class
TFT5ModelTest
(
TFCommonTestCases
.
TFCommonModelTester
):
all_model_classes
=
(
TFT5Model
,
TFT5WithLMHeadModel
)
if
False
else
()
# is_tf_available() else ()
is_encoder_decoder
=
True
all_model_classes
=
(
TFT5Model
,
TFT5WithLMHeadModel
)
if
is_tf_available
()
else
()
class
TFT5ModelTester
(
object
):
class
TFT5ModelTester
(
object
):
...
@@ -45,22 +46,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -45,22 +46,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
seq_length
=
7
,
seq_length
=
7
,
is_training
=
True
,
is_training
=
True
,
use_input_mask
=
True
,
use_input_mask
=
True
,
use_token_type_ids
=
True
,
use_labels
=
True
,
use_labels
=
True
,
vocab_size
=
99
,
vocab_size
=
99
,
n_positions
=
14
,
hidden_size
=
32
,
hidden_size
=
32
,
num_hidden_layers
=
5
,
num_hidden_layers
=
5
,
num_attention_heads
=
4
,
num_attention_heads
=
4
,
intermediate_size
=
37
,
d_ff
=
37
,
hidden_act
=
"gelu"
,
relative_attention_num_buckets
=
8
,
hidden_dropout_prob
=
0.1
,
dropout_rate
=
0.1
,
attention_probs_dropout_prob
=
0.1
,
initializer_factor
=
0.002
,
max_position_embeddings
=
512
,
type_vocab_size
=
16
,
type_sequence_label_size
=
2
,
initializer_range
=
0.02
,
num_labels
=
3
,
num_choices
=
4
,
scope
=
None
,
scope
=
None
,
):
):
self
.
parent
=
parent
self
.
parent
=
parent
...
@@ -68,22 +63,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -68,22 +63,16 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
self
.
seq_length
=
seq_length
self
.
seq_length
=
seq_length
self
.
is_training
=
is_training
self
.
is_training
=
is_training
self
.
use_input_mask
=
use_input_mask
self
.
use_input_mask
=
use_input_mask
self
.
use_token_type_ids
=
use_token_type_ids
self
.
use_labels
=
use_labels
self
.
use_labels
=
use_labels
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
self
.
n_positions
=
n_positions
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
intermediate_size
=
intermediate_size
self
.
d_ff
=
d_ff
self
.
hidden_act
=
hidden_act
self
.
relative_attention_num_buckets
=
relative_attention_num_buckets
self
.
hidden_dropout_prob
=
hidden_dropout_prob
self
.
dropout_rate
=
dropout_rate
self
.
attention_probs_dropout_prob
=
attention_probs_dropout_prob
self
.
initializer_factor
=
initializer_factor
self
.
max_position_embeddings
=
max_position_embeddings
self
.
type_vocab_size
=
type_vocab_size
self
.
type_sequence_label_size
=
type_sequence_label_size
self
.
initializer_range
=
initializer_range
self
.
num_labels
=
num_labels
self
.
num_choices
=
num_choices
self
.
scope
=
scope
self
.
scope
=
scope
def
prepare_config_and_inputs
(
self
):
def
prepare_config_and_inputs
(
self
):
...
@@ -93,61 +82,53 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -93,61 +82,53 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
if
self
.
use_input_mask
:
if
self
.
use_input_mask
:
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
input_mask
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
vocab_size
=
2
)
token_type_ids
=
None
if
self
.
use_token_type_ids
:
token_type_ids
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
sequence_labels
=
None
token_labels
=
None
token_labels
=
None
choice_labels
=
None
if
self
.
use_labels
:
if
self
.
use_labels
:
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
num_labels
)
choice_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
num_choices
)
config
=
T5Config
(
config
=
T5Config
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
vocab_size_or_config_json_file
=
self
.
vocab_size
,
hidden_size
=
self
.
hidden_size
,
n_positions
=
self
.
n_positions
,
num_hidden_layers
=
self
.
num_hidden_layers
,
d_model
=
self
.
hidden_size
,
num_attention_heads
=
self
.
num_attention_heads
,
d_ff
=
self
.
d_ff
,
intermediate_size
=
self
.
intermediate_size
,
d_kv
=
self
.
hidden_size
//
self
.
num_attention_heads
,
hidden_act
=
self
.
hidden_act
,
num_layers
=
self
.
num_hidden_layers
,
hidden_dropout_prob
=
self
.
hidden_dropout_prob
,
num_heads
=
self
.
num_attention_heads
,
attention_probs_dropout_prob
=
self
.
attention_probs_dropout_prob
,
relative_attention_num_buckets
=
self
.
relative_attention_num_buckets
,
max_position_embeddings
=
self
.
max_position_embeddings
,
dropout_rate
=
self
.
dropout_rate
,
type_vocab_size
=
self
.
type_vocab_size
,
initializer_factor
=
self
.
initializer_factor
)
initializer_range
=
self
.
initializer_range
)
return
(
config
,
input_ids
,
input_mask
,
token_labels
)
return
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
input_mask
,
token_labels
):
def
create_and_check_t5_model
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice_labels
):
model
=
TFT5Model
(
config
=
config
)
model
=
TFT5Model
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
inputs
=
{
'encoder_input_ids'
:
input_ids
,
'attention_mask'
:
input_mask
,
'decoder_input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
}
'decoder_attention_mask'
:
input_mask
}
sequence_output
,
pooled_output
=
model
(
inputs
)
encoder_output
,
decoder_output
=
model
(
inputs
)
inputs
=
[
input_ids
,
input_mask
]
sequence_output
,
pooled_output
=
model
(
inputs
)
sequence_output
,
pooled_output
=
model
(
input_ids
)
encoder_output
,
decoder_output
=
model
(
input_ids
,
decoder_attention_mask
=
input_mask
,
encoder_input_ids
=
input_ids
)
result
=
{
result
=
{
"
sequence
_output"
:
sequence
_output
.
numpy
(),
"
encoder
_output"
:
encoder
_output
.
numpy
(),
"
pooled
_output"
:
pooled
_output
.
numpy
(),
"
decoder
_output"
:
decoder
_output
.
numpy
(),
}
}
self
.
parent
.
assertListEqual
(
self
.
parent
.
assertListEqual
(
list
(
result
[
"sequence_output"
].
shape
),
list
(
result
[
"encoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"decoder_output"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
[
self
.
batch_size
,
self
.
seq_length
,
self
.
hidden_size
])
self
.
parent
.
assertListEqual
(
list
(
result
[
"pooled_output"
].
shape
),
[
self
.
batch_size
,
self
.
hidden_size
])
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
token_type_ids
,
input_mask
,
sequence_labels
,
token_labels
,
choice
_labels
):
def
create_and_check_t5_with_lm_head
(
self
,
config
,
input_ids
,
input_mask
,
token
_labels
):
model
=
TFT5WithLMHeadModel
(
config
=
config
)
model
=
TFT5WithLMHeadModel
(
config
=
config
)
inputs
=
{
'input_ids'
:
input_ids
,
inputs
=
{
'
encoder_
input_ids'
:
input_ids
,
'
attention_mask
'
:
input_
mask
,
'
decoder_input_ids
'
:
input_
ids
,
'
token_type_ids'
:
token_type_ids
}
'
decoder_attention_mask'
:
input_mask
}
prediction_scores
,
=
model
(
inputs
)
prediction_scores
,
decoder_output
=
model
(
inputs
)
result
=
{
result
=
{
"prediction_scores"
:
prediction_scores
.
numpy
(),
"prediction_scores"
:
prediction_scores
.
numpy
(),
}
}
...
@@ -158,14 +139,15 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -158,14 +139,15 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
def
prepare_config_and_inputs_for_common
(
self
):
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids
,
token_type_ids
,
input_mask
,
(
config
,
input_ids
,
input_mask
,
token_labels
)
=
config_and_inputs
sequence_labels
,
token_labels
,
choice_labels
)
=
config_and_inputs
inputs_dict
=
{
'encoder_input_ids'
:
input_ids
,
inputs_dict
=
{
'input_ids'
:
input_ids
,
'token_type_ids'
:
token_type_ids
,
'attention_mask'
:
input_mask
}
'decoder_input_ids'
:
input_ids
,
'decoder_attention_mask'
:
input_mask
}
return
config
,
inputs_dict
return
config
,
inputs_dict
def
setUp
(
self
):
def
setUp
(
self
):
self
.
model_tester
=
TFT5ModelTest
.
TFT5ModelTester
(
self
)
self
.
model_tester
=
TFT5ModelTest
.
TFT5ModelTester
(
self
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
T5Config
,
hidden_size
=
37
)
self
.
config_tester
=
ConfigTester
(
self
,
config_class
=
T5Config
,
d_model
=
37
)
def
test_config
(
self
):
def
test_config
(
self
):
self
.
config_tester
.
run_common_tests
()
self
.
config_tester
.
run_common_tests
()
...
@@ -181,7 +163,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
...
@@ -181,7 +163,7 @@ class TFT5ModelTest(TFCommonTestCases.TFCommonModelTester):
@
pytest
.
mark
.
slow
@
pytest
.
mark
.
slow
def
test_model_from_pretrained
(
self
):
def
test_model_from_pretrained
(
self
):
cache_dir
=
"/tmp/transformers_test/"
cache_dir
=
"/tmp/transformers_test/"
for
model_name
in
[
't5-
base
'
]:
for
model_name
in
[
't5-
small
'
]:
model
=
TFT5Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
model
=
TFT5Model
.
from_pretrained
(
model_name
,
cache_dir
=
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
shutil
.
rmtree
(
cache_dir
)
self
.
assertIsNotNone
(
model
)
self
.
assertIsNotNone
(
model
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment