Unverified Commit 01485cee authored by Javier de la Rosa's avatar Javier de la Rosa Committed by GitHub
Browse files

Add missing support for Flax XLM-RoBERTa (#15900)



* Adding Flax XLM-RoBERTa

* Add Flax to __init__

* Adding doc and dummy objects

* Add tests

* Add Flax XLM-R models autodoc

* Fix tests

* Add Flask XLM-RoBERTa to TEST_FILES_WITH_NO_COMMON_TESTS

* Update src/transformers/models/xlm_roberta/modeling_flax_xlm_roberta.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Update tests/xlm_roberta/test_modeling_flax_xlm_roberta.py
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>

* Remove test on large Flask XLM-RoBERTa

* Add tokenizer to the test
Co-authored-by: default avatarSuraj Patil <surajp815@gmail.com>
parent 89c7d9cf
......@@ -257,7 +257,7 @@ Flax), PyTorch, and/or TensorFlow.
| WavLM | ❌ | ❌ | ✅ | ❌ | ❌ |
| XGLM | ✅ | ✅ | ✅ | ❌ | ✅ |
| XLM | ✅ | ❌ | ✅ | ✅ | ❌ |
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | |
| XLM-RoBERTa | ✅ | ✅ | ✅ | ✅ | |
| XLM-RoBERTa-XL | ❌ | ❌ | ✅ | ❌ | ❌ |
| XLMProphetNet | ✅ | ❌ | ✅ | ❌ | ❌ |
| XLNet | ✅ | ✅ | ✅ | ✅ | ❌ |
......
......@@ -124,3 +124,33 @@ This model was contributed by [stefan-it](https://huggingface.co/stefan-it). The
[[autodoc]] TFXLMRobertaForQuestionAnswering
- call
## FlaxXLMRobertaModel
[[autodoc]] FlaxXLMRobertaModel
- __call__
## FlaxXLMRobertaForMaskedLM
[[autodoc]] FlaxXLMRobertaForMaskedLM
- __call__
## FlaxXLMRobertaForSequenceClassification
[[autodoc]] FlaxXLMRobertaForSequenceClassification
- __call__
## FlaxXLMRobertaForMultipleChoice
[[autodoc]] FlaxXLMRobertaForMultipleChoice
- __call__
## FlaxXLMRobertaForTokenClassification
[[autodoc]] FlaxXLMRobertaForTokenClassification
- __call__
## FlaxXLMRobertaForQuestionAnswering
[[autodoc]] FlaxXLMRobertaForQuestionAnswering
- __call__
......@@ -2348,6 +2348,16 @@ if is_flax_available():
"FlaxXGLMPreTrainedModel",
]
)
_import_structure["models.xlm_roberta"].extend(
[
"FlaxXLMRobertaForMaskedLM",
"FlaxXLMRobertaForMultipleChoice",
"FlaxXLMRobertaForQuestionAnswering",
"FlaxXLMRobertaForSequenceClassification",
"FlaxXLMRobertaForTokenClassification",
"FlaxXLMRobertaModel",
]
)
else:
from .utils import dummy_flax_objects
......@@ -4268,6 +4278,14 @@ if TYPE_CHECKING:
FlaxWav2Vec2PreTrainedModel,
)
from .models.xglm import FlaxXGLMForCausalLM, FlaxXGLMModel, FlaxXGLMPreTrainedModel
from .models.xlm_roberta import (
FlaxXLMRobertaForMaskedLM,
FlaxXLMRobertaForMultipleChoice,
FlaxXLMRobertaForQuestionAnswering,
FlaxXLMRobertaForSequenceClassification,
FlaxXLMRobertaForTokenClassification,
FlaxXLMRobertaModel,
)
else:
# Import the same objects as dummies to get them in the namespace.
# They will raise an import error if the user tries to instantiate / use them.
......
......@@ -35,6 +35,7 @@ FLAX_MODEL_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertModel"),
("albert", "FlaxAlbertModel"),
("roberta", "FlaxRobertaModel"),
("xlm-roberta", "FlaxXLMRobertaModel"),
("bert", "FlaxBertModel"),
("beit", "FlaxBeitModel"),
("big_bird", "FlaxBigBirdModel"),
......@@ -60,6 +61,7 @@ FLAX_MODEL_FOR_PRETRAINING_MAPPING_NAMES = OrderedDict(
# Model for pre-training mapping
("albert", "FlaxAlbertForPreTraining"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bert", "FlaxBertForPreTraining"),
("big_bird", "FlaxBigBirdForPreTraining"),
("bart", "FlaxBartForConditionalGeneration"),
......@@ -78,6 +80,7 @@ FLAX_MODEL_FOR_MASKED_LM_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertForMaskedLM"),
("albert", "FlaxAlbertForMaskedLM"),
("roberta", "FlaxRobertaForMaskedLM"),
("xlm-roberta", "FlaxXLMRobertaForMaskedLM"),
("bert", "FlaxBertForMaskedLM"),
("big_bird", "FlaxBigBirdForMaskedLM"),
("bart", "FlaxBartForConditionalGeneration"),
......@@ -132,6 +135,7 @@ FLAX_MODEL_FOR_SEQUENCE_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertForSequenceClassification"),
("albert", "FlaxAlbertForSequenceClassification"),
("roberta", "FlaxRobertaForSequenceClassification"),
("xlm-roberta", "FlaxXLMRobertaForSequenceClassification"),
("bert", "FlaxBertForSequenceClassification"),
("big_bird", "FlaxBigBirdForSequenceClassification"),
("bart", "FlaxBartForSequenceClassification"),
......@@ -147,6 +151,7 @@ FLAX_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertForQuestionAnswering"),
("albert", "FlaxAlbertForQuestionAnswering"),
("roberta", "FlaxRobertaForQuestionAnswering"),
("xlm-roberta", "FlaxXLMRobertaForQuestionAnswering"),
("bert", "FlaxBertForQuestionAnswering"),
("big_bird", "FlaxBigBirdForQuestionAnswering"),
("bart", "FlaxBartForQuestionAnswering"),
......@@ -162,6 +167,7 @@ FLAX_MODEL_FOR_TOKEN_CLASSIFICATION_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertForTokenClassification"),
("albert", "FlaxAlbertForTokenClassification"),
("roberta", "FlaxRobertaForTokenClassification"),
("xlm-roberta", "FlaxXLMRobertaForTokenClassification"),
("bert", "FlaxBertForTokenClassification"),
("big_bird", "FlaxBigBirdForTokenClassification"),
("electra", "FlaxElectraForTokenClassification"),
......@@ -175,6 +181,7 @@ FLAX_MODEL_FOR_MULTIPLE_CHOICE_MAPPING_NAMES = OrderedDict(
("distilbert", "FlaxDistilBertForMultipleChoice"),
("albert", "FlaxAlbertForMultipleChoice"),
("roberta", "FlaxRobertaForMultipleChoice"),
("xlm-roberta", "FlaxXLMRobertaForMultipleChoice"),
("bert", "FlaxBertForMultipleChoice"),
("big_bird", "FlaxBigBirdForMultipleChoice"),
("electra", "FlaxElectraForMultipleChoice"),
......
......@@ -20,6 +20,7 @@ from typing import TYPE_CHECKING
from ...file_utils import (
_LazyModule,
is_flax_available,
is_sentencepiece_available,
is_tf_available,
is_tokenizers_available,
......@@ -64,6 +65,15 @@ if is_tf_available():
"TFXLMRobertaModel",
]
if is_flax_available():
_import_structure["modeling_flax_xlm_roberta"] = [
"FlaxXLMRobertaForMaskedLM",
"FlaxXLMRobertaForMultipleChoice",
"FlaxXLMRobertaForQuestionAnswering",
"FlaxXLMRobertaForSequenceClassification",
"FlaxXLMRobertaForTokenClassification",
"FlaxXLMRobertaModel",
]
if TYPE_CHECKING:
from .configuration_xlm_roberta import (
......@@ -101,6 +111,16 @@ if TYPE_CHECKING:
TFXLMRobertaModel,
)
if is_flax_available():
from .modeling_flax_xlm_roberta import (
FlaxXLMRobertaForMaskedLM,
FlaxXLMRobertaForMultipleChoice,
FlaxXLMRobertaForQuestionAnswering,
FlaxXLMRobertaForSequenceClassification,
FlaxXLMRobertaForTokenClassification,
FlaxXLMRobertaModel,
)
else:
import sys
......
# coding=utf-8
# Copyright 2022 Facebook AI Research and the HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Flax XLM-RoBERTa model."""
from ...file_utils import add_start_docstrings
from ...utils import logging
from ..roberta.modeling_flax_roberta import (
FlaxRobertaForMaskedLM,
FlaxRobertaForMultipleChoice,
FlaxRobertaForQuestionAnswering,
FlaxRobertaForSequenceClassification,
FlaxRobertaForTokenClassification,
FlaxRobertaModel,
)
from .configuration_xlm_roberta import XLMRobertaConfig
logger = logging.get_logger(__name__)
XLM_ROBERTA_PRETRAINED_MODEL_ARCHIVE_LIST = [
"xlm-roberta-base",
"xlm-roberta-large",
"xlm-roberta-large-finetuned-conll02-dutch",
"xlm-roberta-large-finetuned-conll02-spanish",
"xlm-roberta-large-finetuned-conll03-english",
"xlm-roberta-large-finetuned-conll03-german",
# See all XLM-RoBERTa models at https://huggingface.co/models?filter=xlm-roberta
]
XLM_ROBERTA_START_DOCSTRING = r"""
This model inherits from [`FlaxPreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading, saving and converting weights from PyTorch models)
This model is also a Flax Linen [flax.linen.Module](https://flax.readthedocs.io/en/latest/flax.linen.html#module)
subclass. Use it as a regular Flax linen Module and refer to the Flax documentation for all matter related to
general usage and behavior.
Finally, this model supports inherent JAX features such as:
- [Just-In-Time (JIT) compilation](https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit)
- [Automatic Differentiation](https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation)
- [Vectorization](https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap)
- [Parallelization](https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap)
Parameters:
config ([`XLMRobertaConfig`]): Model configuration class with all the parameters of the
model. Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~FlaxPreTrainedModel.from_pretrained`] method to load the model weights.
"""
@add_start_docstrings(
"The bare XLM-RoBERTa Model transformer outputting raw hidden-states without any specific head on top.",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaModel(FlaxRobertaModel):
"""
This class overrides [`FlaxRobertaModel`]. Please check the superclass for the appropriate documentation alongside
usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""XLM-RoBERTa Model with a `language modeling` head on top.""",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForMaskedLM(FlaxRobertaForMaskedLM):
"""
This class overrides [`FlaxRobertaForMaskedLM`]. Please check the superclass for the appropriate documentation
alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""
XLM-RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
pooled output) e.g. for GLUE tasks.
""",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForSequenceClassification(FlaxRobertaForSequenceClassification):
"""
This class overrides [`FlaxRobertaForSequenceClassification`]. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""
XLM-RoBERTa Model with a multiple choice classification head on top (a linear layer on top of the pooled output and
a softmax) e.g. for RocStories/SWAG tasks.
""",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForMultipleChoice(FlaxRobertaForMultipleChoice):
"""
This class overrides [`FlaxRobertaForMultipleChoice`]. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""
XLM-RoBERTa Model with a token classification head on top (a linear layer on top of the hidden-states output) e.g.
for Named-Entity-Recognition (NER) tasks.
""",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForTokenClassification(FlaxRobertaForTokenClassification):
"""
This class overrides [`FlaxRobertaForTokenClassification`]. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
@add_start_docstrings(
"""
XLM-RoBERTa Model with a span classification head on top for extractive question-answering tasks like SQuAD (a
linear layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
""",
XLM_ROBERTA_START_DOCSTRING,
)
class FlaxXLMRobertaForQuestionAnswering(FlaxRobertaForQuestionAnswering):
"""
This class overrides [`FlaxRobertaForQuestionAnswering`]. Please check the superclass for the appropriate
documentation alongside usage examples.
"""
config_class = XLMRobertaConfig
......@@ -989,3 +989,45 @@ class FlaxXGLMPreTrainedModel(metaclass=DummyObject):
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaForMaskedLM(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaForMultipleChoice(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaForQuestionAnswering(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaForSequenceClassification(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaForTokenClassification(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
class FlaxXLMRobertaModel(metaclass=DummyObject):
_backends = ["flax"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["flax"])
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import AutoTokenizer, is_flax_available
from transformers.testing_utils import require_flax, require_sentencepiece, require_tokenizers, slow
if is_flax_available():
import jax.numpy as jnp
from transformers import FlaxXLMRobertaModel
@require_sentencepiece
@require_tokenizers
@require_flax
class FlaxXLMRobertaModelIntegrationTest(unittest.TestCase):
@slow
def test_flax_xlm_roberta_base(self):
model = FlaxXLMRobertaModel.from_pretrained("xlm-roberta-base")
tokenizer = AutoTokenizer.from_pretrained("xlm-roberta-base")
text = "The dog is cute and lives in the garden house"
input_ids = jnp.array([tokenizer.encode(text)])
expected_output_shape = (1, 12, 768) # batch_size, sequence_length, embedding_vector_dim
expected_output_values_last_dim = jnp.array(
[[-0.0101, 0.1218, -0.0803, 0.0801, 0.1327, 0.0776, -0.1215, 0.2383, 0.3338, 0.3106, 0.0300, 0.0252]]
)
output = model(input_ids)["last_hidden_state"]
self.assertEqual(output.shape, expected_output_shape)
# compare the actual values for a slice of last dim
self.assertTrue(jnp.allclose(output[:, :, -1], expected_output_values_last_dim, atol=1e-3))
......@@ -102,6 +102,7 @@ TEST_FILES_WITH_NO_COMMON_TESTS = [
"camembert/test_modeling_tf_camembert.py",
"mt5/test_modeling_tf_mt5.py",
"xlm_roberta/test_modeling_tf_xlm_roberta.py",
"xlm_roberta/test_modeling_flax_xlm_roberta.py",
"xlm_prophetnet/test_modeling_xlm_prophetnet.py",
"xlm_roberta/test_modeling_xlm_roberta.py",
"vision_text_dual_encoder/test_modeling_vision_text_dual_encoder.py",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment