Unverified Commit 3a780cc5 authored by amyeroberts's avatar amyeroberts Committed by GitHub
Browse files

Image transforms functionality used instead (#20278)

* Image transforms functionality used instead

* Import torch

* Import rather than copy

* Update src/transformers/models/conditional_detr/feature_extraction_conditional_detr.py
parent 3fad6ae3
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
# limitations under the License. # limitations under the License.
import warnings import warnings
from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple, Union from typing import Iterable, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -38,13 +38,14 @@ if is_vision_available(): ...@@ -38,13 +38,14 @@ if is_vision_available():
) )
if TYPE_CHECKING: if is_torch_available():
if is_torch_available(): import torch
import torch
if is_tf_available(): if is_tf_available():
import tensorflow as tf import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp if is_flax_available():
import jax.numpy as jnp
def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray: def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
......
...@@ -22,8 +22,9 @@ import numpy as np ...@@ -22,8 +22,9 @@ import numpy as np
from PIL import Image from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor from ...image_transforms import center_to_corners_format, corners_to_center_format, rgb_to_id
from ...utils import TensorType, is_torch_available, logging from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging
if is_torch_available(): if is_torch_available():
...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__) ...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__)
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]] ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)
# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes # Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
def masks_to_boxes(masks): def masks_to_boxes(masks):
""" """
...@@ -93,15 +71,6 @@ def masks_to_boxes(masks): ...@@ -93,15 +71,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1) return np.stack([x_min, y_min, x_max, y_max], 1)
# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
# Copied from transformers.models.detr.feature_extraction_detr.binary_mask_to_rle # Copied from transformers.models.detr.feature_extraction_detr.binary_mask_to_rle
def binary_mask_to_rle(mask): def binary_mask_to_rle(mask):
""" """
......
...@@ -33,6 +33,7 @@ from ...utils import ( ...@@ -33,6 +33,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
...@@ -46,6 +47,9 @@ if is_scipy_available(): ...@@ -46,6 +47,9 @@ if is_scipy_available():
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "ConditionalDetrConfig" _CONFIG_FOR_DOC = "ConditionalDetrConfig"
...@@ -2596,17 +2600,6 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -2596,17 +2600,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis # Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list): def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int] # type: (List[List[int]]) -> List[int]
......
...@@ -22,8 +22,9 @@ import numpy as np ...@@ -22,8 +22,9 @@ import numpy as np
from PIL import Image from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor from ...image_transforms import center_to_corners_format, corners_to_center_format, rgb_to_id
from ...utils import TensorType, is_torch_available, logging from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging
if is_torch_available(): if is_torch_available():
...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__) ...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__)
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]] ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)
# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes # Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
def masks_to_boxes(masks): def masks_to_boxes(masks):
""" """
...@@ -93,32 +71,6 @@ def masks_to_boxes(masks): ...@@ -93,32 +71,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1) return np.stack([x_min, y_min, x_max, y_max], 1)
# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
# Copied from transformers.models.detr.feature_extraction_detr.id_to_rgb
def id_to_rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): class DeformableDetrFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r""" r"""
Constructs a Deformable DETR feature extractor. Differs only in the postprocessing of object detection compared to Constructs a Deformable DETR feature extractor. Differs only in the postprocessing of object detection compared to
......
...@@ -35,6 +35,7 @@ from ...file_utils import ( ...@@ -35,6 +35,7 @@ from ...file_utils import (
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_torch_cuda_available, is_torch_cuda_available,
is_vision_available,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
) )
...@@ -58,6 +59,9 @@ if is_torch_cuda_available() and is_ninja_available(): ...@@ -58,6 +59,9 @@ if is_torch_cuda_available() and is_ninja_available():
else: else:
MultiScaleDeformableAttention = None MultiScaleDeformableAttention = None
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
class MultiScaleDeformableAttentionFunction(Function): class MultiScaleDeformableAttentionFunction(Function):
@staticmethod @staticmethod
...@@ -2417,17 +2421,6 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -2417,17 +2421,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._max_by_axis # Copied from transformers.models.detr.modeling_detr._max_by_axis
def _max_by_axis(the_list): def _max_by_axis(the_list):
# type: (List[List[int]]) -> List[int] # type: (List[List[int]]) -> List[int]
......
...@@ -24,8 +24,9 @@ import numpy as np ...@@ -24,8 +24,9 @@ import numpy as np
from PIL import Image from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor from ...image_transforms import center_to_corners_format, corners_to_center_format, id_to_rgb, rgb_to_id
from ...utils import TensorType, is_torch_available, logging from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging
if is_torch_available(): if is_torch_available():
...@@ -38,28 +39,6 @@ logger = logging.get_logger(__name__) ...@@ -38,28 +39,6 @@ logger = logging.get_logger(__name__)
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]] ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
# 2 functions below inspired by https://github.com/facebookresearch/detr/blob/master/util/box_ops.py
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)
def masks_to_boxes(masks): def masks_to_boxes(masks):
""" """
Compute the bounding boxes around the provided panoptic segmentation masks. Compute the bounding boxes around the provided panoptic segmentation masks.
...@@ -93,33 +72,6 @@ def masks_to_boxes(masks): ...@@ -93,33 +72,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1) return np.stack([x_min, y_min, x_max, y_max], 1)
# 2 functions below copied from https://github.com/cocodataset/panopticapi/blob/master/panopticapi/utils.py
# Copyright (c) 2018, Alexander Kirillov
# All rights reserved.
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
def id_to_rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
def binary_mask_to_rle(mask): def binary_mask_to_rle(mask):
""" """
Args: Args:
......
...@@ -33,6 +33,7 @@ from ...utils import ( ...@@ -33,6 +33,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
...@@ -46,6 +47,9 @@ if is_scipy_available(): ...@@ -46,6 +47,9 @@ if is_scipy_available():
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "DetrConfig" _CONFIG_FOR_DOC = "DetrConfig"
...@@ -2284,17 +2288,6 @@ def generalized_box_iou(boxes1, boxes2): ...@@ -2284,17 +2288,6 @@ def generalized_box_iou(boxes1, boxes2):
return iou - (area - union) / area return iou - (area - union) / area
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306 # below: taken from https://github.com/facebookresearch/detr/blob/master/util/misc.py#L306
......
...@@ -22,8 +22,9 @@ from PIL import Image ...@@ -22,8 +22,9 @@ from PIL import Image
from transformers.image_utils import PILImageResampling from transformers.image_utils import PILImageResampling
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor from ...image_transforms import center_to_corners_format
from ...utils import TensorType, is_torch_available, logging from ...image_utils import ImageFeatureExtractionMixin
from ...utils import TensorType, is_torch_available, is_torch_tensor, logging
if is_torch_available(): if is_torch_available():
...@@ -32,17 +33,6 @@ if is_torch_available(): ...@@ -32,17 +33,6 @@ if is_torch_available():
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._upcast # Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t): def _upcast(t):
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
......
...@@ -31,12 +31,17 @@ from ...utils import ( ...@@ -31,12 +31,17 @@ from ...utils import (
ModelOutput, ModelOutput,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
) )
from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig from .configuration_owlvit import OwlViTConfig, OwlViTTextConfig, OwlViTVisionConfig
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32" _CHECKPOINT_FOR_DOC = "google/owlvit-base-patch32"
...@@ -114,17 +119,6 @@ class OwlViTOutput(ModelOutput): ...@@ -114,17 +119,6 @@ class OwlViTOutput(ModelOutput):
) )
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.modeling_detr._upcast # Copied from transformers.models.detr.modeling_detr._upcast
def _upcast(t: torch.Tensor) -> torch.Tensor: def _upcast(t: torch.Tensor) -> torch.Tensor:
# Protects from numerical overflows in multiplications by upcasting to the equivalent higher type # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
......
...@@ -33,6 +33,7 @@ from ...utils import ( ...@@ -33,6 +33,7 @@ from ...utils import (
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_scipy_available, is_scipy_available,
is_timm_available, is_timm_available,
is_vision_available,
logging, logging,
replace_return_docstrings, replace_return_docstrings,
requires_backends, requires_backends,
...@@ -46,6 +47,9 @@ if is_scipy_available(): ...@@ -46,6 +47,9 @@ if is_scipy_available():
if is_timm_available(): if is_timm_available():
from timm import create_model from timm import create_model
if is_vision_available():
from transformers.image_transforms import center_to_corners_format
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
_CONFIG_FOR_DOC = "TableTransformerConfig" _CONFIG_FOR_DOC = "TableTransformerConfig"
...@@ -1929,14 +1933,3 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): ...@@ -1929,14 +1933,3 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
else: else:
raise ValueError("Only 3-dimensional tensors are supported") raise ValueError("Only 3-dimensional tensors are supported")
return NestedTensor(tensor, mask) return NestedTensor(tensor, mask)
# Copied from transformers.models.detr.modeling_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
...@@ -22,6 +22,7 @@ import numpy as np ...@@ -22,6 +22,7 @@ import numpy as np
from PIL import Image from PIL import Image
from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin from ...feature_extraction_utils import BatchFeature, FeatureExtractionMixin
from ...image_transforms import center_to_corners_format, corners_to_center_format, rgb_to_id
from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor from ...image_utils import ImageFeatureExtractionMixin, is_torch_tensor
from ...utils import TensorType, is_torch_available, logging from ...utils import TensorType, is_torch_available, logging
...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__) ...@@ -36,29 +37,6 @@ logger = logging.get_logger(__name__)
ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]] ImageInput = Union[Image.Image, np.ndarray, "torch.Tensor", List[Image.Image], List[np.ndarray], List["torch.Tensor"]]
# Copied from transformers.models.detr.feature_extraction_detr.center_to_corners_format
def center_to_corners_format(x):
"""
Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
(x_0, y_0, x_1, y_1).
"""
center_x, center_y, width, height = x.unbind(-1)
b = [(center_x - 0.5 * width), (center_y - 0.5 * height), (center_x + 0.5 * width), (center_y + 0.5 * height)]
return torch.stack(b, dim=-1)
# Copied from transformers.models.detr.feature_extraction_detr.corners_to_center_format
def corners_to_center_format(x):
"""
Converts a NumPy array of bounding boxes of shape (number of bounding boxes, 4) of corners format (x_0, y_0, x_1,
y_1) to center format (center_x, center_y, width, height).
"""
x_transposed = x.T
x0, y0, x1, y1 = x_transposed[0], x_transposed[1], x_transposed[2], x_transposed[3]
b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
return np.stack(b, axis=-1)
# Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes # Copied from transformers.models.detr.feature_extraction_detr.masks_to_boxes
def masks_to_boxes(masks): def masks_to_boxes(masks):
""" """
...@@ -93,32 +71,6 @@ def masks_to_boxes(masks): ...@@ -93,32 +71,6 @@ def masks_to_boxes(masks):
return np.stack([x_min, y_min, x_max, y_max], 1) return np.stack([x_min, y_min, x_max, y_max], 1)
# Copied from transformers.models.detr.feature_extraction_detr.rgb_to_id
def rgb_to_id(color):
if isinstance(color, np.ndarray) and len(color.shape) == 3:
if color.dtype == np.uint8:
color = color.astype(np.int32)
return color[:, :, 0] + 256 * color[:, :, 1] + 256 * 256 * color[:, :, 2]
return int(color[0] + 256 * color[1] + 256 * 256 * color[2])
# Copied from transformers.models.detr.feature_extraction_detr.id_to_rgb
def id_to_rgb(id_map):
if isinstance(id_map, np.ndarray):
id_map_copy = id_map.copy()
rgb_shape = tuple(list(id_map.shape) + [3])
rgb_map = np.zeros(rgb_shape, dtype=np.uint8)
for i in range(3):
rgb_map[..., i] = id_map_copy % 256
id_map_copy //= 256
return rgb_map
color = []
for _ in range(3):
color.append(id_map % 256)
id_map //= 256
return color
class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin): class YolosFeatureExtractor(FeatureExtractionMixin, ImageFeatureExtractionMixin):
r""" r"""
Constructs a YOLOS feature extractor. Constructs a YOLOS feature extractor.
......
...@@ -46,7 +46,7 @@ if is_scipy_available(): ...@@ -46,7 +46,7 @@ if is_scipy_available():
from scipy.optimize import linear_sum_assignment from scipy.optimize import linear_sum_assignment
if is_vision_available(): if is_vision_available():
from transformers.models.detr.feature_extraction_detr import center_to_corners_format from transformers.image_transforms import center_to_corners_format
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment