Unverified Commit 76d13de5 authored by regisss's avatar regisss Committed by GitHub
Browse files

Add ONNX support for DETR (#17904)

parent bfcd5743
...@@ -62,6 +62,7 @@ Ready-made configurations include the following architectures: ...@@ -62,6 +62,7 @@ Ready-made configurations include the following architectures:
- DeBERTa - DeBERTa
- DeBERTa-v2 - DeBERTa-v2
- DeiT - DeiT
- DETR
- DistilBERT - DistilBERT
- ELECTRA - ELECTRA
- FlauBERT - FlauBERT
......
...@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING ...@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig"]} _import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
try: try:
if not is_vision_available(): if not is_vision_available():
...@@ -47,7 +47,7 @@ else: ...@@ -47,7 +47,7 @@ else:
if TYPE_CHECKING: if TYPE_CHECKING:
from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig from .configuration_detr import DETR_PRETRAINED_CONFIG_ARCHIVE_MAP, DetrConfig, DetrOnnxConfig
try: try:
if not is_vision_available(): if not is_vision_available():
......
...@@ -14,7 +14,13 @@ ...@@ -14,7 +14,13 @@
# limitations under the License. # limitations under the License.
""" DETR model configuration""" """ DETR model configuration"""
from collections import OrderedDict
from typing import Mapping
from packaging import version
from ...configuration_utils import PretrainedConfig from ...configuration_utils import PretrainedConfig
from ...onnx import OnnxConfig
from ...utils import logging from ...utils import logging
...@@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig): ...@@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig):
@property @property
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
@property
def inputs(self) -> Mapping[str, Mapping[int, str]]:
return OrderedDict(
[
("pixel_values", {0: "batch", 1: "sequence"}),
("pixel_mask", {0: "batch", 1: "sequence"}),
]
)
@property
def atol_for_validation(self) -> float:
return 1e-5
@property
def default_onnx_opset(self) -> int:
return 12
...@@ -77,9 +77,22 @@ class OnnxConfig(ABC): ...@@ -77,9 +77,22 @@ class OnnxConfig(ABC):
"causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}), "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
"image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"image-segmentation": OrderedDict(
{
"logits": {0: "batch", 1: "sequence"},
"pred_boxes": {0: "batch", 1: "sequence"},
"pred_masks": {0: "batch", 1: "sequence"},
}
),
"masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}), "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
"multiple-choice": OrderedDict({"logits": {0: "batch"}}), "multiple-choice": OrderedDict({"logits": {0: "batch"}}),
"object-detection": OrderedDict(
{
"logits": {0: "batch", 1: "sequence"},
"pred_boxes": {0: "batch", 1: "sequence"},
}
),
"question-answering": OrderedDict( "question-answering": OrderedDict(
{ {
"start_logits": {0: "batch", 1: "sequence"}, "start_logits": {0: "batch", 1: "sequence"},
......
...@@ -15,9 +15,11 @@ if is_torch_available(): ...@@ -15,9 +15,11 @@ if is_torch_available():
AutoModel, AutoModel,
AutoModelForCausalLM, AutoModelForCausalLM,
AutoModelForImageClassification, AutoModelForImageClassification,
AutoModelForImageSegmentation,
AutoModelForMaskedImageModeling, AutoModelForMaskedImageModeling,
AutoModelForMaskedLM, AutoModelForMaskedLM,
AutoModelForMultipleChoice, AutoModelForMultipleChoice,
AutoModelForObjectDetection,
AutoModelForQuestionAnswering, AutoModelForQuestionAnswering,
AutoModelForSeq2SeqLM, AutoModelForSeq2SeqLM,
AutoModelForSequenceClassification, AutoModelForSequenceClassification,
...@@ -83,8 +85,10 @@ class FeaturesManager: ...@@ -83,8 +85,10 @@ class FeaturesManager:
"sequence-classification": AutoModelForSequenceClassification, "sequence-classification": AutoModelForSequenceClassification,
"token-classification": AutoModelForTokenClassification, "token-classification": AutoModelForTokenClassification,
"multiple-choice": AutoModelForMultipleChoice, "multiple-choice": AutoModelForMultipleChoice,
"object-detection": AutoModelForObjectDetection,
"question-answering": AutoModelForQuestionAnswering, "question-answering": AutoModelForQuestionAnswering,
"image-classification": AutoModelForImageClassification, "image-classification": AutoModelForImageClassification,
"image-segmentation": AutoModelForImageSegmentation,
"masked-im": AutoModelForMaskedImageModeling, "masked-im": AutoModelForMaskedImageModeling,
} }
if is_tf_available(): if is_tf_available():
...@@ -227,6 +231,12 @@ class FeaturesManager: ...@@ -227,6 +231,12 @@ class FeaturesManager:
"deit": supported_features_mapping( "deit": supported_features_mapping(
"default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig" "default", "image-classification", "masked-im", onnx_config_cls="models.deit.DeiTOnnxConfig"
), ),
"detr": supported_features_mapping(
"default",
"object-detection",
"image-segmentation",
onnx_config_cls="models.detr.DetrOnnxConfig",
),
"distilbert": supported_features_mapping( "distilbert": supported_features_mapping(
"default", "default",
"masked-lm", "masked-lm",
......
...@@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = { ...@@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = {
("deberta", "microsoft/deberta-base"), ("deberta", "microsoft/deberta-base"),
("deberta-v2", "microsoft/deberta-v2-xlarge"), ("deberta-v2", "microsoft/deberta-v2-xlarge"),
("convnext", "facebook/convnext-tiny-224"), ("convnext", "facebook/convnext-tiny-224"),
("detr", "facebook/detr-resnet-50"),
("distilbert", "distilbert-base-cased"), ("distilbert", "distilbert-base-cased"),
("electra", "google/electra-base-generator"), ("electra", "google/electra-base-generator"),
("resnet", "microsoft/resnet-50"), ("resnet", "microsoft/resnet-50"),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment