Unverified Commit 76d13de5 authored by regisss's avatar regisss Committed by GitHub
Browse files

Add ONNX support for DETR (#17904)

parent bfcd5743
......@@ -62,6 +62,7 @@ Ready-made configurations include the following architectures:
- DeBERTa
- DeBERTa-v2
- DeiT
- DETR
- DistilBERT
- ELECTRA
- FlauBERT
......
......@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"]}
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
try:
if not is_vision_available():
......@@ -47,7 +47,7 @@ else:
if TYPE_CHECKING:
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrOnnxConfig
try:
if not is_vision_available():
......
......@@ -14,7 +14,13 @@
# limitations under the License.
""" DETR model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging
......@@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig):
@property
def hidden_size(self) -> int:
return self.d_model
class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
("pixel_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-5
@property
def default_onnx_opset(self) -> int:
return 12
......@@ -77,9 +77,22 @@ class OnnxConfig(ABC):
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"image-segmentation": OrderedDict(
{
"logits": {0: "batch", 1: "sequence"},
"pred_boxes": {0: "batch", 1: "sequence"},
"pred_masks": {0: "batch", 1: "sequence"},
}
),
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"object-detection": OrderedDict(
{
"logits": {0: "batch", 1: "sequence"},
"pred_boxes": {0: "batch", 1: "sequence"},
}
),
"question-answering": OrderedDict(
{
"start_logits": {0: "batch", 1: "sequence"},
......
......@@ -15,9 +15,11 @@ if is_torch_available():
AutoModel,
AutoModelForCausalLM,
AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForMaskedImageModeling,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification,
......@@ -83,8 +85,10 @@ class FeaturesManager:
"sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice,
"object-detection": AutoModelForObjectDetection,
"question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification,
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling,
}
if is_tf_available():
......@@ -227,6 +231,12 @@ class FeaturesManager:
"deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
),
"detr": supported_features_mapping(
"default",
"object-detection",
"image-segmentation",
onnx_config_cls="models.detr.DetrOnnxConfig",
),
"distilbert": supported_features_mapping(
"default",
"masked-lm",
......
......@@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = {
("deberta", "microsoft/deberta-base"),
("deberta-v2", "microsoft/deberta-v2-xlarge"),
("convnext", "facebook/convnext-tiny-224"),
("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment