"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "22b41b3f8a5cdb37e686d18d8d9a24eb98a331ec"
Unverified Commit c4fa908f authored by Virus's avatar Virus Committed by GitHub
Browse files

Adds IBERT to models exportable with ONNX (#14868)

* Add IBertOnnxConfig and tests

* add all the supported features for IBERT and remove outputs in IbertOnnxConfig

* use OnnxConfig

* fix codestyle

* remove serialization.rst

* codestyle
parent efb35a41
...@@ -40,6 +40,7 @@ Ready-made configurations include the following models: ...@@ -40,6 +40,7 @@ Ready-made configurations include the following models:
- CamemBERT - CamemBERT
- DistilBERT - DistilBERT
- GPT Neo - GPT Neo
- I-BERT
- LayoutLM - LayoutLM
- Longformer - Longformer
- Marian - Marian
......
...@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available ...@@ -22,7 +22,7 @@ from ...file_utils import _LazyModule, is_torch_available
_import_structure = { _import_structure = {
"configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig"], "configuration_ibert": ["IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "IBertConfig", "IBertOnnxConfig"],
} }
if is_torch_available(): if is_torch_available():
...@@ -38,7 +38,7 @@ if is_torch_available(): ...@@ -38,7 +38,7 @@ if is_torch_available():
] ]
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig from .configuration_ibert import IBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, IBertConfig, IBertOnnxConfig
if is_torch_available(): if is_torch_available():
from .modeling_ibert import ( from .modeling_ibert import (
......
...@@ -15,6 +15,10 @@ ...@@ -15,6 +15,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" I-BERT configuration""" """ I-BERT configuration"""
from collections import OrderedDict
from typing import Mapping
from transformers.onnx import OnnxConfig
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...utils import logging from ...utils import logging
...@@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig): ...@@ -122,3 +126,14 @@ class IBertConfig(PretrainedConfig):
self.position_embedding_type = position_embedding_type self.position_embedding_type = position_embedding_type
self.quant_mode = quant_mode self.quant_mode = quant_mode
self.force_dequant = force_dequant self.force_dequant = force_dequant
class IBertOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("input_ids", {0: "batch", 1: "sequence"}),
("attention_mask", {0: "batch", 1: "sequence"}),
]
)
...@@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig ...@@ -9,6 +9,7 @@ from ..models.camembert import CamembertOnnxConfig
from ..models.distilbert import DistilBertOnnxConfig from ..models.distilbert import DistilBertOnnxConfig
from ..models.gpt2 import GPT2OnnxConfig from ..models.gpt2 import GPT2OnnxConfig
from ..models.gpt_neo import GPTNeoOnnxConfig from ..models.gpt_neo import GPTNeoOnnxConfig
from ..models.ibert import IBertOnnxConfig
from ..models.layoutlm import LayoutLMOnnxConfig from ..models.layoutlm import LayoutLMOnnxConfig
from ..models.longformer import LongformerOnnxConfig from ..models.longformer import LongformerOnnxConfig
from ..models.marian import MarianOnnxConfig from ..models.marian import MarianOnnxConfig
...@@ -125,6 +126,15 @@ class FeaturesManager: ...@@ -125,6 +126,15 @@ class FeaturesManager:
"question-answering", "question-answering",
onnx_config_cls=BertOnnxConfig, onnx_config_cls=BertOnnxConfig,
), ),
"ibert": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
# "multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls=IBertOnnxConfig,
),
"camembert": supported_features_mapping( "camembert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase): ...@@ -171,6 +171,7 @@ class OnnxConfigWithPastTestCaseV2(TestCase):
PYTORCH_EXPORT_MODELS = { PYTORCH_EXPORT_MODELS = {
("albert", "hf-internal-testing/tiny-albert"), ("albert", "hf-internal-testing/tiny-albert"),
("bert", "bert-base-cased"), ("bert", "bert-base-cased"),
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"), ("camembert", "camembert-base"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
# ("longFormer", "longformer-base-4096"), # ("longFormer", "longformer-base-4096"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment