Unverified Commit c4fa908f authored by Virus's avatar Virus Committed by GitHub
Browse files

Adds IBERT to models exportable with ONNX (#14868)

* Add IBertOnnxConfig and tests

* add all the supported features for IBERT and remove outputs in IbertOnnxConfig

* use OnnxConfig

* fix codestyle

* remove serialization.rst

* codestyle
parent efb35a41
......@@ -40,6 +40,7 @@ Ready-made configurations include the following models:
- CamemBERT
- DistilBERT
- GPT Neo
- I-BERT
- LayoutLM
- Longformer
- Marian
......
......@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available
_import_structure = {
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"],
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
}
if is_torch_available():
......@@ -38,7 +38,7 @@ if is_torch_available():
]
if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
if is_torch_available():
from .modeling_ibert import (
......
......@@ -15,6 +15,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" I-BERT configuration"""
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfig
from ...configuration_utils import PretrainedConfig
from ...utils import logging
......@@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type
self.quant_mode = quant_mode
self.force_dequant = force_dequant
class IBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
......@@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig
......@@ -125,6 +126,15 @@ class FeaturesManager:
"question-answering",
onnx_config_cls=BertOnnxConfig,
),
"ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
),
"camembert": supported_features_mapping(
"default",
"masked-lm",
......
......@@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
PYTORCH_EXPORT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"),
# ("longFormer", "longformer-base-4096"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment