Unverified Commit 1f60df81 authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add Camembert to models exportable with ONNX (#14059)



Add Camembert to models exportable with ONNX
Co-authored-by: default avatarThomas.Chaigneau <thomas.chaigneau@arkea.com>
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 0c3174c7
......@@ -43,6 +43,7 @@ Ready-made configurations include the following models:
- ALBERT
- BART
- BERT
- CamemBERT
- DistilBERT
- GPT Neo
- LayoutLM
......
......@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = {
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"],
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"],
}
if is_sentencepiece_available():
......@@ -62,7 +62,7 @@ if is_tf_available():
if TYPE_CHECKING:
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig
if is_sentencepiece_available():
from .tokenization_camembert import CamembertTokenizer
......
......@@ -15,6 +15,10 @@
# limitations under the License.
""" CamemBERT configuration """
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig
......@@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig):
"""
model_type = "camembert"
class CamembertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
......@@ -5,6 +5,7 @@ from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
......@@ -62,6 +63,14 @@ class FeaturesManager:
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"camembert": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment