Unverified Commit 86822a35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

T5 & mT5 (#8552)

* add mt5 and t5v1_1 model

* fix tests

* correct some imports

* add tf model

* finish tf t5

* improve examples

* fix copies

* clean doc
parent 9e01f988
...@@ -248,6 +248,7 @@ conversion utilities for the following models: ...@@ -248,6 +248,7 @@ conversion utilities for the following models:
model_doc/marian model_doc/marian
model_doc/mbart model_doc/mbart
model_doc/mobilebert model_doc/mobilebert
model_doc/mt5
model_doc/gpt model_doc/gpt
model_doc/gpt2 model_doc/gpt2
model_doc/pegasus model_doc/pegasus
......
MT5
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
<https://arxiv.org/abs/2010.11934>`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya
Siddhant, Aditya Barua, Colin Raffel.
The abstract from the paper is the following:
*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
benchmarks. All of the code and model checkpoints*
The original code can be found `here <https://github.com/google-research/multilingual-t5>`__.
MT5Config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5Config
:members:
MT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5Model
:members:
MT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5ForConditionalGeneration
:members:
TFMT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMT5Model
:members:
TFMT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMT5ForConditionalGeneration
:members:
...@@ -498,6 +498,7 @@ if is_torch_available(): ...@@ -498,6 +498,7 @@ if is_torch_available():
MobileBertPreTrainedModel, MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert, load_tf_weights_in_mobilebert,
) )
from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model
from .models.openai import ( from .models.openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel, OpenAIGPTDoubleHeadsModel,
...@@ -791,6 +792,7 @@ if is_tf_available(): ...@@ -791,6 +792,7 @@ if is_tf_available():
TFMobileBertModel, TFMobileBertModel,
TFMobileBertPreTrainedModel, TFMobileBertPreTrainedModel,
) )
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .models.openai import ( from .models.openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST, TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel, TFOpenAIGPTDoubleHeadsModel,
......
...@@ -40,6 +40,7 @@ from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP, ...@@ -40,6 +40,7 @@ from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
from ..marian.configuration_marian import MarianConfig from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..mt5.configuration_mt5 import MT5Config
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
...@@ -101,6 +102,7 @@ CONFIG_MAPPING = OrderedDict( ...@@ -101,6 +102,7 @@ CONFIG_MAPPING = OrderedDict(
[ [
# Add configs here # Add configs here
("retribert", RetriBertConfig), ("retribert", RetriBertConfig),
("mt5", MT5Config),
("t5", T5Config), ("t5", T5Config),
("mobilebert", MobileBertConfig), ("mobilebert", MobileBertConfig),
("distilbert", DistilBertConfig), ("distilbert", DistilBertConfig),
...@@ -178,6 +180,7 @@ MODEL_NAMES_MAPPING = OrderedDict( ...@@ -178,6 +180,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("rag", "RAG"), ("rag", "RAG"),
("xlm-prophetnet", "XLMProphetNet"), ("xlm-prophetnet", "XLMProphetNet"),
("prophetnet", "ProphetNet"), ("prophetnet", "ProphetNet"),
("mt5", "mT5"),
] ]
) )
......
...@@ -120,6 +120,7 @@ from ..mobilebert.modeling_mobilebert import ( ...@@ -120,6 +120,7 @@ from ..mobilebert.modeling_mobilebert import (
MobileBertForTokenClassification, MobileBertForTokenClassification,
MobileBertModel, MobileBertModel,
) )
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
...@@ -209,6 +210,7 @@ from .configuration_auto import ( ...@@ -209,6 +210,7 @@ from .configuration_auto import (
MarianConfig, MarianConfig,
MBartConfig, MBartConfig,
MobileBertConfig, MobileBertConfig,
MT5Config,
OpenAIGPTConfig, OpenAIGPTConfig,
PegasusConfig, PegasusConfig,
ProphetNetConfig, ProphetNetConfig,
...@@ -235,6 +237,7 @@ MODEL_MAPPING = OrderedDict( ...@@ -235,6 +237,7 @@ MODEL_MAPPING = OrderedDict(
[ [
# Base model mapping # Base model mapping
(RetriBertConfig, RetriBertModel), (RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model), (T5Config, T5Model),
(DistilBertConfig, DistilBertModel), (DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel), (AlbertConfig, AlbertModel),
...@@ -376,6 +379,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -376,6 +379,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
(MT5Config, MT5ForConditionalGeneration),
(T5Config, T5ForConditionalGeneration), (T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration), (PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel), (MarianConfig, MarianMTModel),
......
...@@ -106,6 +106,7 @@ from ..mobilebert.modeling_tf_mobilebert import ( ...@@ -106,6 +106,7 @@ from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForTokenClassification, TFMobileBertForTokenClassification,
TFMobileBertModel, TFMobileBertModel,
) )
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
from ..roberta.modeling_tf_roberta import ( from ..roberta.modeling_tf_roberta import (
...@@ -161,6 +162,7 @@ from .configuration_auto import ( ...@@ -161,6 +162,7 @@ from .configuration_auto import (
MarianConfig, MarianConfig,
MBartConfig, MBartConfig,
MobileBertConfig, MobileBertConfig,
MT5Config,
OpenAIGPTConfig, OpenAIGPTConfig,
PegasusConfig, PegasusConfig,
RobertaConfig, RobertaConfig,
...@@ -182,6 +184,7 @@ TF_MODEL_MAPPING = OrderedDict( ...@@ -182,6 +184,7 @@ TF_MODEL_MAPPING = OrderedDict(
[ [
# Base model mapping # Base model mapping
(LxmertConfig, TFLxmertModel), (LxmertConfig, TFLxmertModel),
(MT5Config, TFMT5Model),
(T5Config, TFT5Model), (T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel), (DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel), (AlbertConfig, TFAlbertModel),
...@@ -294,6 +297,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict( ...@@ -294,6 +297,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[ [
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
(MT5Config, TFMT5ForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration), (T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel), (MarianConfig, TFMarianMTModel),
(MBartConfig, TFMBartForConditionalGeneration), (MBartConfig, TFMBartForConditionalGeneration),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_mt5 import MT5Config
if is_torch_available():
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
if is_tf_available():
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class MT5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a
:class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the mT5 `google/mt5-small <https://huggingface.co/google/mt5-small>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Arguments:
vocab_size (:obj:`int`, `optional`, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
d_model (:obj:`int`, `optional`, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (:obj:`int`, `optional`, defaults to 64):
Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model
// num_heads`.
d_ff (:obj:`int`, `optional`, defaults to 1024):
Size of the intermediate feed forward layer in each :obj:`T5Block`.
num_layers (:obj:`int`, `optional`, defaults to 8):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (:obj:`int`, `optional`):
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not
set.
num_heads (:obj:`int`, `optional`, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder.
relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32):
The number of buckets to use for each attention layer.
dropout_rate (:obj:`float`, `optional`, defaults to 0.1):
The ratio for all dropout layers.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
"""
model_type = "mt5"
def __init__(
self,
vocab_size=250112,
d_model=512,
d_kv=64,
d_ff=1024,
num_layers=8,
num_decoder_layers=None,
num_heads=6,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="gated-gelu",
is_encoder_decoder=True,
tokenizer_class="T5Tokenizer",
tie_word_embeddings=False,
pad_token_id=0,
eos_token_id=1,
decoder_start_token_id=0,
**kwargs
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
# coding=utf-8
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch mT5 model. """
from ...utils import logging
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from .configuration_mt5 import MT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
class MT5Model(T5Model):
r"""
This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation
alongside usage examples.
Examples::
>>> from transformers import MT5Model, T5Tokenizer
>>> model = MT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
class MT5ForConditionalGeneration(T5ForConditionalGeneration):
r"""
This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(**batch)
>>> loss = outputs.loss
"""
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
# coding=utf-8
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tensorflow mT5 model. """
from ...utils import logging
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from .configuration_mt5 import MT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
class TFMT5Model(TFT5Model):
r"""
This class overrides :class:`~transformers.TFT5Model`. Please check the superclass for the appropriate
documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5Model, T5Tokenizer
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> batch["decoder_input_ids"] = batch["labels"]
>>> del batch["labels"]
>>> outputs = model(batch)
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
config_class = MT5Config
class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
r"""
This class overrides :class:`~transformers.TFT5ForConditionalGeneration`. Please check the superclass for the
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> outputs = model(batch)
>>> loss = outputs.loss
"""
model_type = "mt5"
config_class = MT5Config
# coding=utf-8 # coding=utf-8
# Copyright 2010, The T5 Authors and HuggingFace Inc. # Copyright 2020, The T5 Authors and HuggingFace Inc.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -43,9 +43,6 @@ class T5Config(PretrainedConfig): ...@@ -43,9 +43,6 @@ class T5Config(PretrainedConfig):
vocab_size (:obj:`int`, `optional`, defaults to 32128): vocab_size (:obj:`int`, `optional`, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`. :obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
n_positions (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
d_model (:obj:`int`, `optional`, defaults to 512): d_model (:obj:`int`, `optional`, defaults to 512):
Size of the encoder layers and the pooler layer. Size of the encoder layers and the pooler layer.
d_kv (:obj:`int`, `optional`, defaults to 64): d_kv (:obj:`int`, `optional`, defaults to 64):
...@@ -69,6 +66,9 @@ class T5Config(PretrainedConfig): ...@@ -69,6 +66,9 @@ class T5Config(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1): initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing). testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
""" """
model_type = "t5" model_type = "t5"
...@@ -85,6 +85,7 @@ class T5Config(PretrainedConfig): ...@@ -85,6 +85,7 @@ class T5Config(PretrainedConfig):
dropout_rate=0.1, dropout_rate=0.1,
layer_norm_epsilon=1e-6, layer_norm_epsilon=1e-6,
initializer_factor=1.0, initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True, is_encoder_decoder=True,
pad_token_id=0, pad_token_id=0,
eos_token_id=1, eos_token_id=1,
...@@ -109,6 +110,7 @@ class T5Config(PretrainedConfig): ...@@ -109,6 +110,7 @@ class T5Config(PretrainedConfig):
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
@property @property
def hidden_size(self): def hidden_size(self):
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import argparse import argparse
from transformers import T5Config, T5Model, load_tf_weights_in_t5 from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging from transformers.utils import logging
...@@ -28,7 +28,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du ...@@ -28,7 +28,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
# Initialise PyTorch model # Initialise PyTorch model
config = T5Config.from_json_file(config_file) config = T5Config.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config))) print("Building PyTorch model from configuration: {}".format(str(config)))
model = T5Model(config) model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint # Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path) load_tf_weights_in_t5(model, config, tf_checkpoint_path)
......
...@@ -25,6 +25,7 @@ import torch.nn.functional as F ...@@ -25,6 +25,7 @@ import torch.nn.functional as F
from torch import nn from torch import nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -140,6 +141,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path): ...@@ -140,6 +141,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
continue continue
elif scope_names[0] == "logits": elif scope_names[0] == "logits":
pointer = getattr(pointer, "lm_head") pointer = getattr(pointer, "lm_head")
elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
pointer = getattr(pointer, f"wi_{scope_names[1]}")
continue
else: else:
try: try:
pointer = getattr(pointer, scope_names[0]) pointer = getattr(pointer, scope_names[0])
...@@ -211,10 +215,36 @@ class T5DenseReluDense(nn.Module): ...@@ -211,10 +215,36 @@ class T5DenseReluDense(nn.Module):
return hidden_states return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module): class T5LayerFF(nn.Module):
def __init__(self, config): def __init__(self, config):
super().__init__() super().__init__()
self.DenseReluDense = T5DenseReluDense(config) if config.feed_forward_proj == "relu":
self.DenseReluDense = T5DenseReluDense(config)
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = T5DenseGatedGeluDense(config)
else:
raise ValueError(
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon) self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate) self.dropout = nn.Dropout(config.dropout_rate)
...@@ -641,6 +671,16 @@ class T5PreTrainedModel(PreTrainedModel): ...@@ -641,6 +671,16 @@ class T5PreTrainedModel(PreTrainedModel):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None: if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_() module.wo.bias.data.zero_()
elif isinstance(module, T5DenseGatedGeluDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
module.wi_1.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, T5Attention): elif isinstance(module, T5Attention):
# Mesh TensorFlow attention initialization to avoid scaling before softmax # Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
...@@ -1099,8 +1139,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1099,8 +1139,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
r"encoder\.embed_tokens\.weight", r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight", r"decoder\.embed_tokens\.weight",
r"lm_head\.weight", r"lm_head\.weight",
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight", r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
] ]
...@@ -1262,9 +1300,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel): ...@@ -1262,9 +1300,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
) )
sequence_output = decoder_outputs[0] sequence_output = decoder_outputs[0]
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5) # Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
lm_logits = self.lm_head(sequence_output) lm_logits = self.lm_head(sequence_output)
loss = None loss = None
......
# coding=utf-8 # coding=utf-8
# Copyright 2018 T5 Authors and The HuggingFace Inc. team. # Copyright 2020 T5 Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
...@@ -26,6 +26,7 @@ import tensorflow as tf ...@@ -26,6 +26,7 @@ import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings from transformers.modeling_tf_utils import TFWrappedEmbeddings
from ...activations_tf import get_tf_activation
from ...file_utils import ( from ...file_utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -103,10 +104,35 @@ class TFT5DenseReluDense(tf.keras.layers.Layer): ...@@ -103,10 +104,35 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
return hidden_states return hidden_states
class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new")
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.wo(hidden_states)
return hidden_states
class TFT5LayerFF(tf.keras.layers.Layer): class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs): def __init__(self, config, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense") if config.feed_forward_proj == "relu":
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
else:
raise ValueError(
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm") self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate) self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
...@@ -547,9 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer): ...@@ -547,9 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.embed_tokens return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens): def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens self.embed_tokens = embed_tokens
...@@ -970,9 +993,6 @@ class TFT5Model(TFT5PreTrainedModel): ...@@ -970,9 +993,6 @@ class TFT5Model(TFT5PreTrainedModel):
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
def get_output_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0] self.shared.vocab_size = self.shared.weight.shape[0]
...@@ -1165,11 +1185,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1165,11 +1185,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
decoder_config.is_decoder = True decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder") self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
def get_input_embeddings(self): def get_input_embeddings(self):
return self.shared return self.shared
def get_output_embeddings(self): def get_output_embeddings(self):
return self.shared if self.config.tie_word_embeddings:
return self.shared
else:
return self.lm_head
def set_input_embeddings(self, new_embeddings): def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings self.shared.weight = new_embeddings
...@@ -1331,9 +1357,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling ...@@ -1331,9 +1357,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
training=training, training=training,
) )
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5) sequence_output = decoder_outputs[0]
embed_tokens = self.get_output_embeddings()
logits = embed_tokens(sequence_output, mode="linear") # T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5)
logits = self.get_output_embeddings()(sequence_output, mode="linear")
else:
logits = self.get_output_embeddings()(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits) loss = None if labels is None else self.compute_loss(labels, logits)
......
...@@ -1361,6 +1361,29 @@ def load_tf_weights_in_mobilebert(*args, **kwargs): ...@@ -1361,6 +1361,29 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
requires_pytorch(load_tf_weights_in_mobilebert) requires_pytorch(load_tf_weights_in_mobilebert)
class MT5Config:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class MT5Model:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
...@@ -970,6 +970,24 @@ class TFMobileBertPreTrainedModel: ...@@ -970,6 +970,24 @@ class TFMobileBertPreTrainedModel:
requires_tf(self) requires_tf(self)
class TFMT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFMT5Model:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
if is_torch_available():
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@require_torch
@require_sentencepiece
@require_tokenizers
class MT5IntegrationTest(unittest.TestCase):
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small", return_dict=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -84.9127
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
...@@ -490,6 +490,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase): ...@@ -490,6 +490,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs) self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_with_lm_head(self): def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs) self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
...@@ -569,7 +577,7 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -569,7 +577,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab) >>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
""" """
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device) model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("t5-small") tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
...@@ -581,6 +589,32 @@ class T5ModelIntegrationTests(unittest.TestCase): ...@@ -581,6 +589,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
EXPECTED_SCORE = -19.0845 EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4) self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_small_v1_1_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_v1_1_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1_1_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -59.0293
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow @slow
def test_summarization(self): def test_summarization(self):
model = self.model model = self.model
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
@require_tf
@require_sentencepiece
@require_tokenizers
class TFMT5ModelIntegrationTest(unittest.TestCase):
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFAutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -84.9127
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
...@@ -258,6 +258,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase): ...@@ -258,6 +258,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs) self.model_tester.create_and_check_t5_model(*config_and_inputs)
def test_t5_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_t5_model(config, *config_and_inputs[1:])
def test_with_lm_head(self): def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs() config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs) self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
...@@ -296,6 +303,58 @@ class TFT5ModelIntegrationTests(unittest.TestCase): ...@@ -296,6 +303,58 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
def model(self): def model(self):
return TFT5ForConditionalGeneration.from_pretrained("t5-base") return TFT5ForConditionalGeneration.from_pretrained("t5-base")
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_small_v1_1_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_v1.1_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1.1_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small")
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -59.0293
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow @slow
def test_summarization(self): def test_summarization(self):
model = self.model model = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment