Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
76d13de5
"vscode:/vscode.git/clone" did not exist on "84eec9e6ba55c5aceee2a92fd820fcca4b67c510"
Unverified
Commit
76d13de5
authored
Jun 28, 2022
by
regisss
Committed by
GitHub
Jun 28, 2022
Browse files
Add ONNX support for DETR (#17904)
parent
bfcd5743
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
55 additions
and
2 deletions
+55
-2
docs/source/en/serialization.mdx
docs/source/en/serialization.mdx
+1
-0
src/transformers/models/detr/__init__.py
src/transformers/models/detr/__init__.py
+2
-2
src/transformers/models/detr/configuration_detr.py
src/transformers/models/detr/configuration_detr.py
+28
-0
src/transformers/onnx/config.py
src/transformers/onnx/config.py
+13
-0
src/transformers/onnx/features.py
src/transformers/onnx/features.py
+10
-0
tests/onnx/test_onnx_v2.py
tests/onnx/test_onnx_v2.py
+1
-0
No files found.
docs/source/en/serialization.mdx
View file @
76d13de5
...
...
@@ -62,6 +62,7 @@ Ready-made configurations include the following architectures:
- DeBERTa
- DeBERTa-v2
- DeiT
- DETR
- DistilBERT
- ELECTRA
- FlauBERT
...
...
src/transformers/models/detr/__init__.py
View file @
76d13de5
...
...
@@ -21,7 +21,7 @@ from typing import TYPE_CHECKING
from
...utils
import
OptionalDependencyNotAvailable
,
_LazyModule
,
is_timm_available
,
is_vision_available
_import_structure
=
{
"configuration_detr"
:
[
"DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"DetrConfig"
]}
_import_structure
=
{
"configuration_detr"
:
[
"DETR_PRETRAINED_CONFIG_ARCHIVE_MAP"
,
"DetrConfig"
,
"DetrOnnxConfig"
]}
try
:
if
not
is_vision_available
():
...
...
@@ -47,7 +47,7 @@ else:
if
TYPE_CHECKING
:
from
.configuration_detr
import
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DetrConfig
from
.configuration_detr
import
DETR_PRETRAINED_CONFIG_ARCHIVE_MAP
,
DetrConfig
,
DetrOnnxConfig
try
:
if
not
is_vision_available
():
...
...
src/transformers/models/detr/configuration_detr.py
View file @
76d13de5
...
...
@@ -14,7 +14,13 @@
# limitations under the License.
""" DETR model configuration"""
from
collections
import
OrderedDict
from
typing
import
Mapping
from
packaging
import
version
from
...configuration_utils
import
PretrainedConfig
from
...onnx
import
OnnxConfig
from
...utils
import
logging
...
...
@@ -204,3 +210,25 @@ class DetrConfig(PretrainedConfig):
@
property
def
hidden_size
(
self
)
->
int
:
return
self
.
d_model
class
DetrOnnxConfig
(
OnnxConfig
):
torch_onnx_minimum_version
=
version
.
parse
(
"1.11"
)
@
property
def
inputs
(
self
)
->
Mapping
[
str
,
Mapping
[
int
,
str
]]:
return
OrderedDict
(
[
(
"pixel_values"
,
{
0
:
"batch"
,
1
:
"sequence"
}),
(
"pixel_mask"
,
{
0
:
"batch"
,
1
:
"sequence"
}),
]
)
@
property
def
atol_for_validation
(
self
)
->
float
:
return
1e-5
@
property
def
default_onnx_opset
(
self
)
->
int
:
return
12
src/transformers/onnx/config.py
View file @
76d13de5
...
...
@@ -77,9 +77,22 @@ class OnnxConfig(ABC):
"causal-lm"
:
OrderedDict
({
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
}}),
"default"
:
OrderedDict
({
"last_hidden_state"
:
{
0
:
"batch"
,
1
:
"sequence"
}}),
"image-classification"
:
OrderedDict
({
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
}}),
"image-segmentation"
:
OrderedDict
(
{
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
},
"pred_boxes"
:
{
0
:
"batch"
,
1
:
"sequence"
},
"pred_masks"
:
{
0
:
"batch"
,
1
:
"sequence"
},
}
),
"masked-im"
:
OrderedDict
({
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
}}),
"masked-lm"
:
OrderedDict
({
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
}}),
"multiple-choice"
:
OrderedDict
({
"logits"
:
{
0
:
"batch"
}}),
"object-detection"
:
OrderedDict
(
{
"logits"
:
{
0
:
"batch"
,
1
:
"sequence"
},
"pred_boxes"
:
{
0
:
"batch"
,
1
:
"sequence"
},
}
),
"question-answering"
:
OrderedDict
(
{
"start_logits"
:
{
0
:
"batch"
,
1
:
"sequence"
},
...
...
src/transformers/onnx/features.py
View file @
76d13de5
...
...
@@ -15,9 +15,11 @@ if is_torch_available():
AutoModel
,
AutoModelForCausalLM
,
AutoModelForImageClassification
,
AutoModelForImageSegmentation
,
AutoModelForMaskedImageModeling
,
AutoModelForMaskedLM
,
AutoModelForMultipleChoice
,
AutoModelForObjectDetection
,
AutoModelForQuestionAnswering
,
AutoModelForSeq2SeqLM
,
AutoModelForSequenceClassification
,
...
...
@@ -83,8 +85,10 @@ class FeaturesManager:
"sequence-classification"
:
AutoModelForSequenceClassification
,
"token-classification"
:
AutoModelForTokenClassification
,
"multiple-choice"
:
AutoModelForMultipleChoice
,
"object-detection"
:
AutoModelForObjectDetection
,
"question-answering"
:
AutoModelForQuestionAnswering
,
"image-classification"
:
AutoModelForImageClassification
,
"image-segmentation"
:
AutoModelForImageSegmentation
,
"masked-im"
:
AutoModelForMaskedImageModeling
,
}
if
is_tf_available
():
...
...
@@ -227,6 +231,12 @@ class FeaturesManager:
"deit"
:
supported_features_mapping
(
"default"
,
"image-classification"
,
"masked-im"
,
onnx_config_cls
=
"models.deit.DeiTOnnxConfig"
),
"detr"
:
supported_features_mapping
(
"default"
,
"object-detection"
,
"image-segmentation"
,
onnx_config_cls
=
"models.detr.DetrOnnxConfig"
,
),
"distilbert"
:
supported_features_mapping
(
"default"
,
"masked-lm"
,
...
...
tests/onnx/test_onnx_v2.py
View file @
76d13de5
...
...
@@ -183,6 +183,7 @@ PYTORCH_EXPORT_MODELS = {
(
"deberta"
,
"microsoft/deberta-base"
),
(
"deberta-v2"
,
"microsoft/deberta-v2-xlarge"
),
(
"convnext"
,
"facebook/convnext-tiny-224"
),
(
"detr"
,
"facebook/detr-resnet-50"
),
(
"distilbert"
,
"distilbert-base-cased"
),
(
"electra"
,
"google/electra-base-generator"
),
(
"resnet"
,
"microsoft/resnet-50"
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment