Unverified Commit dde718e7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DETR and friends] Remove is_timm_available (#21814)



* First draft

* Fix to_dict

* Improve conversion script

* Update config

* Remove timm dependency

* Fix dummies

* Fix typo, add integration test

* Upload 101 model as well

* Remove timm dummies

* Fix style

---------
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 2156662d
......@@ -866,52 +866,6 @@ else:
_import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
# Timm-backed objects
try:
if not (is_timm_available() and is_vision_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_timm_and_vision_objects
_import_structure["utils.dummy_timm_and_vision_objects"] = [
name for name in dir(dummy_timm_and_vision_objects) if not name.startswith("_")
]
else:
_import_structure["models.deformable_detr"].extend(
[
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DeformableDetrForObjectDetection",
"DeformableDetrModel",
"DeformableDetrPreTrainedModel",
]
)
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
_import_structure["models.table_transformer"].extend(
[
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TableTransformerForObjectDetection",
"TableTransformerModel",
"TableTransformerPreTrainedModel",
]
)
_import_structure["models.conditional_detr"].extend(
[
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConditionalDetrForObjectDetection",
"ConditionalDetrForSegmentation",
"ConditionalDetrModel",
"ConditionalDetrPreTrainedModel",
]
)
# PyTorch-backed objects
try:
......@@ -1309,6 +1263,15 @@ else:
"CodeGenPreTrainedModel",
]
)
_import_structure["models.conditional_detr"].extend(
[
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConditionalDetrForObjectDetection",
"ConditionalDetrForSegmentation",
"ConditionalDetrModel",
"ConditionalDetrPreTrainedModel",
]
)
_import_structure["models.convbert"].extend(
[
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1406,6 +1369,14 @@ else:
"DecisionTransformerPreTrainedModel",
]
)
_import_structure["models.deformable_detr"].extend(
[
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DeformableDetrForObjectDetection",
"DeformableDetrModel",
"DeformableDetrPreTrainedModel",
]
)
_import_structure["models.deit"].extend(
[
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -1424,6 +1395,15 @@ else:
"DetaPreTrainedModel",
]
)
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
_import_structure["models.dinat"].extend(
[
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -2372,6 +2352,14 @@ else:
"load_tf_weights_in_t5",
]
)
_import_structure["models.table_transformer"].extend(
[
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TableTransformerForObjectDetection",
"TableTransformerModel",
"TableTransformerPreTrainedModel",
]
)
_import_structure["models.tapas"].extend(
[
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
......@@ -4398,39 +4386,6 @@ if TYPE_CHECKING:
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
# Modeling
try:
if not (is_timm_available() and is_vision_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_timm_and_vision_objects import *
else:
from .models.conditional_detr import (
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ConditionalDetrPreTrainedModel,
)
from .models.deformable_detr import (
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DeformableDetrForObjectDetection,
DeformableDetrModel,
DeformableDetrPreTrainedModel,
)
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
from .models.table_transformer import (
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TableTransformerForObjectDetection,
TableTransformerModel,
TableTransformerPreTrainedModel,
)
try:
if not is_torch_available():
raise OptionalDependencyNotAvailable()
......@@ -4767,6 +4722,13 @@ if TYPE_CHECKING:
CodeGenModel,
CodeGenPreTrainedModel,
)
from .models.conditional_detr import (
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ConditionalDetrPreTrainedModel,
)
from .models.convbert import (
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvBertForMaskedLM,
......@@ -4848,6 +4810,12 @@ if TYPE_CHECKING:
DecisionTransformerModel,
DecisionTransformerPreTrainedModel,
)
from .models.deformable_detr import (
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DeformableDetrForObjectDetection,
DeformableDetrModel,
DeformableDetrPreTrainedModel,
)
from .models.deit import (
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
DeiTForImageClassification,
......@@ -4862,6 +4830,13 @@ if TYPE_CHECKING:
DetaModel,
DetaPreTrainedModel,
)
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
from .models.dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone,
......@@ -5626,6 +5601,12 @@ if TYPE_CHECKING:
T5PreTrainedModel,
load_tf_weights_in_t5,
)
from .models.table_transformer import (
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TableTransformerForObjectDetection,
TableTransformerModel,
TableTransformerPreTrainedModel,
)
from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM,
......
......@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
......@@ -35,7 +35,7 @@ else:
_import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"]
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......@@ -66,7 +66,7 @@ if TYPE_CHECKING:
from .image_processing_conditional_detr import ConditionalDetrImageProcessor
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......
......@@ -1101,12 +1101,12 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
......
......@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {
......@@ -31,7 +31,7 @@ else:
_import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"]
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .image_processing_deformable_detr import DeformableDetrImageProcessor
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......
......@@ -1099,12 +1099,12 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
......
......@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
......@@ -29,7 +29,7 @@ else:
_import_structure["image_processing_detr"] = ["DetrImageProcessor"]
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......@@ -56,7 +56,7 @@ if TYPE_CHECKING:
from .image_processing_detr import DetrImageProcessor
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......
......@@ -14,8 +14,9 @@
# limitations under the License.
""" DETR model configuration"""
import copy
from collections import OrderedDict
from typing import Mapping
from typing import Dict, Mapping
from packaging import version
......@@ -187,6 +188,8 @@ class DetrConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
......@@ -233,6 +236,28 @@ class DetrConfig(PretrainedConfig):
def hidden_size(self) -> int:
return self.d_model
@classmethod
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
Returns:
[`DetrConfig`]: An instance of a configuration object
"""
return cls(backbone_config=backbone_config, **kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11")
......
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
# Copyright 2023 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
def get_detr_config(model_name):
config = DetrConfig(use_timm_backbone=False)
# set backbone attributes
if "resnet50" in model_name:
pass
elif "resnet101" in model_name:
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
# initialize config
if "resnet-50" in model_name:
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
elif "resnet-101" in model_name:
backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
else:
raise ValueError("Model name should include either resnet50 or resnet101")
config = DetrConfig(use_timm_backbone=False, backbone_config=backbone_config)
# set label attributes
is_panoptic = "panoptic" in model_name
if is_panoptic:
......@@ -286,7 +286,7 @@ def prepare_img():
@torch.no_grad()
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
"""
Copy/paste/tweak model's weights to our DETR structure.
"""
......@@ -295,8 +295,12 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
config, is_panoptic = get_detr_config(model_name)
# load original model from torch hub
model_name_to_original_name = {
"detr-resnet-50": "detr_resnet50",
"detr-resnet-101": "detr_resnet101",
}
logger.info(f"Converting model {model_name}...")
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval()
detr = torch.hub.load("facebookresearch/detr", model_name_to_original_name[model_name], pretrained=True).eval()
state_dict = detr.state_dict()
# rename keys
for src, dest in create_rename_keys(config):
......@@ -344,9 +348,6 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
original_outputs = detr(pixel_values)
outputs = model(pixel_values)
print("Logits:", outputs.logits[0, :3, :3])
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
if is_panoptic:
......@@ -360,15 +361,26 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Upload model and image processor to the hub
logger.info("Uploading PyTorch model and image processor to the hub...")
model.push_to_hub(f"nielsr/{model_name}")
processor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert."
"--model_name",
default="detr-resnet-50",
type=str,
choices=["detr-resnet-50", "detr-resnet-101"],
help="Name of the DETR model you'd like to convert.",
)
parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
)
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
args = parser.parse_args()
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path)
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
......@@ -1065,12 +1065,12 @@ class DetrImageProcessor(BaseImageProcessor):
images (`ImageInput`):
Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object
List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys:
If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty.
......
......@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = {
......@@ -26,7 +26,7 @@ _import_structure = {
}
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......@@ -47,7 +47,7 @@ if TYPE_CHECKING:
)
try:
if not is_timm_available():
if not is_torch_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
pass
......
......@@ -189,6 +189,8 @@ class TableTransformerConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config)
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config
......
......@@ -1661,6 +1661,37 @@ class CodeGenPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrForSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......@@ -2073,6 +2104,30 @@ class DecisionTransformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DeformableDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeformableDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......@@ -2135,6 +2190,37 @@ class DetaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"])
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrForSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
......@@ -6040,6 +6126,30 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"])
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TableTransformerForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TableTransformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TableTransformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrForSegmentation(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DeformableDetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrForSegmentation(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TableTransformerForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class TableTransformerModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class TableTransformerPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
......@@ -20,7 +20,7 @@ import math
import unittest
from transformers import DetrConfig, is_timm_available, is_vision_available
from transformers.testing_utils import require_timm, require_vision, slow, torch_device
from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin
......@@ -510,7 +510,7 @@ def prepare_img():
@require_timm
@require_vision
@slow
class DetrModelIntegrationTests(unittest.TestCase):
class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None
......@@ -626,3 +626,33 @@ class DetrModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4))
self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
self.assertDictEqual(results["segments_info"][0], expected_first_segment)
@require_vision
@require_torch
@slow
class DetrModelIntegrationTests(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
if is_vision_available()
else None
)
def test_inference_no_head(self):
model = DetrModel.from_pretrained("facebook/detr-resnet-50", revision="no_timm").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**encoding)
expected_shape = torch.Size((1, 100, 256))
assert outputs.last_hidden_state.shape == expected_shape
expected_slice = torch.tensor(
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment