Unverified Commit dde718e7 authored by NielsRogge's avatar NielsRogge Committed by GitHub
Browse files

[DETR and friends] Remove is_timm_available (#21814)



* First draft

* Fix to_dict

* Improve conversion script

* Update config

* Remove timm dependency

* Fix dummies

* Fix typo, add integration test

* Upload 101 model as well

* Remove timm dummies

* Fix style

---------
Co-authored-by: default avatarNiels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
parent 2156662d
...@@ -866,52 +866,6 @@ else: ...@@ -866,52 +866,6 @@ else:
_import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"]) _import_structure["models.vit_hybrid"].extend(["ViTHybridImageProcessor"])
_import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"]) _import_structure["models.yolos"].extend(["YolosFeatureExtractor", "YolosImageProcessor"])
# Timm-backed objects
try:
if not (is_timm_available() and is_vision_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_timm_and_vision_objects
_import_structure["utils.dummy_timm_and_vision_objects"] = [
name for name in dir(dummy_timm_and_vision_objects) if not name.startswith("_")
]
else:
_import_structure["models.deformable_detr"].extend(
[
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DeformableDetrForObjectDetection",
"DeformableDetrModel",
"DeformableDetrPreTrainedModel",
]
)
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
_import_structure["models.table_transformer"].extend(
[
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TableTransformerForObjectDetection",
"TableTransformerModel",
"TableTransformerPreTrainedModel",
]
)
_import_structure["models.conditional_detr"].extend(
[
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConditionalDetrForObjectDetection",
"ConditionalDetrForSegmentation",
"ConditionalDetrModel",
"ConditionalDetrPreTrainedModel",
]
)
# PyTorch-backed objects # PyTorch-backed objects
try: try:
...@@ -1309,6 +1263,15 @@ else: ...@@ -1309,6 +1263,15 @@ else:
"CodeGenPreTrainedModel", "CodeGenPreTrainedModel",
] ]
) )
_import_structure["models.conditional_detr"].extend(
[
"CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"ConditionalDetrForObjectDetection",
"ConditionalDetrForSegmentation",
"ConditionalDetrModel",
"ConditionalDetrPreTrainedModel",
]
)
_import_structure["models.convbert"].extend( _import_structure["models.convbert"].extend(
[ [
"CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST", "CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1406,6 +1369,14 @@ else: ...@@ -1406,6 +1369,14 @@ else:
"DecisionTransformerPreTrainedModel", "DecisionTransformerPreTrainedModel",
] ]
) )
_import_structure["models.deformable_detr"].extend(
[
"DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DeformableDetrForObjectDetection",
"DeformableDetrModel",
"DeformableDetrPreTrainedModel",
]
)
_import_structure["models.deit"].extend( _import_structure["models.deit"].extend(
[ [
"DEIT_PRETRAINED_MODEL_ARCHIVE_LIST", "DEIT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -1424,6 +1395,15 @@ else: ...@@ -1424,6 +1395,15 @@ else:
"DetaPreTrainedModel", "DetaPreTrainedModel",
] ]
) )
_import_structure["models.detr"].extend(
[
"DETR_PRETRAINED_MODEL_ARCHIVE_LIST",
"DetrForObjectDetection",
"DetrForSegmentation",
"DetrModel",
"DetrPreTrainedModel",
]
)
_import_structure["models.dinat"].extend( _import_structure["models.dinat"].extend(
[ [
"DINAT_PRETRAINED_MODEL_ARCHIVE_LIST", "DINAT_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -2372,6 +2352,14 @@ else: ...@@ -2372,6 +2352,14 @@ else:
"load_tf_weights_in_t5", "load_tf_weights_in_t5",
] ]
) )
_import_structure["models.table_transformer"].extend(
[
"TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST",
"TableTransformerForObjectDetection",
"TableTransformerModel",
"TableTransformerPreTrainedModel",
]
)
_import_structure["models.tapas"].extend( _import_structure["models.tapas"].extend(
[ [
"TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST", "TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST",
...@@ -4398,39 +4386,6 @@ if TYPE_CHECKING: ...@@ -4398,39 +4386,6 @@ if TYPE_CHECKING:
from .models.yolos import YolosFeatureExtractor, YolosImageProcessor from .models.yolos import YolosFeatureExtractor, YolosImageProcessor
# Modeling # Modeling
try:
if not (is_timm_available() and is_vision_available()):
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_timm_and_vision_objects import *
else:
from .models.conditional_detr import (
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ConditionalDetrPreTrainedModel,
)
from .models.deformable_detr import (
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DeformableDetrForObjectDetection,
DeformableDetrModel,
DeformableDetrPreTrainedModel,
)
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
from .models.table_transformer import (
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TableTransformerForObjectDetection,
TableTransformerModel,
TableTransformerPreTrainedModel,
)
try: try:
if not is_torch_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
...@@ -4767,6 +4722,13 @@ if TYPE_CHECKING: ...@@ -4767,6 +4722,13 @@ if TYPE_CHECKING:
CodeGenModel, CodeGenModel,
CodeGenPreTrainedModel, CodeGenPreTrainedModel,
) )
from .models.conditional_detr import (
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
ConditionalDetrForObjectDetection,
ConditionalDetrForSegmentation,
ConditionalDetrModel,
ConditionalDetrPreTrainedModel,
)
from .models.convbert import ( from .models.convbert import (
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST, CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
ConvBertForMaskedLM, ConvBertForMaskedLM,
...@@ -4848,6 +4810,12 @@ if TYPE_CHECKING: ...@@ -4848,6 +4810,12 @@ if TYPE_CHECKING:
DecisionTransformerModel, DecisionTransformerModel,
DecisionTransformerPreTrainedModel, DecisionTransformerPreTrainedModel,
) )
from .models.deformable_detr import (
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DeformableDetrForObjectDetection,
DeformableDetrModel,
DeformableDetrPreTrainedModel,
)
from .models.deit import ( from .models.deit import (
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST, DEIT_PRETRAINED_MODEL_ARCHIVE_LIST,
DeiTForImageClassification, DeiTForImageClassification,
...@@ -4862,6 +4830,13 @@ if TYPE_CHECKING: ...@@ -4862,6 +4830,13 @@ if TYPE_CHECKING:
DetaModel, DetaModel,
DetaPreTrainedModel, DetaPreTrainedModel,
) )
from .models.detr import (
DETR_PRETRAINED_MODEL_ARCHIVE_LIST,
DetrForObjectDetection,
DetrForSegmentation,
DetrModel,
DetrPreTrainedModel,
)
from .models.dinat import ( from .models.dinat import (
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST, DINAT_PRETRAINED_MODEL_ARCHIVE_LIST,
DinatBackbone, DinatBackbone,
...@@ -5626,6 +5601,12 @@ if TYPE_CHECKING: ...@@ -5626,6 +5601,12 @@ if TYPE_CHECKING:
T5PreTrainedModel, T5PreTrainedModel,
load_tf_weights_in_t5, load_tf_weights_in_t5,
) )
from .models.table_transformer import (
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST,
TableTransformerForObjectDetection,
TableTransformerModel,
TableTransformerPreTrainedModel,
)
from .models.tapas import ( from .models.tapas import (
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST, TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST,
TapasForMaskedLM, TapasForMaskedLM,
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
...@@ -35,7 +35,7 @@ else: ...@@ -35,7 +35,7 @@ else:
_import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"] _import_structure["image_processing_conditional_detr"] = ["ConditionalDetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
...@@ -66,7 +66,7 @@ if TYPE_CHECKING: ...@@ -66,7 +66,7 @@ if TYPE_CHECKING:
from .image_processing_conditional_detr import ConditionalDetrImageProcessor from .image_processing_conditional_detr import ConditionalDetrImageProcessor
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
......
...@@ -1101,12 +1101,12 @@ class ConditionalDetrImageProcessor(BaseImageProcessor): ...@@ -1101,12 +1101,12 @@ class ConditionalDetrImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image or batch of images to preprocess. Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys: detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty. dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys: If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty. An image can have no segments, in which case the list should be empty.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = { _import_structure = {
...@@ -31,7 +31,7 @@ else: ...@@ -31,7 +31,7 @@ else:
_import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"] _import_structure["image_processing_deformable_detr"] = ["DeformableDetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
...@@ -57,7 +57,7 @@ if TYPE_CHECKING: ...@@ -57,7 +57,7 @@ if TYPE_CHECKING:
from .image_processing_deformable_detr import DeformableDetrImageProcessor from .image_processing_deformable_detr import DeformableDetrImageProcessor
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
......
...@@ -1099,12 +1099,12 @@ class DeformableDetrImageProcessor(BaseImageProcessor): ...@@ -1099,12 +1099,12 @@ class DeformableDetrImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image or batch of images to preprocess. Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys: detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty. dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys: If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty. An image can have no segments, in which case the list should be empty.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available, is_vision_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_vision_available
_import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]} _import_structure = {"configuration_detr": ["DETR_PRETRAINED_CONFIG_ARCHIVE_MAP", "DetrConfig", "DetrOnnxConfig"]}
...@@ -29,7 +29,7 @@ else: ...@@ -29,7 +29,7 @@ else:
_import_structure["image_processing_detr"] = ["DetrImageProcessor"] _import_structure["image_processing_detr"] = ["DetrImageProcessor"]
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
...@@ -56,7 +56,7 @@ if TYPE_CHECKING: ...@@ -56,7 +56,7 @@ if TYPE_CHECKING:
from .image_processing_detr import DetrImageProcessor from .image_processing_detr import DetrImageProcessor
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
......
...@@ -14,8 +14,9 @@ ...@@ -14,8 +14,9 @@
# limitations under the License. # limitations under the License.
""" DETR model configuration""" """ DETR model configuration"""
import copy
from collections import OrderedDict from collections import OrderedDict
from typing import Mapping from typing import Dict, Mapping
from packaging import version from packaging import version
...@@ -187,6 +188,8 @@ class DetrConfig(PretrainedConfig): ...@@ -187,6 +188,8 @@ class DetrConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type") backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
...@@ -233,6 +236,28 @@ class DetrConfig(PretrainedConfig): ...@@ -233,6 +236,28 @@ class DetrConfig(PretrainedConfig):
def hidden_size(self) -> int: def hidden_size(self) -> int:
return self.d_model return self.d_model
@classmethod
def from_backbone_config(cls, backbone_config: PretrainedConfig, **kwargs):
"""Instantiate a [`DetrConfig`] (or a derived class) from a pre-trained backbone model configuration.
Args:
backbone_config ([`PretrainedConfig`]):
The backbone configuration.
Returns:
[`DetrConfig`]: An instance of a configuration object
"""
return cls(backbone_config=backbone_config, **kwargs)
def to_dict(self) -> Dict[str, any]:
"""
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. Returns:
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
"""
output = copy.deepcopy(self.__dict__)
if output["backbone_config"] is not None:
output["backbone_config"] = self.backbone_config.to_dict()
output["model_type"] = self.__class__.model_type
return output
class DetrOnnxConfig(OnnxConfig): class DetrOnnxConfig(OnnxConfig):
torch_onnx_minimum_version = version.parse("1.11") torch_onnx_minimum_version = version.parse("1.11")
......
# coding=utf-8 # coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. # Copyright 2023 The HuggingFace Inc. team.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
...@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__) ...@@ -33,16 +33,16 @@ logger = logging.get_logger(__name__)
def get_detr_config(model_name): def get_detr_config(model_name):
config = DetrConfig(use_timm_backbone=False) # initialize config
if "resnet-50" in model_name:
# set backbone attributes backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-50")
if "resnet50" in model_name: elif "resnet-101" in model_name:
pass backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
elif "resnet101" in model_name:
config.backbone_config = ResNetConfig.from_pretrained("microsoft/resnet-101")
else: else:
raise ValueError("Model name should include either resnet50 or resnet101") raise ValueError("Model name should include either resnet50 or resnet101")
config = DetrConfig(use_timm_backbone=False, backbone_config=backbone_config)
# set label attributes # set label attributes
is_panoptic = "panoptic" in model_name is_panoptic = "panoptic" in model_name
if is_panoptic: if is_panoptic:
...@@ -286,7 +286,7 @@ def prepare_img(): ...@@ -286,7 +286,7 @@ def prepare_img():
@torch.no_grad() @torch.no_grad()
def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): def convert_detr_checkpoint(model_name, pytorch_dump_folder_path=None, push_to_hub=False):
""" """
Copy/paste/tweak model's weights to our DETR structure. Copy/paste/tweak model's weights to our DETR structure.
""" """
...@@ -295,8 +295,12 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): ...@@ -295,8 +295,12 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
config, is_panoptic = get_detr_config(model_name) config, is_panoptic = get_detr_config(model_name)
# load original model from torch hub # load original model from torch hub
model_name_to_original_name = {
"detr-resnet-50": "detr_resnet50",
"detr-resnet-101": "detr_resnet101",
}
logger.info(f"Converting model {model_name}...") logger.info(f"Converting model {model_name}...")
detr = torch.hub.load("facebookresearch/detr", model_name, pretrained=True).eval() detr = torch.hub.load("facebookresearch/detr", model_name_to_original_name[model_name], pretrained=True).eval()
state_dict = detr.state_dict() state_dict = detr.state_dict()
# rename keys # rename keys
for src, dest in create_rename_keys(config): for src, dest in create_rename_keys(config):
...@@ -344,9 +348,6 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): ...@@ -344,9 +348,6 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
original_outputs = detr(pixel_values) original_outputs = detr(pixel_values)
outputs = model(pixel_values) outputs = model(pixel_values)
print("Logits:", outputs.logits[0, :3, :3])
print("Original logits:", original_outputs["pred_logits"][0, :3, :3])
assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3) assert torch.allclose(outputs.logits, original_outputs["pred_logits"], atol=1e-3)
assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3) assert torch.allclose(outputs.pred_boxes, original_outputs["pred_boxes"], atol=1e-3)
if is_panoptic: if is_panoptic:
...@@ -360,15 +361,26 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path): ...@@ -360,15 +361,26 @@ def convert_detr_checkpoint(model_name, pytorch_dump_folder_path):
model.save_pretrained(pytorch_dump_folder_path) model.save_pretrained(pytorch_dump_folder_path)
processor.save_pretrained(pytorch_dump_folder_path) processor.save_pretrained(pytorch_dump_folder_path)
if push_to_hub:
# Upload model and image processor to the hub
logger.info("Uploading PyTorch model and image processor to the hub...")
model.push_to_hub(f"nielsr/{model_name}")
processor.push_to_hub(f"nielsr/{model_name}")
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument( parser.add_argument(
"--model_name", default="detr_resnet50", type=str, help="Name of the DETR model you'd like to convert." "--model_name",
default="detr-resnet-50",
type=str,
choices=["detr-resnet-50", "detr-resnet-101"],
help="Name of the DETR model you'd like to convert.",
) )
parser.add_argument( parser.add_argument(
"--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model." "--pytorch_dump_folder_path", default=None, type=str, help="Path to the folder to output PyTorch model."
) )
parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to the hub or not.")
args = parser.parse_args() args = parser.parse_args()
convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path) convert_detr_checkpoint(args.model_name, args.pytorch_dump_folder_path, args.push_to_hub)
...@@ -1065,12 +1065,12 @@ class DetrImageProcessor(BaseImageProcessor): ...@@ -1065,12 +1065,12 @@ class DetrImageProcessor(BaseImageProcessor):
images (`ImageInput`): images (`ImageInput`):
Image or batch of images to preprocess. Image or batch of images to preprocess.
annotations (`AnnotationType` or `List[AnnotationType]`, *optional*): annotations (`AnnotationType` or `List[AnnotationType]`, *optional*):
List of annotations associated with the image or batch of images. If annotionation is for object List of annotations associated with the image or batch of images. If annotation is for object
detection, the annotations should be a dictionary with the following keys: detection, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a - "annotations" (`List[Dict]`): List of annotations for an image. Each annotation should be a
dictionary. An image can have no annotations, in which case the list should be empty. dictionary. An image can have no annotations, in which case the list should be empty.
If annotionation is for segmentation, the annotations should be a dictionary with the following keys: If annotation is for segmentation, the annotations should be a dictionary with the following keys:
- "image_id" (`int`): The image id. - "image_id" (`int`): The image id.
- "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary. - "segments_info" (`List[Dict]`): List of segments for an image. Each segment should be a dictionary.
An image can have no segments, in which case the list should be empty. An image can have no segments, in which case the list should be empty.
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_timm_available from ...utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available
_import_structure = { _import_structure = {
...@@ -26,7 +26,7 @@ _import_structure = { ...@@ -26,7 +26,7 @@ _import_structure = {
} }
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
...@@ -47,7 +47,7 @@ if TYPE_CHECKING: ...@@ -47,7 +47,7 @@ if TYPE_CHECKING:
) )
try: try:
if not is_timm_available(): if not is_torch_available():
raise OptionalDependencyNotAvailable() raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable: except OptionalDependencyNotAvailable:
pass pass
......
...@@ -189,6 +189,8 @@ class TableTransformerConfig(PretrainedConfig): ...@@ -189,6 +189,8 @@ class TableTransformerConfig(PretrainedConfig):
backbone_model_type = backbone_config.get("model_type") backbone_model_type = backbone_config.get("model_type")
config_class = CONFIG_MAPPING[backbone_model_type] config_class = CONFIG_MAPPING[backbone_model_type]
backbone_config = config_class.from_dict(backbone_config) backbone_config = config_class.from_dict(backbone_config)
# set timm attributes to None
dilation, backbone, use_pretrained_backbone = None, None, None
self.use_timm_backbone = use_timm_backbone self.use_timm_backbone = use_timm_backbone
self.backbone_config = backbone_config self.backbone_config = backbone_config
......
...@@ -1661,6 +1661,37 @@ class CodeGenPreTrainedModel(metaclass=DummyObject): ...@@ -1661,6 +1661,37 @@ class CodeGenPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrForSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None CONVBERT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2073,6 +2104,30 @@ class DecisionTransformerPreTrainedModel(metaclass=DummyObject): ...@@ -2073,6 +2104,30 @@ class DecisionTransformerPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DeformableDetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeformableDetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None DEIT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -2135,6 +2190,37 @@ class DetaPreTrainedModel(metaclass=DummyObject): ...@@ -2135,6 +2190,37 @@ class DetaPreTrainedModel(metaclass=DummyObject):
requires_backends(self, ["torch"]) requires_backends(self, ["torch"])
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrForSegmentation(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class DetrPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None DINAT_PRETRAINED_MODEL_ARCHIVE_LIST = None
...@@ -6040,6 +6126,30 @@ def load_tf_weights_in_t5(*args, **kwargs): ...@@ -6040,6 +6126,30 @@ def load_tf_weights_in_t5(*args, **kwargs):
requires_backends(load_tf_weights_in_t5, ["torch"]) requires_backends(load_tf_weights_in_t5, ["torch"])
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TableTransformerForObjectDetection(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TableTransformerModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
class TableTransformerPreTrainedModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None TAPAS_PRETRAINED_MODEL_ARCHIVE_LIST = None
......
# This file is autogenerated by the command `make fix-copies`, do not edit.
from ..utils import DummyObject, requires_backends
CONDITIONAL_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class ConditionalDetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrForSegmentation(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class ConditionalDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
DEFORMABLE_DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DeformableDetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DeformableDetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
DETR_PRETRAINED_MODEL_ARCHIVE_LIST = None
class DetrForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrForSegmentation(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class DetrPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
TABLE_TRANSFORMER_PRETRAINED_MODEL_ARCHIVE_LIST = None
class TableTransformerForObjectDetection(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class TableTransformerModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
class TableTransformerPreTrainedModel(metaclass=DummyObject):
_backends = ["timm", "vision"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["timm", "vision"])
...@@ -20,7 +20,7 @@ import math ...@@ -20,7 +20,7 @@ import math
import unittest import unittest
from transformers import DetrConfig, is_timm_available, is_vision_available from transformers import DetrConfig, is_timm_available, is_vision_available
from transformers.testing_utils import require_timm, require_vision, slow, torch_device from transformers.testing_utils import require_timm, require_torch, require_vision, slow, torch_device
from transformers.utils import cached_property from transformers.utils import cached_property
from ...generation.test_utils import GenerationTesterMixin from ...generation.test_utils import GenerationTesterMixin
...@@ -510,7 +510,7 @@ def prepare_img(): ...@@ -510,7 +510,7 @@ def prepare_img():
@require_timm @require_timm
@require_vision @require_vision
@slow @slow
class DetrModelIntegrationTests(unittest.TestCase): class DetrModelIntegrationTestsTimmBackbone(unittest.TestCase):
@cached_property @cached_property
def default_feature_extractor(self): def default_feature_extractor(self):
return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None return DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50") if is_vision_available() else None
...@@ -626,3 +626,33 @@ class DetrModelIntegrationTests(unittest.TestCase): ...@@ -626,3 +626,33 @@ class DetrModelIntegrationTests(unittest.TestCase):
self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4)) self.assertTrue(torch.allclose(results["segmentation"][:3, :3], expected_slice_segmentation, atol=1e-4))
self.assertTrue(len(results["segments_info"]), expected_number_of_segments) self.assertTrue(len(results["segments_info"]), expected_number_of_segments)
self.assertDictEqual(results["segments_info"][0], expected_first_segment) self.assertDictEqual(results["segments_info"][0], expected_first_segment)
@require_vision
@require_torch
@slow
class DetrModelIntegrationTests(unittest.TestCase):
@cached_property
def default_feature_extractor(self):
return (
DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50", revision="no_timm")
if is_vision_available()
else None
)
def test_inference_no_head(self):
model = DetrModel.from_pretrained("facebook/detr-resnet-50", revision="no_timm").to(torch_device)
feature_extractor = self.default_feature_extractor
image = prepare_img()
encoding = feature_extractor(images=image, return_tensors="pt").to(torch_device)
with torch.no_grad():
outputs = model(**encoding)
expected_shape = torch.Size((1, 100, 256))
assert outputs.last_hidden_state.shape == expected_shape
expected_slice = torch.tensor(
[[0.0616, -0.5146, -0.4032], [-0.7629, -0.4934, -1.7153], [-0.4768, -0.6403, -0.7826]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.last_hidden_state[0, :3, :3], expected_slice, atol=1e-4))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment