Unverified Commit eb16be41 authored by mrbean's avatar mrbean Committed by GitHub
Browse files

add onnx support for deberta and debertav2 (#17617)



* add onnx support for debertav2

* debertav2 -> deberta-v2 in onnx features file

* remove causal lm

* add deberta-v2-xlarge to onnx tests

* use self.type().dtype() in xsoftmax
Co-authored-by: default avatarJingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* remove hack for deberta

* remove unused imports

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py
Co-authored-by: default avatarJingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* use generate dummy inputs

* linter

* add imports

* add support for deberta v1 as well

* deberta does not support multiple choice

* Update src/transformers/models/deberta/configuration_deberta.py
Co-authored-by: default avatarJingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* Update src/transformers/models/deberta_v2/configuration_deberta_v2.py
Co-authored-by: default avatarJingya HUANG <44135271+JingyaHuang@users.noreply.github.com>

* one line ordered dict

* fire build
Co-authored-by: default avatarJingya HUANG <44135271+JingyaHuang@users.noreply.github.com>
parent 8fcbe275
......@@ -58,6 +58,8 @@ Ready-made configurations include the following architectures:
- ConvNeXT
- Data2VecText
- Data2VecVision
- DeBERTa
- DeBERTa-v2
- DeiT
- DistilBERT
- ELECTRA
......
......@@ -28,7 +28,7 @@ from ...utils import (
_import_structure = {
"configuration_deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig"],
"configuration_deberta": ["DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaConfig", "DebertaOnnxConfig"],
"tokenization_deberta": ["DebertaTokenizer"],
}
......@@ -74,7 +74,7 @@ else:
if TYPE_CHECKING:
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig
from .configuration_deberta import DEBERTA_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaConfig, DebertaOnnxConfig
from .tokenization_deberta import DebertaTokenizer
try:
......
......@@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" DeBERTa model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional, Union
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -137,3 +141,41 @@ class DebertaConfig(PretrainedConfig):
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
self.pooler_dropout = pooler_dropout
self.pooler_hidden_act = pooler_hidden_act
# Copied from transformers.models.deberta_v2.configuration_deberta_v2.DebertaV2OnnxConfig
class DebertaOnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
if self._config.type_vocab_size > 0:
return OrderedDict(
[("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
)
else:
return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
@property
def default_onnx_opset(self) -> int:
return 12
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
tokenizer: "PreTrainedTokenizerBase" = None,
) -> Mapping[str, Any]:
dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
del dummy_inputs["token_type_ids"]
return dummy_inputs
......@@ -129,7 +129,9 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......
......@@ -28,7 +28,7 @@ from ...utils import (
_import_structure = {
"configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config"],
"configuration_deberta_v2": ["DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP", "DebertaV2Config", "DebertaV2OnnxConfig"],
"tokenization_deberta_v2": ["DebertaV2Tokenizer"],
}
......@@ -75,7 +75,11 @@ else:
if TYPE_CHECKING:
from .configuration_deberta_v2 import DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP, DebertaV2Config
from .configuration_deberta_v2 import (
DEBERTA_V2_PRETRAINED_CONFIG_ARCHIVE_MAP,
DebertaV2Config,
DebertaV2OnnxConfig,
)
from .tokenization_deberta_v2 import DebertaV2Tokenizer
try:
......
......@@ -13,8 +13,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
""" DeBERTa-v2 model configuration"""
from collections import OrderedDict
from typing import Any, Mapping, Optional, Union
from ... import FeatureExtractionMixin, PreTrainedTokenizerBase, TensorType
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -139,3 +143,40 @@ class DebertaV2Config(PretrainedConfig):
self.pooler_hidden_size = kwargs.get("pooler_hidden_size", hidden_size)
self.pooler_dropout = pooler_dropout
self.pooler_hidden_act = pooler_hidden_act
class DebertaV2OnnxConfig(OnnxConfig):
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
if self.task == "multiple-choice":
dynamic_axis = {0: "batch", 1: "choice", 2: "sequence"}
else:
dynamic_axis = {0: "batch", 1: "sequence"}
if self._config.type_vocab_size > 0:
return OrderedDict(
[("input_ids", dynamic_axis), ("attention_mask", dynamic_axis), ("token_type_ids", dynamic_axis)]
)
else:
return OrderedDict([("input_ids", dynamic_axis), ("attention_mask", dynamic_axis)])
@property
def default_onnx_opset(self) -> int:
return 12
def generate_dummy_inputs(
self,
preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin"],
batch_size: int = -1,
seq_length: int = -1,
num_choices: int = -1,
is_pair: bool = False,
framework: Optional[TensorType] = None,
num_channels: int = 3,
image_width: int = 40,
image_height: int = 40,
tokenizer: "PreTrainedTokenizerBase" = None,
) -> Mapping[str, Any]:
dummy_inputs = super().generate_dummy_inputs(preprocessor=preprocessor, framework=framework)
if self._config.type_vocab_size == 0 and "token_type_ids" in dummy_inputs:
del dummy_inputs["token_type_ids"]
return dummy_inputs
......@@ -132,7 +132,9 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......
......@@ -557,7 +557,9 @@ class XSoftmax(torch.autograd.Function):
g.op("Sub", g.op("Constant", value_t=torch.tensor(1, dtype=torch.int64)), mask_cast_value),
to_i=sym_help.cast_pytorch_to_onnx["Byte"],
)
output = masked_fill(g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.dtype).min)))
output = masked_fill(
g, self, r_mask, g.op("Constant", value_t=torch.tensor(torch.finfo(self.type().dtype()).min))
)
output = softmax(g, output, dim)
return masked_fill(g, output, r_mask, g.op("Constant", value_t=torch.tensor(0, dtype=torch.uint8)))
......
......@@ -351,7 +351,7 @@ def validate_model_outputs(
logger.info("Validating ONNX model...")
if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validatethe model outputs.")
raise ValueError("You cannot provide both a tokenizer and a preprocessor to validate the model outputs.")
if tokenizer is not None:
warnings.warn(
"The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
......
......@@ -207,6 +207,23 @@ class FeaturesManager:
"question-answering",
onnx_config_cls="models.data2vec.Data2VecTextOnnxConfig",
),
"deberta": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"token-classification",
"question-answering",
onnx_config_cls="models.deberta.DebertaOnnxConfig",
),
"deberta-v2": supported_features_mapping(
"default",
"masked-lm",
"sequence-classification",
"multiple-choice",
"token-classification",
"question-answering",
onnx_config_cls="models.deberta_v2.DebertaV2OnnxConfig",
),
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
),
......
......@@ -180,6 +180,8 @@ PYTORCH_EXPORT_MODELS = {
("ibert", "kssteven/ibert-roberta-base"),
("camembert", "camembert-base"),
("convbert", "YituTech/conv-bert-base"),
("deberta", "microsoft/deberta-base"),
("deberta-v2", "microsoft/deberta-v2-xlarge"),
("convnext", "facebook/convnext-tiny-224"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment