"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "b4eef63a1de97b9bbd8d54b83ede16e34afe3529"
Unverified Commit a317e6c3 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[Flax] Correctly Add MT5 (#12988)



* finish PR

* finish mt5

* push

* up

* Update tests/test_modeling_flax_mt5.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent da9754a3
...@@ -428,7 +428,7 @@ Flax), PyTorch, and/or TensorFlow. ...@@ -428,7 +428,7 @@ Flax), PyTorch, and/or TensorFlow.
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mBART | ✅ | ✅ | ✅ | ✅ | ✅ | | mBART | ✅ | ✅ | ✅ | ✅ | ✅ |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
| mT5 | ✅ | ✅ | ✅ | ✅ | | | mT5 | ✅ | ✅ | ✅ | ✅ | |
+-----------------------------+----------------+----------------+-----------------+--------------------+--------------+ +-----------------------------+----------------+----------------+-----------------+--------------------+--------------+
.. toctree:: .. toctree::
......
...@@ -94,3 +94,17 @@ TFMT5EncoderModel ...@@ -94,3 +94,17 @@ TFMT5EncoderModel
.. autoclass:: transformers.TFMT5EncoderModel .. autoclass:: transformers.TFMT5EncoderModel
:members: :members:
FlaxMT5Model
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMT5Model
:members:
FlaxMT5ForConditionalGeneration
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: transformers.FlaxMT5ForConditionalGeneration
:members:
...@@ -1691,6 +1691,7 @@ if is_flax_available(): ...@@ -1691,6 +1691,7 @@ if is_flax_available():
"FlaxMBartPreTrainedModel", "FlaxMBartPreTrainedModel",
] ]
) )
_import_structure["models.mt5"].extend(["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"])
_import_structure["models.roberta"].extend( _import_structure["models.roberta"].extend(
[ [
"FlaxRobertaForMaskedLM", "FlaxRobertaForMaskedLM",
...@@ -3120,6 +3121,7 @@ if TYPE_CHECKING: ...@@ -3120,6 +3121,7 @@ if TYPE_CHECKING:
FlaxMBartModel, FlaxMBartModel,
FlaxMBartPreTrainedModel, FlaxMBartPreTrainedModel,
) )
from .models.mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from .models.roberta import ( from .models.roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
......
...@@ -62,6 +62,7 @@ from ..mbart.modeling_flax_mbart import ( ...@@ -62,6 +62,7 @@ from ..mbart.modeling_flax_mbart import (
FlaxMBartForSequenceClassification, FlaxMBartForSequenceClassification,
FlaxMBartModel, FlaxMBartModel,
) )
from ..mt5.modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
from ..roberta.modeling_flax_roberta import ( from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM, FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice, FlaxRobertaForMultipleChoice,
...@@ -109,7 +110,7 @@ FLAX_MODEL_MAPPING = OrderedDict( ...@@ -109,7 +110,7 @@ FLAX_MODEL_MAPPING = OrderedDict(
(ViTConfig, FlaxViTModel), (ViTConfig, FlaxViTModel),
(MBartConfig, FlaxMBartModel), (MBartConfig, FlaxMBartModel),
(T5Config, FlaxT5Model), (T5Config, FlaxT5Model),
(MT5Config, FlaxT5Model), (MT5Config, FlaxMT5Model),
(Wav2Vec2Config, FlaxWav2Vec2Model), (Wav2Vec2Config, FlaxWav2Vec2Model),
(MarianConfig, FlaxMarianModel), (MarianConfig, FlaxMarianModel),
] ]
...@@ -125,7 +126,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict( ...@@ -125,7 +126,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING = OrderedDict(
(ElectraConfig, FlaxElectraForPreTraining), (ElectraConfig, FlaxElectraForPreTraining),
(MBartConfig, FlaxMBartForConditionalGeneration), (MBartConfig, FlaxMBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxT5ForConditionalGeneration), (MT5Config, FlaxMT5ForConditionalGeneration),
(Wav2Vec2Config, FlaxWav2Vec2ForPreTraining), (Wav2Vec2Config, FlaxWav2Vec2ForPreTraining),
] ]
) )
...@@ -147,7 +148,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict( ...@@ -147,7 +148,7 @@ FLAX_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING = OrderedDict(
# Model for Seq2Seq Causal LM mapping # Model for Seq2Seq Causal LM mapping
(BartConfig, FlaxBartForConditionalGeneration), (BartConfig, FlaxBartForConditionalGeneration),
(T5Config, FlaxT5ForConditionalGeneration), (T5Config, FlaxT5ForConditionalGeneration),
(MT5Config, FlaxT5ForConditionalGeneration), (MT5Config, FlaxMT5ForConditionalGeneration),
(MarianConfig, FlaxMarianMTModel), (MarianConfig, FlaxMarianMTModel),
] ]
) )
......
...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING ...@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import ( from ...file_utils import (
_LazyModule, _LazyModule,
is_flax_available,
is_sentencepiece_available, is_sentencepiece_available,
is_tf_available, is_tf_available,
is_tokenizers_available, is_tokenizers_available,
...@@ -51,6 +52,9 @@ if is_torch_available(): ...@@ -51,6 +52,9 @@ if is_torch_available():
if is_tf_available(): if is_tf_available():
_import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"] _import_structure["modeling_tf_mt5"] = ["TFMT5EncoderModel", "TFMT5ForConditionalGeneration", "TFMT5Model"]
if is_flax_available():
_import_structure["modeling_flax_mt5"] = ["FlaxMT5ForConditionalGeneration", "FlaxMT5Model"]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_mt5 import MT5Config from .configuration_mt5 import MT5Config
...@@ -61,6 +65,9 @@ if TYPE_CHECKING: ...@@ -61,6 +65,9 @@ if TYPE_CHECKING:
if is_tf_available(): if is_tf_available():
from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model from .modeling_tf_mt5 import TFMT5EncoderModel, TFMT5ForConditionalGeneration, TFMT5Model
if is_flax_available():
from .modeling_flax_mt5 import FlaxMT5ForConditionalGeneration, FlaxMT5Model
else: else:
import sys import sys
......
# coding=utf-8
# Copyright 2021 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Flax mT5 model. """
from ...utils import logging
from ..t5.modeling_flax_t5 import FlaxT5ForConditionalGeneration, FlaxT5Model
from .configuration_mt5 import MT5Config
logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "T5Config"
_TOKENIZER_FOR_DOC = "T5Tokenizer"
class FlaxMT5Model(FlaxT5Model):
r"""
This class overrides :class:`~transformers.FlaxT5Model`. Please check the superclass for the appropriate
documentation alongside usage examples.
Examples::
>>> from transformers import FlaxMT5Model, T5Tokenizer
>>> model = FlaxMT5Model.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer():
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=decoder_input_ids)
>>> hidden_states = outputs.last_hidden_state
"""
model_type = "mt5"
config_class = MT5Config
class FlaxMT5ForConditionalGeneration(FlaxT5ForConditionalGeneration):
r"""
This class overrides :class:`~transformers.FlaxT5ForConditionalGeneration`. Please check the superclass for the
appropriate documentation alongside usage examples.
Examples::
>>> from transformers import FlaxMT5ForConditionalGeneration, T5Tokenizer
>>> model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
>>> tokenizer = T5Tokenizer.from_pretrained("google/mt5-small")
>>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
>>> summary = "Weiter Verhandlung in Syrien."
>>> inputs = tokenizer(article, return_tensors="np")
>>> with tokenizer.as_target_tokenizer():
... decoder_input_ids = tokenizer(summary, return_tensors="np").input_ids
>>> outputs = model(**inputs, decoder_input_ids=decoder_input_ids)
>>> logits = outputs.logits
"""
model_type = "mt5"
config_class = MT5Config
...@@ -642,6 +642,24 @@ class FlaxMBartPreTrainedModel: ...@@ -642,6 +642,24 @@ class FlaxMBartPreTrainedModel:
requires_backends(cls, ["flax"]) requires_backends(cls, ["flax"])
class FlaxMT5ForConditionalGeneration:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxMT5Model:
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["flax"])
class FlaxRobertaForMaskedLM: class FlaxRobertaForMaskedLM:
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"]) requires_backends(self, ["flax"])
......
# Copyright 2021 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import is_flax_available
from transformers.testing_utils import require_sentencepiece, require_tokenizers, require_torch, slow
if is_flax_available():
import optax
from flax.training.common_utils import onehot
from transformers import AutoTokenizer, FlaxMT5ForConditionalGeneration
from transformers.models.t5.modeling_flax_t5 import shift_tokens_right
@require_torch
@require_sentencepiece
@require_tokenizers
class MT5IntegrationTest(unittest.TestCase):
@slow
def test_small_integration_test(self):
"""
For comparision run:
>>> import t5 # pip install t5==0.7.1
>>> from t5.data.sentencepiece_vocabulary import SentencePieceVocabulary
>>> path_to_mtf_small_mt5_checkpoint = '<fill_in>'
>>> path_to_mtf_small_mt5_spm_model_path = '<fill_in>'
>>> t5_model = t5.models.MtfModel(model_dir=path_to_mtf_small_mt5_checkpoint, batch_size=1, tpu=None)
>>> vocab = SentencePieceVocabulary(path_to_mtf_small_mt5_spm_model_path)
>>> score = t5_model.score(inputs=["Hello there"], targets=["Hi I am"], vocabulary=vocab)
"""
model = FlaxMT5ForConditionalGeneration.from_pretrained("google/mt5-small")
tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
input_ids = tokenizer("Hello there", return_tensors="np").input_ids
labels = tokenizer("Hi I am", return_tensors="np").input_ids
decoder_input_ids = shift_tokens_right(labels, model.config.pad_token_id, model.config.decoder_start_token_id)
logits = model(input_ids, decoder_input_ids=decoder_input_ids).logits
loss = optax.softmax_cross_entropy(logits, onehot(labels, logits.shape[-1])).mean()
mtf_score = -(labels.shape[-1] * loss.item())
EXPECTED_SCORE = -84.9127
self.assertTrue(abs(mtf_score - EXPECTED_SCORE) < 1e-4)
...@@ -82,8 +82,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [ ...@@ -82,8 +82,7 @@ IGNORE_NON_TESTED = PRIVATE_MODELS.copy() + [
# trigger the common tests. # trigger the common tests.
TEST_FILES_WITH_NO_COMMON_TESTS = [ TEST_FILES_WITH_NO_COMMON_TESTS = [
"test_modeling_camembert.py", "test_modeling_camembert.py",
"test_modeling_flax_bert.py", "test_modeling_flax_mt5.py",
"test_modeling_flax_roberta.py",
"test_modeling_mbart.py", "test_modeling_mbart.py",
"test_modeling_mt5.py", "test_modeling_mt5.py",
"test_modeling_pegasus.py", "test_modeling_pegasus.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment