Unverified Commit 74a20740 authored by Sangbum Daniel Choi's avatar Sangbum Daniel Choi Committed by GitHub
Browse files

New model support RTDETR (#29077)

* fill out docs string in configuration
https://github.com/huggingface/transformers/pull/29077/files/75dcd3a0e82cca36f12178b65bbd071ab7b25088#r1506391856



* reduce the input image size for the tests

* remove the unappropriate tests

* only 5 failes exists

* make style

* fill up missed architecture for object detection in docs

* fix auto modeling

* simple fix in missing import

* major change including backbone refactor and objectdetectionoutput refactor

* minor fix only 4 fails left

* intermediate fix

* revert __init__.py

* revert __init__.py

* make style

* fixes in pr_docs

* intermediate fix

* make style

* two fixes

* pass doctest

* only one fix left

* intermediate commit

* all fixed

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/convert_rt_detr_original_pytorch_checkpoint_to_pytorch.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update tests/models/rt_detr/test_modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* function class above the model definition in dice_loss

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* simple fix

* layernorm add config.layer_norm_eps

* fix inputs_docstring

* make style

* simple fix

* add custom coco loading test in image_processor

* fix error in BaseModelOutput
https://github.com/huggingface/transformers/pull/29077#discussion_r1516657790



* simple typo

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* intermediate fix

* fix with load_backbone format

* remove unused configuration

* 3 fix test left

* make style

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarSounak Dey <dey.sounak@gmail.com>

* change last_hidden_state to first index

* all pass fix
TO DO: minor update in comments

* make fix-copies

* remove deepcopy

* pr_document fix

* revert deepcopy due to the issue of unexpceted behavior in decoderlayer

* add atol in final

* add no_split_module

* _no_split_modules = None

* device transfer for model parallelism

* minor fix

* make fix-copies

* fix typo

* add test_image_processor with post_processing

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add config in RTDETRPredictionHead

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* set lru_cache with max_size 32

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* add lru_cache import and configuration change

* change the order of definition

* make fix-copies

* add docs and change config error

* revert strange make-fix

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* test pass

* fix get_clones related and remove deepcopy

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* nit for paper section

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* rename denoising related parameters

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* check the image transformation logic

* make style

* make style

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* pe_encoding -> positional_encoding_temperature

* remove TODO

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* remove eval_idx since transformer DETR is giving all decoder output

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* change variable name

* make style and docs import update

* Revert "Update src/transformers/models/rt_detr/image_processing_rt_detr.py"

This reverts commit 74aa3e1de0ca0cd3d354161d38ef28b4389c0eee.

* fix typo

* add postprocessing in docs

* move import scipy to top

* change varaible name

* make fix-copies

* remove eval_idx in test

* move to after first sentence

* update image_processor since box loss requires normalized one

* change appropriate name to auxiliary_outputs

* Update src/transformers/models/rt_detr/__init__.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update src/transformers/models/rt_detr/__init__.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* make style

* remove panoptic related comments

* make style

* revert valid_processor_keys

* fix aux related test

* make style

* change origination from config to backbone API

* enable the dn_loss

* fix test and conversion

* renewal weight initialization

* change initializer_range

* make fix-up

* fix the loss issue in the auxiliary output and denoising part

* change weight loss to original RTDETR

* fix in initialization

* sync shape format of dn and aux

* make style

* stable fine-tuning and compatible conversion for resnet101

* make style

* skip input_embed

* change encoder related variable

* enable converting rtdetr_r101

* add r101 related conversion code

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/image_processing_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* change name _shape to _reshape

* Update src/transformers/__init__.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Update src/transformers/__init__.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* maket style

* make fix-copies

* remove deprecated import

* more fix

* remove last_hidden_state for task-specific model

* Revert "remove last_hidden_state for task-specific model"

This reverts commit ccb7a34051d69b9fc7aa17ed8644664d3fdbdaca.

* minore change in convert

* remove print

* make style and fix-copies

* add custom rtdetr backbone for r18, r34

* remove print

* change copied

* add pad_size

* make style

* change layertype to optional to pass the CI

* make style

* add test in modeling_resnet_rt_detr

* make fix-copies

* skip tmp file test

* fix comment

* add docs

* change to modeling_resnet file format

* enabling resnet50 above

* Update src/transformers/models/rt_detr/modeling_rt_detr.py
Co-authored-by: default avatarJason Wu <jasonkit@users.noreply.github.com>

* enable all the rtdetr model :)

* finish except CI

* add RTDetrResNetBackbone

* make fix-copies

* fix
TO DO: CI enable

* make style

* rename test

* add docs

* add special fix

* revert resnet

* Update src/transformers/models/rt_detr/modeling_rt_detr_resnet.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* add more comment

* remove swin comment

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* rename convert and add verify backbone

* Update docs/source/en/_toctree.yml
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Update docs/source/en/model_doc/rt_detr.md
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>

* make style

* requests for docs

* more general test docs

* general script docs

* make fix-copies

* final commit

* Revert "Update src/transformers/models/rt_detr/configuration_rt_detr.py"

This reverts commit d136225cd3f64f510d303ce1d227698174f43fff.

* skip test_model_get_set_embeddings

* remove target

* add changes

* make fix-copies

* remove decoder_attention_mask

* add load_backbone function for auto_backbone

* remove comment

* fix repo name

* Update src/transformers/models/rt_detr/configuration_rt_detr.py
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>

* final commit

* remove unused downsample_in_bottleneck

* new test for autobackbone

* change to appropriate indices

* test fix

* fix dict in test_image_processor

* fix test

* [run-slow] rt_detr, rt_detr_resnet

* change the slow test

* [run-slow] rt_detr

* [run-slow] rt_detr, rt_detr_resnet

* make in to same cuda in CSPRepLayer

* [run-slow] rt_detr, rt_detr_resnet

---------
Co-authored-by: default avataramyeroberts <22614925+amyeroberts@users.noreply.github.com>
Co-authored-by: default avatarSounak Dey <dey.sounak@gmail.com>
Co-authored-by: default avatarNielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: default avatarJason Wu <jasonkit@users.noreply.github.com>
Co-authored-by: default avatarChoiSangBum <choisangbum@ChoiSangBumui-MacBookPro.local>
parent 8b7cd402
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import unittest
import requests
from transformers.testing_utils import require_torch, require_vision, slow
from transformers.utils import is_torch_available, is_vision_available
from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs
if is_vision_available():
from PIL import Image
from transformers import RTDetrImageProcessor
if is_torch_available():
import torch
class RTDetrImageProcessingTester(unittest.TestCase):
def __init__(
self,
parent,
batch_size=4,
num_channels=3,
do_resize=True,
size=None,
do_rescale=True,
rescale_factor=1 / 255,
do_normalize=False,
do_pad=False,
return_tensors="pt",
):
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.do_resize = do_resize
self.size = size if size is not None else {"height": 640, "width": 640}
self.do_rescale = do_rescale
self.rescale_factor = rescale_factor
self.do_normalize = do_normalize
self.do_pad = do_pad
self.return_tensors = return_tensors
def prepare_image_processor_dict(self):
return {
"do_resize": self.do_resize,
"size": self.size,
"do_rescale": self.do_rescale,
"rescale_factor": self.rescale_factor,
"do_normalize": self.do_normalize,
"do_pad": self.do_pad,
"return_tensors": self.return_tensors,
}
def get_expected_values(self):
return self.size["height"], self.size["width"]
def expected_output_image_shape(self, images):
height, width = self.get_expected_values()
return self.num_channels, height, width
def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=30,
max_resolution=400,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)
@require_torch
@require_vision
class RtDetrImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = RTDetrImageProcessor if is_vision_available() else None
def setUp(self):
super().setUp()
self.image_processor_tester = RTDetrImageProcessingTester(self)
@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()
def test_image_processor_properties(self):
image_processing = self.image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))
self.assertTrue(hasattr(image_processing, "resample"))
self.assertTrue(hasattr(image_processing, "do_rescale"))
self.assertTrue(hasattr(image_processing, "rescale_factor"))
self.assertTrue(hasattr(image_processing, "return_tensors"))
def test_image_processor_from_dict_with_kwargs(self):
image_processor = self.image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"height": 640, "width": 640})
def test_valid_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
params = {"image_id": 39769, "annotations": target}
# encode them
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
# legal encodings (single image)
_ = image_processing(images=image, annotations=params, return_tensors="pt")
_ = image_processing(images=image, annotations=[params], return_tensors="pt")
# legal encodings (batch of one image)
_ = image_processing(images=[image], annotations=params, return_tensors="pt")
_ = image_processing(images=[image], annotations=[params], return_tensors="pt")
# legal encoding (batch of more than one image)
n = 5
_ = image_processing(images=[image] * n, annotations=[params] * n, return_tensors="pt")
# example of an illegal encoding (missing the 'image_id' key)
with self.assertRaises(ValueError) as e:
image_processing(images=image, annotations={"annotations": target}, return_tensors="pt")
self.assertTrue(str(e.exception).startswith("Invalid COCO detection annotations"))
# example of an illegal encoding (unequal lengths of images and annotations)
with self.assertRaises(ValueError) as e:
image_processing(images=[image] * n, annotations=[params] * (n - 1), return_tensors="pt")
self.assertTrue(str(e.exception) == "The number of images (5) and annotations (4) do not match.")
@slow
def test_call_pytorch_with_coco_detection_annotations(self):
# prepare image and target
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
target = {"image_id": 39769, "annotations": target}
# encode them
image_processing = RTDetrImageProcessor.from_pretrained("PekingU/rtdetr_r50vd")
encoding = image_processing(images=image, annotations=target, return_tensors="pt")
# verify pixel values
expected_shape = torch.Size([1, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
expected_slice = torch.tensor([0.5490, 0.5647, 0.5725])
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-4))
# verify area
expected_area = torch.tensor([2827.9883, 5403.4761, 235036.7344, 402070.2188, 71068.8281, 79601.2812])
self.assertTrue(torch.allclose(encoding["labels"][0]["area"], expected_area))
# verify boxes
expected_boxes_shape = torch.Size([6, 4])
self.assertEqual(encoding["labels"][0]["boxes"].shape, expected_boxes_shape)
expected_boxes_slice = torch.tensor([0.5503, 0.2765, 0.0604, 0.2215])
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"][0], expected_boxes_slice, atol=1e-3))
# verify image_id
expected_image_id = torch.tensor([39769])
self.assertTrue(torch.allclose(encoding["labels"][0]["image_id"], expected_image_id))
# verify is_crowd
expected_is_crowd = torch.tensor([0, 0, 0, 0, 0, 0])
self.assertTrue(torch.allclose(encoding["labels"][0]["iscrowd"], expected_is_crowd))
# verify class_labels
expected_class_labels = torch.tensor([75, 75, 63, 65, 17, 17])
self.assertTrue(torch.allclose(encoding["labels"][0]["class_labels"], expected_class_labels))
# verify orig_size
expected_orig_size = torch.tensor([480, 640])
self.assertTrue(torch.allclose(encoding["labels"][0]["orig_size"], expected_orig_size))
# verify size
expected_size = torch.tensor([640, 640])
self.assertTrue(torch.allclose(encoding["labels"][0]["size"], expected_size))
@slow
def test_image_processor_outputs(self):
image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
image_processing = self.image_processing_class(**self.image_processor_dict)
encoding = image_processing(images=image, return_tensors="pt")
# verify pixel values: shape
expected_shape = torch.Size([1, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values: output values
expected_slice = torch.tensor([0.5490196347236633, 0.5647059082984924, 0.572549045085907])
self.assertTrue(torch.allclose(encoding["pixel_values"][0, 0, 0, :3], expected_slice, atol=1e-5))
def test_multiple_images_processor_outputs(self):
images_urls = [
"http://images.cocodataset.org/val2017/000000000139.jpg",
"http://images.cocodataset.org/val2017/000000000285.jpg",
"http://images.cocodataset.org/val2017/000000000632.jpg",
"http://images.cocodataset.org/val2017/000000000724.jpg",
"http://images.cocodataset.org/val2017/000000000776.jpg",
"http://images.cocodataset.org/val2017/000000000785.jpg",
"http://images.cocodataset.org/val2017/000000000802.jpg",
"http://images.cocodataset.org/val2017/000000000872.jpg",
]
images = []
for url in images_urls:
image = Image.open(requests.get(url, stream=True).raw)
images.append(image)
# apply image processing
image_processing = self.image_processing_class(**self.image_processor_dict)
encoding = image_processing(images=images, return_tensors="pt")
# verify if pixel_values is part of the encoding
self.assertIn("pixel_values", encoding)
# verify pixel values: shape
expected_shape = torch.Size([8, 3, 640, 640])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# verify pixel values: output values
expected_slices = torch.tensor(
[
[0.5333333611488342, 0.5568627715110779, 0.5647059082984924],
[0.5372549295425415, 0.4705882668495178, 0.4274510145187378],
[0.3960784673690796, 0.35686275362968445, 0.3686274588108063],
[0.20784315466880798, 0.1882353127002716, 0.15294118225574493],
[0.364705890417099, 0.364705890417099, 0.3686274588108063],
[0.8078432083129883, 0.8078432083129883, 0.8078432083129883],
[0.4431372880935669, 0.4431372880935669, 0.4431372880935669],
[0.19607844948768616, 0.21176472306251526, 0.3607843220233917],
]
)
self.assertTrue(torch.allclose(encoding["pixel_values"][:, 1, 0, :3], expected_slices, atol=1e-5))
@slow
def test_batched_coco_detection_annotations(self):
image_0 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
image_1 = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png").resize((800, 800))
with open("./tests/fixtures/tests_samples/COCO/coco_annotations.txt", "r") as f:
target = json.loads(f.read())
annotations_0 = {"image_id": 39769, "annotations": target}
annotations_1 = {"image_id": 39769, "annotations": target}
# Adjust the bounding boxes for the resized image
w_0, h_0 = image_0.size
w_1, h_1 = image_1.size
for i in range(len(annotations_1["annotations"])):
coords = annotations_1["annotations"][i]["bbox"]
new_bbox = [
coords[0] * w_1 / w_0,
coords[1] * h_1 / h_0,
coords[2] * w_1 / w_0,
coords[3] * h_1 / h_0,
]
annotations_1["annotations"][i]["bbox"] = new_bbox
images = [image_0, image_1]
annotations = [annotations_0, annotations_1]
image_processing = RTDetrImageProcessor()
encoding = image_processing(
images=images,
annotations=annotations,
return_segmentation_masks=True,
return_tensors="pt", # do_convert_annotations=True
)
# Check the pixel values have been padded
postprocessed_height, postprocessed_width = 640, 640
expected_shape = torch.Size([2, 3, postprocessed_height, postprocessed_width])
self.assertEqual(encoding["pixel_values"].shape, expected_shape)
# Check the bounding boxes have been adjusted for padded images
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
expected_boxes_0 = torch.tensor(
[
[0.6879, 0.4609, 0.0755, 0.3691],
[0.2118, 0.3359, 0.2601, 0.1566],
[0.5011, 0.5000, 0.9979, 1.0000],
[0.5010, 0.5020, 0.9979, 0.9959],
[0.3284, 0.5944, 0.5884, 0.8112],
[0.8394, 0.5445, 0.3213, 0.9110],
]
)
expected_boxes_1 = torch.tensor(
[
[0.5503, 0.2765, 0.0604, 0.2215],
[0.1695, 0.2016, 0.2080, 0.0940],
[0.5006, 0.4933, 0.9977, 0.9865],
[0.5008, 0.5002, 0.9983, 0.9955],
[0.2627, 0.5456, 0.4707, 0.8646],
[0.7715, 0.4115, 0.4570, 0.7161],
]
)
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1e-3))
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1e-3))
# Check if do_convert_annotations=False, then the annotations are not converted to centre_x, centre_y, width, height
# format and not in the range [0, 1]
encoding = image_processing(
images=images,
annotations=annotations,
return_segmentation_masks=True,
do_convert_annotations=False,
return_tensors="pt",
)
self.assertEqual(encoding["labels"][0]["boxes"].shape, torch.Size([6, 4]))
self.assertEqual(encoding["labels"][1]["boxes"].shape, torch.Size([6, 4]))
# Convert to absolute coordinates
unnormalized_boxes_0 = torch.vstack(
[
expected_boxes_0[:, 0] * postprocessed_width,
expected_boxes_0[:, 1] * postprocessed_height,
expected_boxes_0[:, 2] * postprocessed_width,
expected_boxes_0[:, 3] * postprocessed_height,
]
).T
unnormalized_boxes_1 = torch.vstack(
[
expected_boxes_1[:, 0] * postprocessed_width,
expected_boxes_1[:, 1] * postprocessed_height,
expected_boxes_1[:, 2] * postprocessed_width,
expected_boxes_1[:, 3] * postprocessed_height,
]
).T
# Convert from centre_x, centre_y, width, height to x_min, y_min, x_max, y_max
expected_boxes_0 = torch.vstack(
[
unnormalized_boxes_0[:, 0] - unnormalized_boxes_0[:, 2] / 2,
unnormalized_boxes_0[:, 1] - unnormalized_boxes_0[:, 3] / 2,
unnormalized_boxes_0[:, 0] + unnormalized_boxes_0[:, 2] / 2,
unnormalized_boxes_0[:, 1] + unnormalized_boxes_0[:, 3] / 2,
]
).T
expected_boxes_1 = torch.vstack(
[
unnormalized_boxes_1[:, 0] - unnormalized_boxes_1[:, 2] / 2,
unnormalized_boxes_1[:, 1] - unnormalized_boxes_1[:, 3] / 2,
unnormalized_boxes_1[:, 0] + unnormalized_boxes_1[:, 2] / 2,
unnormalized_boxes_1[:, 1] + unnormalized_boxes_1[:, 3] / 2,
]
).T
self.assertTrue(torch.allclose(encoding["labels"][0]["boxes"], expected_boxes_0, rtol=1))
self.assertTrue(torch.allclose(encoding["labels"][1]["boxes"], expected_boxes_1, rtol=1))
This diff is collapsed.
# coding=utf-8
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from transformers import RTDetrResNetConfig
from transformers.testing_utils import require_torch, torch_device
from transformers.utils.import_utils import is_torch_available
from ...test_backbone_common import BackboneTesterMixin
from ...test_modeling_common import floats_tensor, ids_tensor
if is_torch_available():
from transformers import RTDetrResNetBackbone
class RTDetrResNetModelTester:
def __init__(
self,
parent,
batch_size=3,
image_size=32,
num_channels=3,
embeddings_size=10,
hidden_sizes=[10, 20, 30, 40],
depths=[1, 1, 2, 1],
is_training=True,
use_labels=True,
hidden_act="relu",
num_labels=3,
scope=None,
out_features=["stage2", "stage3", "stage4"],
out_indices=[2, 3, 4],
):
self.parent = parent
self.batch_size = batch_size
self.image_size = image_size
self.num_channels = num_channels
self.embeddings_size = embeddings_size
self.hidden_sizes = hidden_sizes
self.depths = depths
self.is_training = is_training
self.use_labels = use_labels
self.hidden_act = hidden_act
self.num_labels = num_labels
self.scope = scope
self.num_stages = len(hidden_sizes)
self.out_features = out_features
self.out_indices = out_indices
def prepare_config_and_inputs(self):
pixel_values = floats_tensor([self.batch_size, self.num_channels, self.image_size, self.image_size])
labels = None
if self.use_labels:
labels = ids_tensor([self.batch_size], self.num_labels)
config = self.get_config()
return config, pixel_values, labels
def get_config(self):
return RTDetrResNetConfig(
num_channels=self.num_channels,
embeddings_size=self.embeddings_size,
hidden_sizes=self.hidden_sizes,
depths=self.depths,
hidden_act=self.hidden_act,
num_labels=self.num_labels,
out_features=self.out_features,
out_indices=self.out_indices,
)
def create_and_check_backbone(self, config, pixel_values, labels):
model = RTDetrResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), len(config.out_features))
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[1], 4, 4])
# verify channels
self.parent.assertEqual(len(model.channels), len(config.out_features))
self.parent.assertListEqual(model.channels, config.hidden_sizes[1:])
# verify backbone works with out_features=None
config.out_features = None
model = RTDetrResNetBackbone(config=config)
model.to(torch_device)
model.eval()
result = model(pixel_values)
# verify feature maps
self.parent.assertEqual(len(result.feature_maps), 1)
self.parent.assertListEqual(list(result.feature_maps[0].shape), [self.batch_size, self.hidden_sizes[-1], 1, 1])
# verify channels
self.parent.assertEqual(len(model.channels), 1)
self.parent.assertListEqual(model.channels, [config.hidden_sizes[-1]])
def prepare_config_and_inputs_for_common(self):
config_and_inputs = self.prepare_config_and_inputs()
config, pixel_values, labels = config_and_inputs
inputs_dict = {"pixel_values": pixel_values}
return config, inputs_dict
@require_torch
class RTDetrResNetBackboneTest(BackboneTesterMixin, unittest.TestCase):
all_model_classes = (RTDetrResNetBackbone,) if is_torch_available() else ()
has_attentions = False
config_class = RTDetrResNetConfig
def setUp(self):
self.model_tester = RTDetrResNetModelTester(self)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment