"...resnet50_tensorflow.git" did not exist on "60613a8729122012c3f6739550de981d01087326"
Unverified Commit 1f60df81 authored by Thomas Chaigneau's avatar Thomas Chaigneau Committed by GitHub
Browse files

Add Camembert to models exportable with ONNX (#14059)



Add Camembert to models exportable with ONNX
Co-authored-by: default avatarThomas.Chaigneau <thomas.chaigneau@arkea.com>
Co-authored-by: default avatarMichael Benayoun <mickbenayoun@gmail.com>
parent 0c3174c7
...@@ -43,6 +43,7 @@ Ready-made configurations include the following models: ...@@ -43,6 +43,7 @@ Ready-made configurations include the following models:
- ALBERT - ALBERT
- BART - BART
- BERT - BERT
- CamemBERT
- DistilBERT - DistilBERT
- GPT Neo - GPT Neo
- LayoutLM - LayoutLM
......
...@@ -28,7 +28,7 @@ from ...file_utils import ( ...@@ -28,7 +28,7 @@ from ...file_utils import (
_import_structure = { _import_structure = {
"configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig"], "configuration_camembert": ["CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "CamembertConfig", "CamembertOnnxConfig"],
} }
if is_sentencepiece_available(): if is_sentencepiece_available():
...@@ -62,7 +62,7 @@ if is_tf_available(): ...@@ -62,7 +62,7 @@ if is_tf_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig from .configuration_camembert import CAMEMBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, CamembertConfig, CamembertOnnxConfig
if is_sentencepiece_available(): if is_sentencepiece_available():
from .tokenization_camembert import CamembertTokenizer from .tokenization_camembert import CamembertTokenizer
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
# limitations under the License. # limitations under the License.
""" CamemBERT configuration """ """ CamemBERT configuration """
from collections import OrderedDict
from typing import Mapping
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
from ..roberta.configuration_roberta import RobertaConfig from ..roberta.configuration_roberta import RobertaConfig
...@@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig): ...@@ -35,3 +39,14 @@ class CamembertConfig(RobertaConfig):
""" """
model_type = "camembert" model_type = "camembert"
class CamembertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
...@@ -5,6 +5,7 @@ from .. import is_torch_available ...@@ -5,6 +5,7 @@ from .. import is_torch_available
from ..models.albert import AlbertOnnxConfig from ..models.albert import AlbertOnnxConfig
from ..models.bart import BartOnnxConfig from ..models.bart import BartOnnxConfig
from ..models.bert import BertOnnxConfig from ..models.bert import BertOnnxConfig
from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
...@@ -62,6 +63,14 @@ class FeaturesManager: ...@@ -62,6 +63,14 @@ class FeaturesManager:
"bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig), "bart": supported_features_mapping("default", onnx_config_cls=BartOnnxConfig),
"mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig), "mbart": supported_features_mapping("default", onnx_config_cls=MBartOnnxConfig),
"bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig), "bert": supported_features_mapping("default", onnx_config_cls=BertOnnxConfig),
"camembert": supported_features_mapping(
"default",
"causal-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls=CamembertOnnxConfig,
),
"distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig), "distilbert": supported_features_mapping("default", onnx_config_cls=DistilBertOnnxConfig),
"gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig), "gpt2": supported_features_mapping("default", onnx_config_cls=GPT2OnnxConfig),
"longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig), "longformer": supported_features_mapping("default", onnx_config_cls=LongformerOnnxConfig),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment