Unverified Commit 0aac9ba2 authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add Flaubert OnnxConfig to Transformers (#16279)



* Add Flaubert to ONNX to make it available for conversion.

* Fixed features for FlauBERT. fixup command remove flaubert to docs list.
Co-authored-by: default avatarChainYo <t.chaigneau.tc@gmail.com>
parent 9fef6683
......@@ -52,6 +52,7 @@ Ready-made configurations include the following architectures:
- Data2VecText
- DistilBERT
- ELECTRA
- FlauBERT
- GPT Neo
- I-BERT
- LayoutLM
......
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_tf_available, is_torch_available
_import_structure = {
"configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig"],
"configuration_flaubert": ["FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "FlaubertConfig", "FlaubertOnnxConfig"],
"tokenization_flaubert": ["FlaubertTokenizer"],
}
......@@ -52,7 +52,7 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig
from .configuration_flaubert import FLAUBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, FlaubertConfig, FlaubertOnnxConfig
from .tokenization_flaubert import FlaubertTokenizer
if is_torch_available():
......
......@@ -14,6 +14,10 @@
# limitations under the License.
""" Flaubert configuration, based on XLM."""
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging
from ..xlm.configuration_xlm import XLMConfig
......@@ -137,3 +141,14 @@ class FlaubertConfig(XLMConfig):
self.layerdrop = layerdrop
self.pre_norm = pre_norm
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, **kwargs)
class FlaubertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
......@@ -8,6 +8,7 @@ from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.electra import ElectraOnnxConfig
from ..models.flaubert import FlaubertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
......@@ -179,6 +180,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=DistilBertOnnxConfig,
),
"flaubert": supported_features_mapping(
"default",
"masked-lm",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=FlaubertOnnxConfig,
),
"marian": supported_features_mapping(
"default",
"default-with-past",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment