"docs/git@developer.sourcefind.cn:OpenDAS/pytorch3d.git" did not exist on "c2e62a50871a7f23d93723f27188e59094181ab7"
Unverified Commit 121a780c authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

Cleanup conversion transforms (#6801)

* remove copy from convert_color_space

* remove copy from convert_format_bounding_box

* remove .to_* methods from features

* remove unnecessary clones

* add perf todos

* refactor convert_color_space

* lint

* remove another clone

* and another clone

* remove a missed copy
parent 5421f12a
...@@ -461,9 +461,7 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle, ...@@ -461,9 +461,7 @@ def reference_affine_bounding_box(bounding_box, *, format, spatial_size, angle,
], ],
dtype=bbox.dtype, dtype=bbox.dtype,
) )
return F.convert_format_bounding_box( return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format)
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
)
if bounding_box.ndim < 2: if bounding_box.ndim < 2:
bounding_box = [bounding_box] bounding_box = [bounding_box]
...@@ -556,17 +554,12 @@ KERNEL_INFOS.extend( ...@@ -556,17 +554,12 @@ KERNEL_INFOS.extend(
def sample_inputs_convert_format_bounding_box(): def sample_inputs_convert_format_bounding_box():
formats = set(features.BoundingBoxFormat) formats = list(features.BoundingBoxFormat)
for bounding_box_loader in make_bounding_box_loaders(formats=formats): for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
old_format = bounding_box_loader.format yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)):
yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params)
def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy):
if not copy:
raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`")
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
return torchvision.ops.box_convert( return torchvision.ops.box_convert(
bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower() bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower()
) )
...@@ -574,8 +567,7 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format, ...@@ -574,8 +567,7 @@ def reference_convert_format_bounding_box(bounding_box, old_format, new_format,
def reference_inputs_convert_format_bounding_box(): def reference_inputs_convert_format_bounding_box():
for args_kwargs in sample_inputs_convert_color_space_image_tensor(): for args_kwargs in sample_inputs_convert_color_space_image_tensor():
(image_loader, *other_args), kwargs = args_kwargs if len(args_kwargs.args[0].shape) == 2:
if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True):
yield args_kwargs yield args_kwargs
...@@ -600,11 +592,11 @@ def sample_inputs_convert_color_space_image_tensor(): ...@@ -600,11 +592,11 @@ def sample_inputs_convert_color_space_image_tensor():
for image_loader in make_image_loaders( for image_loader in make_image_loaders(
sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True
): ):
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space, copy=False) yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space)
@pil_reference_wrapper @pil_reference_wrapper
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy=True): def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode) color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode)
if color_space_pil != old_color_space: if color_space_pil != old_color_space:
raise pytest.UsageError( raise pytest.UsageError(
...@@ -612,7 +604,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c ...@@ -612,7 +604,7 @@ def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_c
f"from {old_color_space} to {color_space_pil}" f"from {old_color_space} to {color_space_pil}"
) )
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy) return F.convert_color_space_image_pil(image_pil, color_space=new_color_space)
def reference_inputs_convert_color_space_image_tensor(): def reference_inputs_convert_color_space_image_tensor():
......
...@@ -478,9 +478,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center): ...@@ -478,9 +478,7 @@ def test_correctness_rotate_bounding_box(angle, expand, center):
device=bbox.device, device=bbox.device,
) )
return ( return (
convert_format_bounding_box( convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format),
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
),
(height, width), (height, width),
) )
...@@ -733,14 +731,16 @@ def test_correctness_pad_bounding_box(device, padding): ...@@ -733,14 +731,16 @@ def test_correctness_pad_bounding_box(device, padding):
bbox_format = bbox.format bbox_format = bbox.format
bbox_dtype = bbox.dtype bbox_dtype = bbox.dtype
bbox = convert_format_bounding_box(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY) bbox = (
bbox.clone()
if bbox_format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY)
)
bbox[0::2] += pad_left bbox[0::2] += pad_left
bbox[1::2] += pad_up bbox[1::2] += pad_up
bbox = convert_format_bounding_box( bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format)
bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
)
if bbox.dtype != bbox_dtype: if bbox.dtype != bbox_dtype:
# Temporary cast to original dtype # Temporary cast to original dtype
# e.g. float32 -> int # e.g. float32 -> int
...@@ -840,9 +840,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints): ...@@ -840,9 +840,7 @@ def test_correctness_perspective_bounding_box(device, startpoints, endpoints):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box( return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format)
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
)
spatial_size = (32, 38) spatial_size = (32, 38)
...@@ -903,7 +901,7 @@ def test_correctness_center_crop_bounding_box(device, output_size): ...@@ -903,7 +901,7 @@ def test_correctness_center_crop_bounding_box(device, output_size):
dtype=bbox.dtype, dtype=bbox.dtype,
device=bbox.device, device=bbox.device,
) )
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False) return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
for bboxes in make_bounding_boxes(extra_dims=((4,),)): for bboxes in make_bounding_boxes(extra_dims=((4,),)):
bboxes = bboxes.to(device) bboxes = bboxes.to(device)
......
...@@ -110,18 +110,6 @@ class Image(_Feature): ...@@ -110,18 +110,6 @@ class Image(_Feature):
def num_channels(self) -> int: def num_channels(self) -> int:
return self.shape[-3] return self.shape[-3]
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Image.wrap_like(
self,
self._F.convert_color_space_image_tensor(
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Image: def horizontal_flip(self) -> Image:
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor)) output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
return Image.wrap_like(self, output) return Image.wrap_like(self, output)
......
...@@ -66,18 +66,6 @@ class Video(_Feature): ...@@ -66,18 +66,6 @@ class Video(_Feature):
def num_frames(self) -> int: def num_frames(self) -> int:
return self.shape[-4] return self.shape[-4]
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
if isinstance(color_space, str):
color_space = ColorSpace.from_str(color_space.upper())
return Video.wrap_like(
self,
self._F.convert_color_space_video(
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
),
color_space=color_space,
)
def horizontal_flip(self) -> Video: def horizontal_flip(self) -> Video:
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor)) output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
return Video.wrap_like(self, output) return Video.wrap_like(self, output)
......
...@@ -265,7 +265,7 @@ class SimpleCopyPaste(_RandomApplyTransform): ...@@ -265,7 +265,7 @@ class SimpleCopyPaste(_RandomApplyTransform):
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422 # https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
xyxy_boxes[:, 2:] += 1 xyxy_boxes[:, 2:] += 1
boxes = F.convert_format_bounding_box( boxes = F.convert_format_bounding_box(
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format
) )
out_target["boxes"] = torch.cat([boxes, paste_boxes]) out_target["boxes"] = torch.cat([boxes, paste_boxes])
......
...@@ -655,9 +655,7 @@ class RandomIoUCrop(Transform): ...@@ -655,9 +655,7 @@ class RandomIoUCrop(Transform):
continue continue
# check for any valid boxes with centers within the crop area # check for any valid boxes with centers within the crop area
xyxy_bboxes = F.convert_format_bounding_box( xyxy_bboxes = F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY)
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
)
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2]) cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3]) cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom) is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
...@@ -801,22 +799,21 @@ class FixedSizeCrop(Transform): ...@@ -801,22 +799,21 @@ class FixedSizeCrop(Transform):
top = int(offset_height * r) top = int(offset_height * r)
left = int(offset_width * r) left = int(offset_width * r)
bounding_boxes: Optional[torch.Tensor]
try: try:
bounding_boxes = query_bounding_box(flat_inputs) bounding_boxes = query_bounding_box(flat_inputs)
except ValueError: except ValueError:
bounding_boxes = None bounding_boxes = None
if needs_crop and bounding_boxes is not None: if needs_crop and bounding_boxes is not None:
bounding_boxes = cast( format = bounding_boxes.format
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width) bounding_boxes, spatial_size = F.crop_bounding_box(
) bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width
bounding_boxes = features.BoundingBox.wrap_like(
bounding_boxes,
F.clamp_bounding_box(
bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size
),
) )
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:] bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
height_and_width = F.convert_format_bounding_box(
bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH
)[..., 2:]
is_valid = torch.all(height_and_width > 0, dim=-1) is_valid = torch.all(height_and_width > 0, dim=-1)
else: else:
is_valid = None is_valid = None
......
...@@ -50,7 +50,6 @@ class ConvertColorSpace(Transform): ...@@ -50,7 +50,6 @@ class ConvertColorSpace(Transform):
self, self,
color_space: Union[str, features.ColorSpace], color_space: Union[str, features.ColorSpace],
old_color_space: Optional[Union[str, features.ColorSpace]] = None, old_color_space: Optional[Union[str, features.ColorSpace]] = None,
copy: bool = True,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -62,14 +61,10 @@ class ConvertColorSpace(Transform): ...@@ -62,14 +61,10 @@ class ConvertColorSpace(Transform):
old_color_space = features.ColorSpace.from_str(old_color_space) old_color_space = features.ColorSpace.from_str(old_color_space)
self.old_color_space = old_color_space self.old_color_space = old_color_space
self.copy = copy
def _transform( def _transform(
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any] self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
) -> Union[features.ImageType, features.VideoType]: ) -> Union[features.ImageType, features.VideoType]:
return F.convert_color_space( return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
)
class ClampBoundingBoxes(Transform): class ClampBoundingBoxes(Transform):
......
...@@ -36,14 +36,18 @@ def horizontal_flip_bounding_box( ...@@ -36,14 +36,18 @@ def horizontal_flip_bounding_box(
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_box.shape shape = bounding_box.shape
bounding_box = convert_format_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY # BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]] bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(shape) ).reshape(shape)
...@@ -73,14 +77,18 @@ def vertical_flip_bounding_box( ...@@ -73,14 +77,18 @@ def vertical_flip_bounding_box(
) -> torch.Tensor: ) -> torch.Tensor:
shape = bounding_box.shape shape = bounding_box.shape
bounding_box = convert_format_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY # BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]] bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
return convert_format_bounding_box( return convert_format_bounding_box(
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(shape) ).reshape(shape)
...@@ -394,8 +402,9 @@ def affine_bounding_box( ...@@ -394,8 +402,9 @@ def affine_bounding_box(
center: Optional[List[float]] = None, center: Optional[List[float]] = None,
) -> torch.Tensor: ) -> torch.Tensor:
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY bounding_box = (
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center) out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)
...@@ -403,7 +412,7 @@ def affine_bounding_box( ...@@ -403,7 +412,7 @@ def affine_bounding_box(
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape) ).reshape(original_shape)
...@@ -583,8 +592,8 @@ def rotate_bounding_box( ...@@ -583,8 +592,8 @@ def rotate_bounding_box(
center = None center = None
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = (
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
out_bboxes, spatial_size = _affine_bounding_box_xyxy( out_bboxes, spatial_size = _affine_bounding_box_xyxy(
...@@ -599,9 +608,9 @@ def rotate_bounding_box( ...@@ -599,9 +608,9 @@ def rotate_bounding_box(
) )
return ( return (
convert_format_bounding_box( convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False original_shape
).reshape(original_shape), ),
spatial_size, spatial_size,
) )
...@@ -818,8 +827,12 @@ def crop_bounding_box( ...@@ -818,8 +827,12 @@ def crop_bounding_box(
height: int, height: int,
width: int, width: int,
) -> Tuple[torch.Tensor, Tuple[int, int]]: ) -> Tuple[torch.Tensor, Tuple[int, int]]:
bounding_box = convert_format_bounding_box( # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY # BoundingBoxFormat instead of converting back and forth
bounding_box = (
bounding_box.clone()
if format == features.BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
) )
# Crop or implicit pad if left and/or top have negative values: # Crop or implicit pad if left and/or top have negative values:
...@@ -827,9 +840,7 @@ def crop_bounding_box( ...@@ -827,9 +840,7 @@ def crop_bounding_box(
bounding_box[..., 1::2] -= top bounding_box[..., 1::2] -= top
return ( return (
convert_format_bounding_box( convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format),
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
),
(height, width), (height, width),
) )
...@@ -896,8 +907,8 @@ def perspective_bounding_box( ...@@ -896,8 +907,8 @@ def perspective_bounding_box(
raise ValueError("Argument perspective_coeffs should have 8 float values") raise ValueError("Argument perspective_coeffs should have 8 float values")
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = (
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
...@@ -967,7 +978,7 @@ def perspective_bounding_box( ...@@ -967,7 +978,7 @@ def perspective_bounding_box(
# out_bboxes should be of shape [N boxes, 4] # out_bboxes should be of shape [N boxes, 4]
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape) ).reshape(original_shape)
...@@ -1061,8 +1072,8 @@ def elastic_bounding_box( ...@@ -1061,8 +1072,8 @@ def elastic_bounding_box(
displacement = displacement.to(bounding_box.device) displacement = displacement.to(bounding_box.device)
original_shape = bounding_box.shape original_shape = bounding_box.shape
bounding_box = convert_format_bounding_box( bounding_box = (
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
).reshape(-1, 4) ).reshape(-1, 4)
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
...@@ -1088,7 +1099,7 @@ def elastic_bounding_box( ...@@ -1088,7 +1099,7 @@ def elastic_bounding_box(
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
return convert_format_bounding_box( return convert_format_bounding_box(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
).reshape(original_shape) ).reshape(original_shape)
......
...@@ -125,13 +125,10 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor: ...@@ -125,13 +125,10 @@ def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
def convert_format_bounding_box( def convert_format_bounding_box(
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, copy: bool = True bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
) -> torch.Tensor: ) -> torch.Tensor:
if new_format == old_format: if new_format == old_format:
if copy: return bounding_box
return bounding_box.clone()
else:
return bounding_box
if old_format == BoundingBoxFormat.XYWH: if old_format == BoundingBoxFormat.XYWH:
bounding_box = _xywh_to_xyxy(bounding_box) bounding_box = _xywh_to_xyxy(bounding_box)
...@@ -149,12 +146,16 @@ def convert_format_bounding_box( ...@@ -149,12 +146,16 @@ def convert_format_bounding_box(
def clamp_bounding_box( def clamp_bounding_box(
bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int] bounding_box: torch.Tensor, format: BoundingBoxFormat, spatial_size: Tuple[int, int]
) -> torch.Tensor: ) -> torch.Tensor:
# TODO: (PERF) Possible speed up clamping if we have different implementations for each bbox format. # TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
# Not sure if they yield equivalent results. # BoundingBoxFormat instead of converting back and forth
xyxy_boxes = convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY) xyxy_boxes = (
bounding_box.clone()
if format == BoundingBoxFormat.XYXY
else convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY)
)
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1]) xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0]) xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format, copy=False) return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)
def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def _split_alpha(image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
...@@ -192,13 +193,10 @@ def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor: ...@@ -192,13 +193,10 @@ def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
def convert_color_space_image_tensor( def convert_color_space_image_tensor(
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor: ) -> torch.Tensor:
if new_color_space == old_color_space: if new_color_space == old_color_space:
if copy: return image
return image.clone()
else:
return image
if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER: if old_color_space == ColorSpace.OTHER or new_color_space == ColorSpace.OTHER:
raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.") raise RuntimeError(f"Conversion to or from {ColorSpace.OTHER} is not supported.")
...@@ -242,34 +240,29 @@ _COLOR_SPACE_TO_PIL_MODE = { ...@@ -242,34 +240,29 @@ _COLOR_SPACE_TO_PIL_MODE = {
@torch.jit.unused @torch.jit.unused
def convert_color_space_image_pil( def convert_color_space_image_pil(image: PIL.Image.Image, color_space: ColorSpace) -> PIL.Image.Image:
image: PIL.Image.Image, color_space: ColorSpace, copy: bool = True
) -> PIL.Image.Image:
old_mode = image.mode old_mode = image.mode
try: try:
new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space] new_mode = _COLOR_SPACE_TO_PIL_MODE[color_space]
except KeyError: except KeyError:
raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.") raise ValueError(f"Conversion from {ColorSpace.from_pil_mode(old_mode)} to {color_space} is not supported.")
if not copy and image.mode == new_mode: if image.mode == new_mode:
return image return image
return image.convert(new_mode) return image.convert(new_mode)
def convert_color_space_video( def convert_color_space_video(
video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True video: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace
) -> torch.Tensor: ) -> torch.Tensor:
return convert_color_space_image_tensor( return convert_color_space_image_tensor(video, old_color_space=old_color_space, new_color_space=new_color_space)
video, old_color_space=old_color_space, new_color_space=new_color_space, copy=copy
)
def convert_color_space( def convert_color_space(
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT],
color_space: ColorSpace, color_space: ColorSpace,
old_color_space: Optional[ColorSpace] = None, old_color_space: Optional[ColorSpace] = None,
copy: bool = True,
) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]: ) -> Union[features.ImageTypeJIT, features.VideoTypeJIT]:
if isinstance(inpt, torch.Tensor) and ( if isinstance(inpt, torch.Tensor) and (
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video)) torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
...@@ -279,10 +272,16 @@ def convert_color_space( ...@@ -279,10 +272,16 @@ def convert_color_space(
"In order to convert the color space of simple tensors, " "In order to convert the color space of simple tensors, "
"the `old_color_space=...` parameter needs to be passed." "the `old_color_space=...` parameter needs to be passed."
) )
return convert_color_space_image_tensor( return convert_color_space_image_tensor(inpt, old_color_space=old_color_space, new_color_space=color_space)
inpt, old_color_space=old_color_space, new_color_space=color_space, copy=copy elif isinstance(inpt, features.Image):
output = convert_color_space_image_tensor(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
)
return features.Image.wrap_like(inpt, output, color_space=color_space)
elif isinstance(inpt, features.Video):
output = convert_color_space_video(
inpt.as_subclass(torch.Tensor), old_color_space=inpt.color_space, new_color_space=color_space
) )
elif isinstance(inpt, (features.Image, features.Video)): return features.Video.wrap_like(inpt, output, color_space=color_space)
return inpt.to_color_space(color_space, copy=copy)
else: else:
return convert_color_space_image_pil(inpt, color_space, copy=copy) return convert_color_space_image_pil(inpt, color_space)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment