Unverified Commit 86822a35 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

T5 & mT5 (#8552)

* add mt5 and t5v1_1 model

* fix tests

* correct some imports

* add tf model

* finish tf t5

* improve examples

* fix copies

* clean doc
parent 9e01f988
......@@ -248,6 +248,7 @@ conversion utilities for the following models:
model_doc/marian
model_doc/mbart
model_doc/mobilebert
model_doc/mt5
model_doc/gpt
model_doc/gpt2
model_doc/pegasus
......
MT5
-----------------------------------------------------------------------------------------------------------------------
Overview
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The mT5 model was presented in `mT5: A massively multilingual pre-trained text-to-text transformer
<https://arxiv.org/abs/2010.11934>`_ by Linting Xue, Noah Constant, Adam Roberts, Mihir Kale, Rami Al-Rfou, Aditya
Siddhant, Aditya Barua, Colin Raffel.
The abstract from the paper is the following:
*The recent "Text-to-Text Transfer Transformer" (T5) leveraged a unified text-to-text format and scale to attain
state-of-the-art results on a wide variety of English-language NLP tasks. In this paper, we introduce mT5, a
multilingual variant of T5 that was pre-trained on a new Common Crawl-based dataset covering 101 languages. We describe
the design and modified training of mT5 and demonstrate its state-of-the-art performance on many multilingual
benchmarks. All of the code and model checkpoints*
The original code can be found `here <https://github.com/google-research/multilingual-t5>`__.
MT5Config
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5Config
:members:
MT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5Model
:members:
MT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.MT5ForConditionalGeneration
:members:
TFMT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMT5Model
:members:
TFMT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.TFMT5ForConditionalGeneration
:members:
......@@ -498,6 +498,7 @@ if is_torch_available():
MobileBertPreTrainedModel,
load_tf_weights_in_mobilebert,
)
from .models.mt5 import MT5Config, MT5ForConditionalGeneration, MT5Model
from .models.openai import (
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
OpenAIGPTDoubleHeadsModel,
......@@ -791,6 +792,7 @@ if is_tf_available():
TFMobileBertModel,
TFMobileBertPreTrainedModel,
)
from .models.mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from .models.openai import (
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST,
TFOpenAIGPTDoubleHeadsModel,
......
......@@ -40,6 +40,7 @@ from ..lxmert.configuration_lxmert import LXMERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
from ..marian.configuration_marian import MarianConfig
from ..mbart.configuration_mbart import MBART_PRETRAINED_CONFIG_ARCHIVE_MAP, MBartConfig
from ..mobilebert.configuration_mobilebert import MobileBertConfig
from ..mt5.configuration_mt5 import MT5Config
from ..openai.configuration_openai import OPENAI_GPT_PRETRAINED_CONFIG_ARCHIVE_MAP, OpenAIGPTConfig
from ..pegasus.configuration_pegasus import PegasusConfig
from ..prophetnet.configuration_prophetnet import PROPHETNET_PRETRAINED_CONFIG_ARCHIVE_MAP, ProphetNetConfig
......@@ -101,6 +102,7 @@ CONFIG_MAPPING = OrderedDict(
[
# Add configs here
("retribert", RetriBertConfig),
("mt5", MT5Config),
("t5", T5Config),
("mobilebert", MobileBertConfig),
("distilbert", DistilBertConfig),
......@@ -178,6 +180,7 @@ MODEL_NAMES_MAPPING = OrderedDict(
("rag", "RAG"),
("xlm-prophetnet", "XLMProphetNet"),
("prophetnet", "ProphetNet"),
("mt5", "mT5"),
]
)
......
......@@ -120,6 +120,7 @@ from ..mobilebert.modeling_mobilebert import (
MobileBertForTokenClassification,
MobileBertModel,
)
from ..mt5.modeling_mt5 import MT5ForConditionalGeneration, MT5Model
from ..openai.modeling_openai import OpenAIGPTForSequenceClassification, OpenAIGPTLMHeadModel, OpenAIGPTModel
from ..pegasus.modeling_pegasus import PegasusForConditionalGeneration
from ..prophetnet.modeling_prophetnet import ProphetNetForCausalLM, ProphetNetForConditionalGeneration, ProphetNetModel
......@@ -209,6 +210,7 @@ from .configuration_auto import (
MarianConfig,
MBartConfig,
MobileBertConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
ProphetNetConfig,
......@@ -235,6 +237,7 @@ MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(RetriBertConfig, RetriBertModel),
(MT5Config, MT5Model),
(T5Config, T5Model),
(DistilBertConfig, DistilBertModel),
(AlbertConfig, AlbertModel),
......@@ -376,6 +379,7 @@ MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(MT5Config, MT5ForConditionalGeneration),
(T5Config, T5ForConditionalGeneration),
(PegasusConfig, PegasusForConditionalGeneration),
(MarianConfig, MarianMTModel),
......
......@@ -106,6 +106,7 @@ from ..mobilebert.modeling_tf_mobilebert import (
TFMobileBertForTokenClassification,
TFMobileBertModel,
)
from ..mt5.modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
from ..openai.modeling_tf_openai import TFOpenAIGPTLMHeadModel, TFOpenAIGPTModel
from ..pegasus.modeling_tf_pegasus import TFPegasusForConditionalGeneration
from ..roberta.modeling_tf_roberta import (
......@@ -161,6 +162,7 @@ from .configuration_auto import (
MarianConfig,
MBartConfig,
MobileBertConfig,
MT5Config,
OpenAIGPTConfig,
PegasusConfig,
RobertaConfig,
......@@ -182,6 +184,7 @@ TF_MODEL_MAPPING = OrderedDict(
[
# Base model mapping
(LxmertConfig, TFLxmertModel),
(MT5Config, TFMT5Model),
(T5Config, TFT5Model),
(DistilBertConfig, TFDistilBertModel),
(AlbertConfig, TFAlbertModel),
......@@ -294,6 +297,7 @@ TF_MODEL_FOR_MASKED_LM_MAPPING = OrderedDict(
TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
[
# Model for Seq2Seq Causal LM mapping
(MT5Config, TFMT5ForConditionalGeneration),
(T5Config, TFT5ForConditionalGeneration),
(MarianConfig, TFMarianMTModel),
(MBartConfig, TFMBartForConditionalGeneration),
......
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
from ...file_utils import is_tf_available, is_torch_available
from .configuration_mt5 import MT5Config
if is_torch_available():
from .modeling_mt5 import MT5ForConditionalGeneration, MT5Model
if is_tf_available():
from .modeling_tf_mt5 import TFMT5ForConditionalGeneration, TFMT5Model
# coding=utf-8
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" mT5 model configuration """
from ...configuration_utils import PretrainedConfig
from ...utils import logging
logger = logging.get_logger(__name__)
class MT5Config(PretrainedConfig):
r"""
This is the configuration class to store the configuration of a :class:`~transformers.MT5Model` or a
:class:`~transformers.TFMT5Model`. It is used to instantiate a mT5 model according to the specified arguments,
defining the model architecture. Instantiating a configuration with the defaults will yield a similar configuration
to that of the mT5 `google/mt5-small <https://huggingface.co/google/mt5-small>`__ architecture.
Configuration objects inherit from :class:`~transformers.PretrainedConfig` and can be used to control the model
outputs. Read the documentation from :class:`~transformers.PretrainedConfig` for more information.
Arguments:
vocab_size (:obj:`int`, `optional`, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
d_model (:obj:`int`, `optional`, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (:obj:`int`, `optional`, defaults to 64):
Size of the key, query, value projections per attention head. :obj:`d_kv` has to be equal to :obj:`d_model
// num_heads`.
d_ff (:obj:`int`, `optional`, defaults to 1024):
Size of the intermediate feed forward layer in each :obj:`T5Block`.
num_layers (:obj:`int`, `optional`, defaults to 8):
Number of hidden layers in the Transformer encoder.
num_decoder_layers (:obj:`int`, `optional`):
Number of hidden layers in the Transformer decoder. Will use the same value as :obj:`num_layers` if not
set.
num_heads (:obj:`int`, `optional`, defaults to 6):
Number of attention heads for each attention layer in the Transformer encoder.
relative_attention_num_buckets (:obj:`int`, `optional`, defaults to 32):
The number of buckets to use for each attention layer.
dropout_rate (:obj:`float`, `optional`, defaults to 0.1):
The ratio for all dropout layers.
layer_norm_eps (:obj:`float`, `optional`, defaults to 1e-6):
The epsilon used by the layer normalization layers.
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"gated-gelu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`.
"""
model_type = "mt5"
def __init__(
self,
vocab_size=250112,
d_model=512,
d_kv=64,
d_ff=1024,
num_layers=8,
num_decoder_layers=None,
num_heads=6,
relative_attention_num_buckets=32,
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="gated-gelu",
is_encoder_decoder=True,
tokenizer_class="T5Tokenizer",
tie_word_embeddings=False,
pad_token_id=0,
eos_token_id=1,
decoder_start_token_id=0,
**kwargs
):
super().__init__(
is_encoder_decoder=is_encoder_decoder,
tokenizer_class=tokenizer_class,
tie_word_embeddings=tie_word_embeddings,
pad_token_id=pad_token_id,
eos_token_id=eos_token_id,
decoder_start_token_id=decoder_start_token_id,
**kwargs,
)
self.vocab_size = vocab_size
self.d_model = d_model
self.d_kv = d_kv
self.d_ff = d_ff
self.num_layers = num_layers
self.num_decoder_layers = (
num_decoder_layers if num_decoder_layers is not None else self.num_layers
) # default = symmetry
self.num_heads = num_heads
self.relative_attention_num_buckets = relative_attention_num_buckets
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
@property
def hidden_size(self):
return self.d_model
@property
def num_attention_heads(self):
return self.num_heads
@property
def num_hidden_layers(self):
return self.num_layers
# coding=utf-8
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch mT5 model. """
from ...utils import logging
from ..t5.modeling_t5 import T5ForConditionalGeneration, T5Model
from .configuration_mt5 import MT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
class MT5Model(T5Model):
r"""
This class overrides :class:`~transformers.T5Model`. Please check the superclass for the appropriate documentation
alongside usage examples.
Examples::
>>> from transformers import MT5Model, T5Tokenizer
>>> model = MT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(input_ids=batch.input_ids, decoder_input_ids=batch.labels)
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
class MT5ForConditionalGeneration(T5ForConditionalGeneration):
r"""
This class overrides :class:`~transformers.T5ForConditionalGeneration`. Please check the superclass for the
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import MT5ForConditionalGeneration, T5Tokenizer
>>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="pt")
>>> outputs = model(**batch)
>>> loss = outputs.loss
"""
model_type = "mt5"
config_class = MT5Config
authorized_missing_keys = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
keys_to_never_save = [
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
]
# coding=utf-8
# Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Tensorflow mT5 model. """
from ...utils import logging
from ..t5.modeling_tf_t5 import TFT5ForConditionalGeneration, TFT5Model
from .configuration_mt5 import MT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
class TFMT5Model(TFT5Model):
r"""
This class overrides :class:`~transformers.TFT5Model`. Please check the superclass for the appropriate
documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5Model, T5Tokenizer
>>> model = TFMT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> batch["decoder_input_ids"] = batch["labels"]
>>> del batch["labels"]
>>> outputs = model(batch)
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
config_class = MT5Config
class TFMT5ForConditionalGeneration(TFT5ForConditionalGeneration):
r"""
This class overrides :class:`~transformers.TFT5ForConditionalGeneration`. Please check the superclass for the
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import TFMT5ForConditionalGeneration, T5Tokenizer
>>> model = TFMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> batch = tokenizer.prepare_seq2seq_batch(src_texts=[article], tgt_texts=[summary], return_tensors="tf")
>>> outputs = model(batch)
>>> loss = outputs.loss
"""
model_type = "mt5"
config_class = MT5Config
# coding=utf-8
# Copyright 2010, The T5 Authors and HuggingFace Inc.
# Copyright 2020, The T5 Authors and HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -43,9 +43,6 @@ class T5Config(PretrainedConfig):
vocab_size (:obj:`int`, `optional`, defaults to 32128):
Vocabulary size of the T5 model. Defines the number of different tokens that can be represented by the
:obj:`inputs_ids` passed when calling :class:`~transformers.T5Model` or :class:`~transformers.TFT5Model`.
n_positions (:obj:`int`, `optional`, defaults to 512):
The maximum sequence length that this model might ever be used with. Typically set this to something large
just in case (e.g., 512 or 1024 or 2048).
d_model (:obj:`int`, `optional`, defaults to 512):
Size of the encoder layers and the pooler layer.
d_kv (:obj:`int`, `optional`, defaults to 64):
......@@ -69,6 +66,9 @@ class T5Config(PretrainedConfig):
initializer_factor (:obj:`float`, `optional`, defaults to 1):
A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
testing).
feed_forward_proj (:obj:`string`, `optional`, defaults to :obj:`"relu"`):
Type of feed forward layer to be used. Should be one of :obj:`"relu"` or :obj:`"gated-gelu"`. T5v1.1 uses
the :obj:`"gated-gelu"` feed forward projection. Original T5 uses :obj:`"relu"`.
"""
model_type = "t5"
......@@ -85,6 +85,7 @@ class T5Config(PretrainedConfig):
dropout_rate=0.1,
layer_norm_epsilon=1e-6,
initializer_factor=1.0,
feed_forward_proj="relu",
is_encoder_decoder=True,
pad_token_id=0,
eos_token_id=1,
......@@ -109,6 +110,7 @@ class T5Config(PretrainedConfig):
self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
@property
def hidden_size(self):
......
......@@ -17,7 +17,7 @@
import argparse
from transformers import T5Config, T5Model, load_tf_weights_in_t5
from transformers import T5Config, T5ForConditionalGeneration, load_tf_weights_in_t5
from transformers.utils import logging
......@@ -28,7 +28,7 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, config_file, pytorch_du
# Initialise PyTorch model
config = T5Config.from_json_file(config_file)
print("Building PyTorch model from configuration: {}".format(str(config)))
model = T5Model(config)
model = T5ForConditionalGeneration(config)
# Load weights from tf checkpoint
load_tf_weights_in_t5(model, config, tf_checkpoint_path)
......
......@@ -25,6 +25,7 @@ import torch.nn.functional as F
from torch import nn
from torch.nn import CrossEntropyLoss
from ...activations import ACT2FN
from ...file_utils import (
DUMMY_INPUTS,
DUMMY_MASK,
......@@ -140,6 +141,9 @@ def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
continue
elif scope_names[0] == "logits":
pointer = getattr(pointer, "lm_head")
elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
pointer = getattr(pointer, f"wi_{scope_names[1]}")
continue
else:
try:
pointer = getattr(pointer, scope_names[0])
......@@ -211,10 +215,36 @@ class T5DenseReluDense(nn.Module):
return hidden_states
class T5DenseGatedGeluDense(nn.Module):
def __init__(self, config):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.gelu_act = ACT2FN["gelu_new"]
def forward(self, hidden_states):
hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerFF(nn.Module):
def __init__(self, config):
super().__init__()
if config.feed_forward_proj == "relu":
self.DenseReluDense = T5DenseReluDense(config)
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = T5DenseGatedGeluDense(config)
else:
raise ValueError(
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dropout = nn.Dropout(config.dropout_rate)
......@@ -641,6 +671,16 @@ class T5PreTrainedModel(PreTrainedModel):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, T5DenseGatedGeluDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
module.wi_1.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, T5Attention):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
......@@ -1099,8 +1139,6 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"lm_head\.weight",
r"encoder\.embed_tokens\.weight",
r"decoder\.embed_tokens\.weight",
r"decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight",
]
......@@ -1262,9 +1300,12 @@ class T5ForConditionalGeneration(T5PreTrainedModel):
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings:
# Rescale output before projecting on vocab
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
sequence_output = sequence_output * (self.model_dim ** -0.5)
lm_logits = self.lm_head(sequence_output)
loss = None
......
# coding=utf-8
# Copyright 2018 T5 Authors and The HuggingFace Inc. team.
# Copyright 2020 T5 Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -26,6 +26,7 @@ import tensorflow as tf
from transformers.modeling_tf_utils import TFWrappedEmbeddings
from ...activations_tf import get_tf_activation
from ...file_utils import (
DUMMY_INPUTS,
DUMMY_MASK,
......@@ -103,10 +104,35 @@ class TFT5DenseReluDense(tf.keras.layers.Layer):
return hidden_states
class TFT5GatedGeluDense(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.wi_0 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_0")
self.wi_1 = tf.keras.layers.Dense(config.d_ff, use_bias=False, name="wi_1")
self.wo = tf.keras.layers.Dense(config.d_model, use_bias=False, name="wo")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
self.act = get_tf_activation("gelu_new")
def call(self, hidden_states, training=False):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states, training=training)
hidden_states = self.wo(hidden_states)
return hidden_states
class TFT5LayerFF(tf.keras.layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.feed_forward_proj == "relu":
self.DenseReluDense = TFT5DenseReluDense(config, name="DenseReluDense")
elif config.feed_forward_proj == "gated-gelu":
self.DenseReluDense = TFT5GatedGeluDense(config, name="DenseReluDense")
else:
raise ValueError(
f"{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`"
)
self.layer_norm = TFT5LayerNorm(epsilon=config.layer_norm_epsilon, name="layer_norm")
self.dropout = tf.keras.layers.Dropout(config.dropout_rate)
......@@ -547,9 +573,6 @@ class TFT5MainLayer(tf.keras.layers.Layer):
def get_input_embeddings(self):
return self.embed_tokens
def get_output_embeddings(self):
return self.embed_tokens
def set_embed_tokens(self, embed_tokens):
self.embed_tokens = embed_tokens
......@@ -970,9 +993,6 @@ class TFT5Model(TFT5PreTrainedModel):
def get_input_embeddings(self):
return self.shared
def get_output_embeddings(self):
return self.shared
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
self.shared.vocab_size = self.shared.weight.shape[0]
......@@ -1165,11 +1185,17 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
decoder_config.is_decoder = True
self.decoder = TFT5MainLayer(decoder_config, embed_tokens, name="decoder")
if not config.tie_word_embeddings:
self.lm_head = tf.keras.layers.Dense(config.vocab_size, use_bias=False, name="lm_head")
def get_input_embeddings(self):
return self.shared
def get_output_embeddings(self):
if self.config.tie_word_embeddings:
return self.shared
else:
return self.lm_head
def set_input_embeddings(self, new_embeddings):
self.shared.weight = new_embeddings
......@@ -1331,9 +1357,14 @@ class TFT5ForConditionalGeneration(TFT5PreTrainedModel, TFCausalLanguageModeling
training=training,
)
sequence_output = decoder_outputs[0] * (self.model_dim ** -0.5)
embed_tokens = self.get_output_embeddings()
logits = embed_tokens(sequence_output, mode="linear")
sequence_output = decoder_outputs[0]
# T5v1.1 does not tie output word embeddings and thus does not require downscaling
if self.config.tie_word_embeddings:
sequence_output = sequence_output * (self.model_dim ** -0.5)
logits = self.get_output_embeddings()(sequence_output, mode="linear")
else:
logits = self.get_output_embeddings()(sequence_output)
loss = None if labels is None else self.compute_loss(labels, logits)
......
......@@ -1361,6 +1361,29 @@ def load_tf_weights_in_mobilebert(*args, **kwargs):
requires_pytorch(load_tf_weights_in_mobilebert)
class MT5Config:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
class MT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
class MT5Model:
def __init__(self, *args, **kwargs):
requires_pytorch(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_pytorch(self)
OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
......@@ -970,6 +970,24 @@ class TFMobileBertPreTrainedModel:
requires_tf(self)
class TFMT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
class TFMT5Model:
def __init__(self, *args, **kwargs):
requires_tf(self)
@classmethod
def from_pretrained(self, *args, **kwargs):
requires_tf(self)
TF_OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
import unittest
from transformers import is_torch_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow, torch_device
if is_torch_available():
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
@require_torch
@require_sentencepiece
@require_tokenizers
class MT5IntegrationTest(unittest.TestCase):
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = AutoModelForSeq2SeqLM.from_pretrained("google/mt5-small", return_dict=True).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -84.9127
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
......@@ -490,6 +490,14 @@ class T5ModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_model(*config_and_inputs)
def test_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
# check that gated gelu feed forward and different word embeddings work
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_model(config, *config_and_inputs[1:])
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_with_lm_head(*config_and_inputs)
......@@ -569,7 +577,7 @@ class T5ModelIntegrationTests(unittest.TestCase):
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = T5ForConditionalGeneration.from_pretrained("t5-small", return_dict=True).to(torch_device)
model = T5ForConditionalGeneration.from_pretrained("t5-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
......@@ -581,6 +589,32 @@ class T5ModelIntegrationTests(unittest.TestCase):
EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_small_v1_1_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_v1_1_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1_1_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = T5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small").to(torch_device)
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
input_ids = tokenizer("Hello there", return_tensors="pt").input_ids
labels = tokenizer("Hi I am", return_tensors="pt").input_ids
loss = model(input_ids.to(torch_device), labels=labels.to(torch_device)).loss
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -59.0293
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_summarization(self):
model = self.model
......
# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_tf_available
from transformers.testing_utils import require_sentencepiece, require_tf, require_tokenizers, slow
if is_tf_available():
import tensorflow as tf
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
@require_tf
@require_sentencepiece
@require_tokenizers
class TFMT5ModelIntegrationTest(unittest.TestCase):
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFAutoModelForSeq2SeqLM.from_pretrained("google/mt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -84.9127
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
......@@ -258,6 +258,13 @@ class TFT5ModelTest(TFModelTesterMixin, unittest.TestCase):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_model(*config_and_inputs)
def test_t5_model_v1_1(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
config = config_and_inputs[0]
config.tie_word_embeddings = False
config.feed_forward_proj = "gated-gelu"
self.model_tester.create_and_check_t5_model(config, *config_and_inputs[1:])
def test_with_lm_head(self):
config_and_inputs = self.model_tester.prepare_config_and_inputs()
self.model_tester.create_and_check_t5_with_lm_head(*config_and_inputs)
......@@ -296,6 +303,58 @@ class TFT5ModelIntegrationTests(unittest.TestCase):
def model(self):
return TFT5ForConditionalGeneration.from_pretrained("t5-base")
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFT5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -19.0845
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_small_v1_1_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_t5_v1.1_checkpoint = '<fill_in>'
>>> path_to_mtf_small_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_t5_v1.1_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_spm_model_path, extra_ids=100)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = TFT5ForConditionalGeneration.from_pretrained("google/t5-v1_1-small")
tokenizer = T5Tokenizer.from_pretrained("google/t5-v1_1-small")
input_ids = tokenizer("Hello there", return_tensors="tf").input_ids
labels = tokenizer("Hi I am", return_tensors="tf").input_ids
loss = model(input_ids, labels=labels).loss
mtf_score = -tf.math.reduce_sum(loss).numpy()
EXPECTED_SCORE = -59.0293
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
@slow
def test_summarization(self):
model = self.model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment