Commit 82f17be0 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

add typing to transforms

Summary: Add typing to transform.

Reviewed By: wat3rBro

Differential Revision: D27145140

fbshipit-source-id: 8556427b421bf91a05692a590db175c68c4d6890
parent daf7f294
...@@ -2,18 +2,29 @@ ...@@ -2,18 +2,29 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import random import random
from typing import List, Optional, Tuple
import cv2 import cv2
import json
import numpy as np import numpy as np
import torchvision.transforms as T
from detectron2.config import CfgNode
from detectron2.data.transforms import Transform, TransformGen, NoOpTransform
from .build import TRANSFORM_OP_REGISTRY from .build import TRANSFORM_OP_REGISTRY
from detectron2.data.transforms import Transform, TransformGen, NoOpTransform
import torchvision.transforms as T
class AffineTransform(Transform): class AffineTransform(Transform):
def __init__(self, M, img_w, img_h, flags=None, border_mode=None, is_inversed_M=False): def __init__(
self,
M: np.ndarray,
img_w: int,
img_h: int,
flags: Optional[int] = None,
border_mode: Optional[int] = None,
is_inversed_M: bool = False,
):
""" """
Args: Args:
will transform img according to affine transform M will transform img according to affine transform M
...@@ -26,7 +37,7 @@ class AffineTransform(Transform): ...@@ -26,7 +37,7 @@ class AffineTransform(Transform):
if border_mode is not None: if border_mode is not None:
self.warp_kwargs["borderMode"] = border_mode self.warp_kwargs["borderMode"] = border_mode
def apply_image(self, img): def apply_image(self, img: np.ndarray) -> np.ndarray:
M = self.M M = self.M
if self.is_inversed_M: if self.is_inversed_M:
M = M[:2] M = M[:2]
...@@ -38,7 +49,7 @@ class AffineTransform(Transform): ...@@ -38,7 +49,7 @@ class AffineTransform(Transform):
) )
return img return img
def apply_coords(self, coords): def apply_coords(self, coords: np.ndarray) -> np.ndarray:
# Add row of ones to enable matrix multiplication # Add row of ones to enable matrix multiplication
coords = coords.T coords = coords.T
ones = np.ones((1, coords.shape[1])) ones = np.ones((1, coords.shape[1]))
...@@ -52,21 +63,22 @@ class AffineTransform(Transform): ...@@ -52,21 +63,22 @@ class AffineTransform(Transform):
class RandomPivotScaling(TransformGen): class RandomPivotScaling(TransformGen):
""" """
Uniformly pick a random pivot point inside image frame, scaling the image Uniformly pick a random pivot point inside image frame, scaling the image
around the pivot point using the scale factor sampled from a list of around the pivot point using the scale factor sampled from a list of
given scales. The pivot point's location is unchanged after the transform. given scales. The pivot point's location is unchanged after the transform.
Arguments: Arguments:
scales: List[float]: each element can be any positive float number, scales: List[float]: each element can be any positive float number,
when larger than 1.0 objects become larger after transform when larger than 1.0 objects become larger after transform
and vice versa. and vice versa.
""" """
def __init__(self, scales):
def __init__(self, scales: List[int]):
super().__init__() super().__init__()
self._init(locals()) self._init(locals())
self.scales = scales self.scales = scales
def get_transform(self, img): def get_transform(self, img: np.ndarray) -> Transform:
img_h, img_w, _ = img.shape img_h, img_w, _ = img.shape
img_h = float(img_h) img_h = float(img_h)
img_w = float(img_w) img_w = float(img_w)
...@@ -85,10 +97,8 @@ class RandomPivotScaling(TransformGen): ...@@ -85,10 +97,8 @@ class RandomPivotScaling(TransformGen):
rb = (img_w, img_h) rb = (img_w, img_h)
pivot = (pivot_x, pivot_y) pivot = (pivot_x, pivot_y)
pts1 = np.float32([lt, pivot, rb]) pts1 = np.float32([lt, pivot, rb])
pts2 = np.float32([ pts2 = np.float32(
_interp(pivot, lt, scale), [_interp(pivot, lt, scale), pivot, _interp(pivot, rb, scale)],
pivot,
_interp(pivot, rb, scale)],
) )
M = cv2.getAffineTransform(pts1, pts2) M = cv2.getAffineTransform(pts1, pts2)
...@@ -97,16 +107,17 @@ class RandomPivotScaling(TransformGen): ...@@ -97,16 +107,17 @@ class RandomPivotScaling(TransformGen):
class RandomAffine(TransformGen): class RandomAffine(TransformGen):
""" """
Apply random affine trasform to the image given Apply random affine trasform to the image given
probabilities and ranges in each dimension. probabilities and ranges in each dimension.
""" """
def __init__( def __init__(
self, self,
prob=0.5, prob: float = 0.5,
angle_range=(-90, 90), angle_range: Tuple[float, float] = (-90, 90),
translation_range=(0, 0), translation_range: Tuple[float, float] = (0, 0),
scale_range=(1.0, 1.0), scale_range: Tuple[float, float] = (1.0, 1.0),
shear_range=(0, 0), shear_range: Tuple[float, float] = (0, 0),
): ):
""" """
Args: Args:
...@@ -123,7 +134,7 @@ class RandomAffine(TransformGen): ...@@ -123,7 +134,7 @@ class RandomAffine(TransformGen):
# Turn all locals into member variables. # Turn all locals into member variables.
self._init(locals()) self._init(locals())
def get_transform(self, img): def get_transform(self, img: np.ndarray) -> Transform:
im_h, im_w = img.shape[:2] im_h, im_w = img.shape[:2]
max_size = max(im_w, im_h) max_size = max(im_w, im_h)
center = [im_w / 2, im_h / 2] center = [im_w / 2, im_h / 2]
...@@ -148,11 +159,13 @@ class RandomAffine(TransformGen): ...@@ -148,11 +159,13 @@ class RandomAffine(TransformGen):
M = np.linalg.inv(M_inv) M = np.linalg.inv(M_inv)
# Center in output patch # Center in output patch
img_corners = np.array([ img_corners = np.array(
[0, 0, im_w, im_w], [
[0, im_h, 0, im_h], [0, 0, im_w, im_w],
[1, 1, 1, 1], [0, im_h, 0, im_h],
]) [1, 1, 1, 1],
]
)
transformed_corners = M @ img_corners transformed_corners = M @ img_corners
x_min = np.amin(transformed_corners[0]) x_min = np.amin(transformed_corners[0])
x_max = np.amax(transformed_corners[0]) x_max = np.amax(transformed_corners[0])
...@@ -184,14 +197,15 @@ class RandomAffine(TransformGen): ...@@ -184,14 +197,15 @@ class RandomAffine(TransformGen):
max_size, max_size,
flags=cv2.WARP_INVERSE_MAP + cv2.INTER_LINEAR, flags=cv2.WARP_INVERSE_MAP + cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REPLICATE, border_mode=cv2.BORDER_REPLICATE,
is_inversed_M=True is_inversed_M=True,
) )
else: else:
return NoOpTransform() return NoOpTransform()
# example repr: "RandomPivotScalingOp::[1.0, 0.75, 0.5]" # example repr: "RandomPivotScalingOp::[1.0, 0.75, 0.5]"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomPivotScalingOp(cfg, arg_str, is_train): def RandomPivotScalingOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train assert is_train
scales = json.loads(arg_str) scales = json.loads(arg_str)
assert isinstance(scales, list) assert isinstance(scales, list)
...@@ -200,7 +214,7 @@ def RandomPivotScalingOp(cfg, arg_str, is_train): ...@@ -200,7 +214,7 @@ def RandomPivotScalingOp(cfg, arg_str, is_train):
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomAffineOp(cfg, arg_str, is_train): def RandomAffineOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train assert is_train
kwargs = json.loads(arg_str) if arg_str is not None else {} kwargs = json.loads(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Tuple, Dict
import detectron2.data.transforms.augmentation as aug import detectron2.data.transforms.augmentation as aug
from detectron2.data.transforms import NoOpTransform, Transform
import numpy as np import numpy as np
from detectron2.config import CfgNode
from detectron2.data.transforms import NoOpTransform, Transform
from .build import TRANSFORM_OP_REGISTRY, _json_load from .build import TRANSFORM_OP_REGISTRY, _json_load
class LocalizedBoxMotionBlurTransform(Transform): class LocalizedBoxMotionBlurTransform(Transform):
""" Transform to blur provided bounding boxes from an image. """ """ Transform to blur provided bounding boxes from an image. """
def __init__(self, bounding_boxes, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
def __init__(
self,
bounding_boxes: List[List[int]],
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
import imgaug.augmenters as iaa import imgaug.augmenters as iaa
super().__init__() super().__init__()
self._set_attributes(locals()) self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1) self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img): def apply_image(self, img: np.ndarray) -> np.ndarray:
bbox_regions = [img[y:y+h, x:x+w] for x, y, w, h in self.bounding_boxes] bbox_regions = [img[y : y + h, x : x + w] for x, y, w, h in self.bounding_boxes]
blurred_boxes = self.aug.augment_images(bbox_regions) blurred_boxes = self.aug.augment_images(bbox_regions)
new_img = np.array(img) new_img = np.array(img)
for (x, y, w, h), blurred in zip(self.bounding_boxes, blurred_boxes): for (x, y, w, h), blurred in zip(self.bounding_boxes, blurred_boxes):
new_img[y:y+h, x:x+w] = blurred new_img[y : y + h, x : x + w] = blurred
return new_img return new_img
def apply_segmentation(self, segmentation): def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
""" Apply no transform on the full-image segmentation. """ """ Apply no transform on the full-image segmentation. """
return segmentation return segmentation
def apply_coords(self, coords): def apply_coords(self, coords: np.ndarray):
""" Apply no transform on the coordinates. """ """ Apply no transform on the coordinates. """
return coords return coords
...@@ -37,31 +47,42 @@ class LocalizedBoxMotionBlurTransform(Transform): ...@@ -37,31 +47,42 @@ class LocalizedBoxMotionBlurTransform(Transform):
""" The inverse is a No-op, only for geometric transforms. """ """ The inverse is a No-op, only for geometric transforms. """
return NoOpTransform() return NoOpTransform()
class LocalizedBoxMotionBlur(aug.Augmentation): class LocalizedBoxMotionBlur(aug.Augmentation):
""" """
Performs faked motion blur on bounding box annotations in an image. Performs faked motion blur on bounding box annotations in an image.
Randomly selects motion blur parameters from the ranges `k`, `angle`, `direction`. Randomly selects motion blur parameters from the ranges `k`, `angle`, `direction`.
""" """
def __init__(self, prob=0.5, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)): def __init__(
self,
prob: float = 0.5,
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
super().__init__() super().__init__()
self._init(locals()) self._init(locals())
def _validate_bbox_xywh_within_bounds(self, bbox, img_h, img_w): def _validate_bbox_xywh_within_bounds(
self, bbox: List[int], img_h: int, img_w: int
):
x, y, w, h = bbox x, y, w, h = bbox
assert x >= 0, f"Invalid x {x}" assert x >= 0, f"Invalid x {x}"
assert y >= 0, f"Invalid y {x}" assert y >= 0, f"Invalid y {x}"
assert y+h <= img_h, f"Invalid right {x+w} (img width {img_w})" assert y + h <= img_h, f"Invalid right {x+w} (img width {img_w})"
assert y+h <= img_h, f"Invalid bottom {y+h} (img height {img_h})" assert y + h <= img_h, f"Invalid bottom {y+h} (img height {img_h})"
def get_transform(self, image, annotations): def get_transform(self, image: np.ndarray, annotations: List[Dict]) -> Transform:
do_tfm = self._rand_range() < self.prob do_tfm = self._rand_range() < self.prob
if do_tfm: if do_tfm:
return self._get_blur_transform(image, annotations) return self._get_blur_transform(image, annotations)
else: else:
return NoOpTransform() return NoOpTransform()
def _get_blur_transform(self, image, annotations): def _get_blur_transform(
self, image: np.ndarray, annotations: List[Dict]
) -> Transform:
""" """
Return a `Transform` that simulates motion blur within the image's bounding box regions. Return a `Transform` that simulates motion blur within the image's bounding box regions.
""" """
...@@ -78,9 +99,12 @@ class LocalizedBoxMotionBlur(aug.Augmentation): ...@@ -78,9 +99,12 @@ class LocalizedBoxMotionBlur(aug.Augmentation):
direction=self.direction, direction=self.direction,
) )
# example repr: "LocalizedBoxMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}" # example repr: "LocalizedBoxMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train): def RandomLocalizedBoxMotionBlurOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Transform]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
...@@ -88,7 +112,12 @@ def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train): ...@@ -88,7 +112,12 @@ def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train):
class MotionBlurTransform(Transform): class MotionBlurTransform(Transform):
def __init__(self, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)): def __init__(
self,
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
""" """
Args: Args:
will apply the specified blur to the image will apply the specified blur to the image
...@@ -99,23 +128,29 @@ class MotionBlurTransform(Transform): ...@@ -99,23 +128,29 @@ class MotionBlurTransform(Transform):
self._set_attributes(locals()) self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1) self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img): def apply_image(self, img: np.ndarray) -> np.ndarray:
img = self.aug.augment_image(img) img = self.aug.augment_image(img)
return img return img
def apply_segmentation(self, segmentation): def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
return segmentation return segmentation
def apply_coords(self, coords): def apply_coords(self, coords: np.ndarray) -> np.ndarray:
return coords return coords
class RandomMotionBlur(aug.Augmentation): class RandomMotionBlur(aug.Augmentation):
""" """
Apply random motion blur. Apply random motion blur.
""" """
def __init__(self, prob=0.5, k=(3, 7), angle=(0, 360), direction=(-1.0, 1.0)): def __init__(
self,
prob: float = 0.5,
k: Tuple[float, float] = (3, 7),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
""" """
Args: Args:
prob (float): probability of applying transform prob (float): probability of applying transform
...@@ -127,7 +162,7 @@ class RandomMotionBlur(aug.Augmentation): ...@@ -127,7 +162,7 @@ class RandomMotionBlur(aug.Augmentation):
# Turn all locals into member variables. # Turn all locals into member variables.
self._init(locals()) self._init(locals())
def get_transform(self, img): def get_transform(self, img: np.ndarray) -> Transform:
do = self._rand_range() < self.prob do = self._rand_range() < self.prob
if do: if do:
return MotionBlurTransform(self.k, self.angle, self.direction) return MotionBlurTransform(self.k, self.angle, self.direction)
...@@ -137,7 +172,7 @@ class RandomMotionBlur(aug.Augmentation): ...@@ -137,7 +172,7 @@ class RandomMotionBlur(aug.Augmentation):
# example repr: "RandomMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}" # example repr: "RandomMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomMotionBlurOp(cfg, arg_str, is_train): def RandomMotionBlurOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
......
#!/usr/bin/env python3 #!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple, List
import numpy as np import numpy as np
import torch import torch
...@@ -18,7 +19,7 @@ def get_box_union(boxes: Boxes): ...@@ -18,7 +19,7 @@ def get_box_union(boxes: Boxes):
return Boxes(union_bt) return Boxes(union_bt)
def get_box_from_mask(mask: np.ndarray): def get_box_from_mask(mask: torch.Tensor) -> Tuple[int, int, int, int]:
"""Find if there are non-zero elements per row/column first and then find """Find if there are non-zero elements per row/column first and then find
min/max position of those elements. min/max position of those elements.
Only support 2d image (h x w) Only support 2d image (h x w)
...@@ -39,7 +40,9 @@ def get_box_from_mask(mask: np.ndarray): ...@@ -39,7 +40,9 @@ def get_box_from_mask(mask: np.ndarray):
return cmin, rmin, cmax - cmin + 1, rmax - rmin + 1 return cmin, rmin, cmax - cmin + 1, rmax - rmin + 1
def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio): def get_min_box_aspect_ratio(
bbox_xywh: torch.Tensor, target_aspect_ratio: float
) -> torch.Tensor:
"""Get a minimal bbox that matches the target_aspect_ratio """Get a minimal bbox that matches the target_aspect_ratio
target_aspect_ratio is representation by w/h target_aspect_ratio is representation by w/h
bbox are represented by pixel coordinates""" bbox are represented by pixel coordinates"""
...@@ -58,19 +61,21 @@ def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio): ...@@ -58,19 +61,21 @@ def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio):
return torch.cat([new_xy, new_wh]) return torch.cat([new_xy, new_wh])
def get_box_center(bbox_xywh): def get_box_center(bbox_xywh: torch.Tensor) -> torch.Tensor:
"""Get the center of the bbox""" """Get the center of the bbox"""
return torch.Tensor(bbox_xywh[:2]) + torch.Tensor(bbox_xywh[2:]) / 2.0 return torch.Tensor(bbox_xywh[:2]) + torch.Tensor(bbox_xywh[2:]) / 2.0
def get_bbox_xywh_from_center_wh(bbox_center, bbox_wh): def get_bbox_xywh_from_center_wh(
bbox_center: torch.Tensor, bbox_wh: torch.Tensor
) -> torch.Tensor:
"""Get a bbox from bbox center and the width and height""" """Get a bbox from bbox center and the width and height"""
bbox_wh = torch.Tensor(bbox_wh) bbox_wh = torch.Tensor(bbox_wh)
bbox_xy = torch.Tensor(bbox_center) - bbox_wh / 2.0 bbox_xy = torch.Tensor(bbox_center) - bbox_wh / 2.0
return torch.cat([bbox_xy, bbox_wh]) return torch.cat([bbox_xy, bbox_wh])
def get_bbox_xyxy_from_xywh(bbox_xywh): def get_bbox_xyxy_from_xywh(bbox_xywh: torch.Tensor) -> torch.Tensor:
"""Convert the bbox from xywh format to xyxy format """Convert the bbox from xywh format to xyxy format
bbox are represented by pixel coordinates, bbox are represented by pixel coordinates,
the center of pixels are (x + 0.5, y + 0.5) the center of pixels are (x + 0.5, y + 0.5)
...@@ -85,7 +90,7 @@ def get_bbox_xyxy_from_xywh(bbox_xywh): ...@@ -85,7 +90,7 @@ def get_bbox_xyxy_from_xywh(bbox_xywh):
) )
def get_bbox_xywh_from_xyxy(bbox_xyxy): def get_bbox_xywh_from_xyxy(bbox_xyxy: torch.Tensor) -> torch.Tensor:
"""Convert the bbox from xyxy format to xywh format""" """Convert the bbox from xyxy format to xywh format"""
return torch.Tensor( return torch.Tensor(
[ [
...@@ -97,25 +102,25 @@ def get_bbox_xywh_from_xyxy(bbox_xyxy): ...@@ -97,25 +102,25 @@ def get_bbox_xywh_from_xyxy(bbox_xyxy):
) )
def to_boxes_from_xywh(bbox_xywh): def to_boxes_from_xywh(bbox_xywh: torch.Tensor) -> torch.Tensor:
return Boxes(get_bbox_xyxy_from_xywh(bbox_xywh).unsqueeze(0)) return Boxes(get_bbox_xyxy_from_xywh(bbox_xywh).unsqueeze(0))
def scale_bbox_center(bbox_xywh, target_scale): def scale_bbox_center(bbox_xywh: torch.Tensor, target_scale: float) -> torch.Tensor:
"""Scale the bbox around the center of the bbox""" """Scale the bbox around the center of the bbox"""
box_center = get_box_center(bbox_xywh) box_center = get_box_center(bbox_xywh)
box_wh = torch.Tensor(bbox_xywh[2:]) * target_scale box_wh = torch.Tensor(bbox_xywh[2:]) * target_scale
return get_bbox_xywh_from_center_wh(box_center, box_wh) return get_bbox_xywh_from_center_wh(box_center, box_wh)
def offset_bbox(bbox_xywh, target_offset): def offset_bbox(bbox_xywh: torch.Tensor, target_offset: float) -> torch.Tensor:
"""Offset the bbox based on target_offset""" """Offset the bbox based on target_offset"""
box_center = get_box_center(bbox_xywh) box_center = get_box_center(bbox_xywh)
new_center = box_center + torch.Tensor(target_offset) new_center = box_center + torch.Tensor(target_offset)
return get_bbox_xywh_from_center_wh(new_center, bbox_xywh[2:]) return get_bbox_xywh_from_center_wh(new_center, bbox_xywh[2:])
def clip_box_xywh(bbox_xywh, image_size_hw): def clip_box_xywh(bbox_xywh: torch.Tensor, image_size_hw: List[int]):
"""Clip the bbox based on image_size_hw""" """Clip the bbox based on image_size_hw"""
h, w = image_size_hw h, w = image_size_hw
bbox_xyxy = get_bbox_xyxy_from_xywh(bbox_xywh) bbox_xyxy = get_bbox_xyxy_from_xywh(bbox_xywh)
......
...@@ -4,7 +4,9 @@ ...@@ -4,7 +4,9 @@
import json import json
import logging import logging
from typing import List, Dict, Optional, Tuple
from detectron2.config import CfgNode
from detectron2.data import transforms as d2T from detectron2.data import transforms as d2T
from detectron2.utils.registry import Registry from detectron2.utils.registry import Registry
...@@ -15,7 +17,7 @@ logger = logging.getLogger(__name__) ...@@ -15,7 +17,7 @@ logger = logging.getLogger(__name__)
TRANSFORM_OP_REGISTRY = Registry("D2GO_TRANSFORM_REGISTRY") TRANSFORM_OP_REGISTRY = Registry("D2GO_TRANSFORM_REGISTRY")
def _json_load(arg_str): def _json_load(arg_str: str) -> Dict:
try: try:
return json.loads(arg_str) return json.loads(arg_str)
except json.decoder.JSONDecodeError as e: except json.decoder.JSONDecodeError as e:
...@@ -25,7 +27,9 @@ def _json_load(arg_str): ...@@ -25,7 +27,9 @@ def _json_load(arg_str):
# example repr: "ResizeShortestEdgeOp" # example repr: "ResizeShortestEdgeOp"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeOp(cfg, arg_str, is_train): def ResizeShortestEdgeOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[d2T.Transform]:
if is_train: if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN max_size = cfg.INPUT.MAX_SIZE_TRAIN
...@@ -47,9 +51,11 @@ def ResizeShortestEdgeOp(cfg, arg_str, is_train): ...@@ -47,9 +51,11 @@ def ResizeShortestEdgeOp(cfg, arg_str, is_train):
# example repr: "ResizeShortestEdgeSquareOp" # example repr: "ResizeShortestEdgeSquareOp"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train): def ResizeShortestEdgeSquareOp(
""" Resize the input to square using INPUT.MIN_SIZE_TRAIN or INPUT.MIN_SIZE_TEST cfg: CfgNode, arg_str: str, is_train: bool
without keeping aspect ratio ) -> List[d2T.Transform]:
"""Resize the input to square using INPUT.MIN_SIZE_TRAIN or INPUT.MIN_SIZE_TEST
without keeping aspect ratio
""" """
if is_train: if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN min_size = cfg.INPUT.MIN_SIZE_TRAIN
...@@ -67,7 +73,7 @@ def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train): ...@@ -67,7 +73,7 @@ def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train):
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def ResizeOp(cfg, arg_str, is_train): def ResizeOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[d2T.Transform]:
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
return [d2T.Resize(**kwargs)] return [d2T.Resize(**kwargs)]
...@@ -76,7 +82,7 @@ def ResizeOp(cfg, arg_str, is_train): ...@@ -76,7 +82,7 @@ def ResizeOp(cfg, arg_str, is_train):
_TRANSFORM_REPR_SEPARATOR = "::" _TRANSFORM_REPR_SEPARATOR = "::"
def parse_tfm_gen_repr(tfm_gen_repr): def parse_tfm_gen_repr(tfm_gen_repr: str) -> Tuple[str, Optional[str]]:
if tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 0: if tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 0:
return tfm_gen_repr, None return tfm_gen_repr, None
elif tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 1: elif tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 1:
...@@ -88,7 +94,7 @@ def parse_tfm_gen_repr(tfm_gen_repr): ...@@ -88,7 +94,7 @@ def parse_tfm_gen_repr(tfm_gen_repr):
) )
def build_transform_gen(cfg, is_train): def build_transform_gen(cfg: CfgNode, is_train: bool) -> List[d2T.Transform]:
""" """
This function builds a list of TransformGen or Transform objects using the a list of This function builds a list of TransformGen or Transform objects using the a list of
strings from cfg.D2GO_DATA.AUG_OPS.TRAIN/TEST. Each string (aka. `tfm_gen_repr`) strings from cfg.D2GO_DATA.AUG_OPS.TRAIN/TEST. Each string (aka. `tfm_gen_repr`)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List from typing import List, Callable, Union
import detectron2.data.transforms.augmentation as aug import detectron2.data.transforms.augmentation as aug
import numpy as np import numpy as np
...@@ -22,7 +22,7 @@ class InvertibleColorTransform(Transform): ...@@ -22,7 +22,7 @@ class InvertibleColorTransform(Transform):
coordinates such as bounding boxes should not be changed) coordinates such as bounding boxes should not be changed)
""" """
def __init__(self, op, inverse_op): def __init__(self, op: Callable, inverse_op: Callable):
""" """
Args: Args:
op (Callable): operation to be applied to the image, op (Callable): operation to be applied to the image,
...@@ -35,16 +35,16 @@ class InvertibleColorTransform(Transform): ...@@ -35,16 +35,16 @@ class InvertibleColorTransform(Transform):
super().__init__() super().__init__()
self._set_attributes(locals()) self._set_attributes(locals())
def apply_image(self, img): def apply_image(self, img: np.ndarray) -> np.ndarray:
return self.op(img) return self.op(img)
def apply_coords(self, coords): def apply_coords(self, coords: np.ndarray) -> np.ndarray:
return coords return coords
def inverse(self): def inverse(self) -> Transform:
return InvertibleColorTransform(self.inverse_op, self.op) return InvertibleColorTransform(self.inverse_op, self.op)
def apply_segmentation(self, segmentation): def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
return segmentation return segmentation
...@@ -60,7 +60,7 @@ class RandomContrastYUV(aug.Augmentation): ...@@ -60,7 +60,7 @@ class RandomContrastYUV(aug.Augmentation):
super().__init__() super().__init__()
self._init(locals()) self._init(locals())
def get_transform(self, img): def get_transform(self, img: np.ndarray) -> Transform:
w = np.random.uniform(self.intensity_min, self.intensity_max) w = np.random.uniform(self.intensity_min, self.intensity_max)
pure_gray = np.zeros_like(img) pure_gray = np.zeros_like(img)
pure_gray[:, :, 0] = 0.5 pure_gray[:, :, 0] = 0.5
...@@ -77,7 +77,7 @@ class RandomSaturationYUV(aug.Augmentation): ...@@ -77,7 +77,7 @@ class RandomSaturationYUV(aug.Augmentation):
super().__init__() super().__init__()
self._init(locals()) self._init(locals())
def get_transform(self, img): def get_transform(self, img: np.ndarray) -> Transform:
assert ( assert (
len(img.shape) == 3 and img.shape[-1] == 3 len(img.shape) == 3 and img.shape[-1] == 3
), f"Expected (H, W, 3), image shape {img.shape}" ), f"Expected (H, W, 3), image shape {img.shape}"
...@@ -87,7 +87,7 @@ class RandomSaturationYUV(aug.Augmentation): ...@@ -87,7 +87,7 @@ class RandomSaturationYUV(aug.Augmentation):
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w) return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
def convert_rgb_to_yuv_bt601(image): def convert_rgb_to_yuv_bt601(image: np.ndarray) -> np.ndarray:
"""Convert RGB image in (H, W, C) to YUV format """Convert RGB image in (H, W, C) to YUV format
image: range 0 ~ 255 image: range 0 ~ 255
""" """
...@@ -96,7 +96,7 @@ def convert_rgb_to_yuv_bt601(image): ...@@ -96,7 +96,7 @@ def convert_rgb_to_yuv_bt601(image):
return image return image
def convery_yuv_bt601_to_rgb(image): def convery_yuv_bt601_to_rgb(image: np.ndarray) -> np.ndarray:
return du.convert_image_to_rgb(image, "YUV-BT.601") return du.convert_image_to_rgb(image, "YUV-BT.601")
...@@ -107,7 +107,7 @@ class RGB2YUVBT601(aug.Augmentation): ...@@ -107,7 +107,7 @@ class RGB2YUVBT601(aug.Augmentation):
convert_rgb_to_yuv_bt601, convery_yuv_bt601_to_rgb convert_rgb_to_yuv_bt601, convery_yuv_bt601_to_rgb
) )
def get_transform(self, image): def get_transform(self, image) -> Transform:
return self.trans return self.trans
...@@ -118,11 +118,13 @@ class YUVBT6012RGB(aug.Augmentation): ...@@ -118,11 +118,13 @@ class YUVBT6012RGB(aug.Augmentation):
convery_yuv_bt601_to_rgb, convert_rgb_to_yuv_bt601 convery_yuv_bt601_to_rgb, convert_rgb_to_yuv_bt601
) )
def get_transform(self, image): def get_transform(self, image) -> Transform:
return self.trans return self.trans
def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augmentation]: def build_func(
cfg: CfgNode, arg_str: str, is_train: bool, obj
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
...@@ -130,20 +132,28 @@ def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augm ...@@ -130,20 +132,28 @@ def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augm
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomContrastYUVOp(cfg, arg_str, is_train): def RandomContrastYUVOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RandomContrastYUV) return build_func(cfg, arg_str, is_train, obj=RandomContrastYUV)
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomSaturationYUVOp(cfg, arg_str, is_train): def RandomSaturationYUVOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RandomSaturationYUV) return build_func(cfg, arg_str, is_train, obj=RandomSaturationYUV)
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RGB2YUVBT601Op(cfg, arg_str, is_train): def RGB2YUVBT601Op(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RGB2YUVBT601) return build_func(cfg, arg_str, is_train, obj=RGB2YUVBT601)
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def YUVBT6012RGBOp(cfg, arg_str, is_train): def YUVBT6012RGBOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=YUVBT6012RGB) return build_func(cfg, arg_str, is_train, obj=YUVBT6012RGB)
...@@ -3,9 +3,11 @@ ...@@ -3,9 +3,11 @@
import math import math
from typing import List, Optional, Tuple, Union, Any
import detectron2.data.transforms.augmentation as aug import detectron2.data.transforms.augmentation as aug
import numpy as np import numpy as np
from detectron2.config import CfgNode
from detectron2.data.transforms import ExtentTransform, CropTransform from detectron2.data.transforms import ExtentTransform, CropTransform
from detectron2.structures import BoxMode from detectron2.structures import BoxMode
...@@ -22,7 +24,7 @@ class CropBoundary(aug.Augmentation): ...@@ -22,7 +24,7 @@ class CropBoundary(aug.Augmentation):
super().__init__() super().__init__()
self.count = count self.count = count
def get_transform(self, image): def get_transform(self, image: np.ndarray) -> Transform:
img_h, img_w = image.shape[:2] img_h, img_w = image.shape[:2]
assert self.count < img_h and self.count < img_w assert self.count < img_h and self.count < img_w
assert img_h > self.count * 2 assert img_h > self.count * 2
...@@ -32,13 +34,22 @@ class CropBoundary(aug.Augmentation): ...@@ -32,13 +34,22 @@ class CropBoundary(aug.Augmentation):
class PadTransform(Transform): class PadTransform(Transform):
def __init__(self, x0, y0, w, h, org_w, org_h, pad_mode="constant"): def __init__(
self,
x0: int,
y0: int,
w: int,
h: int,
org_w: int,
org_h: int,
pad_mode: str = "constant",
):
super().__init__() super().__init__()
assert x0 + w <= org_w assert x0 + w <= org_w
assert y0 + h <= org_h assert y0 + h <= org_h
self._set_attributes(locals()) self._set_attributes(locals())
def apply_image(self, img): def apply_image(self, img: np.ndarray) -> np.array:
"""img: HxWxC or HxW""" """img: HxWxC or HxW"""
assert len(img.shape) == 2 or len(img.shape) == 3 assert len(img.shape) == 2 or len(img.shape) == 3
assert img.shape[0] == self.h and img.shape[1] == self.w assert img.shape[0] == self.h and img.shape[1] == self.w
...@@ -64,12 +75,12 @@ InvertibleCropTransform = CropTransform ...@@ -64,12 +75,12 @@ InvertibleCropTransform = CropTransform
class PadBorderDivisible(aug.Augmentation): class PadBorderDivisible(aug.Augmentation):
def __init__(self, size_divisibility, pad_mode="constant"): def __init__(self, size_divisibility: int, pad_mode: str = "constant"):
super().__init__() super().__init__()
self.size_divisibility = size_divisibility self.size_divisibility = size_divisibility
self.pad_mode = pad_mode self.pad_mode = pad_mode
def get_transform(self, image): def get_transform(self, image: np.ndarray) -> Transform:
""" image: HxWxC """ """ image: HxWxC """
assert len(image.shape) == 3 and image.shape[2] in [1, 3] assert len(image.shape) == 3 and image.shape[2] in [1, 3]
H, W = image.shape[:2] H, W = image.shape[:2]
...@@ -80,7 +91,10 @@ class PadBorderDivisible(aug.Augmentation): ...@@ -80,7 +91,10 @@ class PadBorderDivisible(aug.Augmentation):
class RandomCropFixedAspectRatio(aug.Augmentation): class RandomCropFixedAspectRatio(aug.Augmentation):
def __init__( def __init__(
self, crop_aspect_ratios_list, scale_range=None, offset_scale_range=None self,
crop_aspect_ratios_list: List[float],
scale_range: Optional[Union[List, Tuple]] = None,
offset_scale_range: Optional[Union[List, Tuple]] = None,
): ):
super().__init__() super().__init__()
assert isinstance(crop_aspect_ratios_list, (list, tuple)) assert isinstance(crop_aspect_ratios_list, (list, tuple))
...@@ -103,21 +117,21 @@ class RandomCropFixedAspectRatio(aug.Augmentation): ...@@ -103,21 +117,21 @@ class RandomCropFixedAspectRatio(aug.Augmentation):
self.rng = np.random.default_rng() self.rng = np.random.default_rng()
def _pick_aspect_ratio(self): def _pick_aspect_ratio(self) -> float:
return self.rng.choice(self.crop_aspect_ratios_list) return self.rng.choice(self.crop_aspect_ratios_list)
def _pick_scale(self): def _pick_scale(self) -> float:
if self.scale_range is None: if self.scale_range is None:
return 1.0 return 1.0
return self.rng.uniform(*self.scale_range) return self.rng.uniform(*self.scale_range)
def _pick_offset(self, box_w, box_h): def _pick_offset(self, box_w: float, box_h: float) -> Tuple[float, float]:
if self.offset_scale_range is None: if self.offset_scale_range is None:
return [0, 0] return [0, 0]
offset_scale = self.rng.uniform(*self.offset_scale_range, size=2) offset_scale = self.rng.uniform(*self.offset_scale_range, size=2)
return offset_scale[0] * box_w, offset_scale[1] * box_h return offset_scale[0] * box_w, offset_scale[1] * box_h
def get_transform(self, image, sem_seg): def get_transform(self, image: np.ndarray, sem_seg: np.ndarray) -> Transform:
# HWC or HW for image, HW for sem_seg # HWC or HW for image, HW for sem_seg
assert len(image.shape) in [2, 3] assert len(image.shape) in [2, 3]
assert len(sem_seg.shape) == 2 assert len(sem_seg.shape) == 2
...@@ -148,7 +162,9 @@ class RandomCropFixedAspectRatio(aug.Augmentation): ...@@ -148,7 +162,9 @@ class RandomCropFixedAspectRatio(aug.Augmentation):
# example repr: "CropBoundaryOp::{'count': 3}" # example repr: "CropBoundaryOp::{'count': 3}"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def CropBoundaryOp(cfg, arg_str, is_train): def CropBoundaryOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
...@@ -157,7 +173,9 @@ def CropBoundaryOp(cfg, arg_str, is_train): ...@@ -157,7 +173,9 @@ def CropBoundaryOp(cfg, arg_str, is_train):
# example repr: "RandomCropFixedAspectRatioOp::{'crop_aspect_ratios_list': [0.5], 'scale_range': [0.8, 1.2], 'offset_scale_range': [-0.3, 0.3]}" # example repr: "RandomCropFixedAspectRatioOp::{'crop_aspect_ratios_list': [0.5], 'scale_range': [0.8, 1.2], 'offset_scale_range': [-0.3, 0.3]}"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train): def RandomCropFixedAspectRatioOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
...@@ -165,7 +183,7 @@ def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train): ...@@ -165,7 +183,7 @@ def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train):
class RandomInstanceCrop(aug.Augmentation): class RandomInstanceCrop(aug.Augmentation):
def __init__(self, crop_scale=(0.8, 1.6)): def __init__(self, crop_scale: Tuple[float, float] = (0.8, 1.6)):
""" """
Generates a CropTransform centered around the instance. Generates a CropTransform centered around the instance.
crop_scale: [low, high] relative crop scale around the instance, this crop_scale: [low, high] relative crop scale around the instance, this
...@@ -177,7 +195,7 @@ class RandomInstanceCrop(aug.Augmentation): ...@@ -177,7 +195,7 @@ class RandomInstanceCrop(aug.Augmentation):
isinstance(crop_scale, (list, tuple)) and len(crop_scale) == 2 isinstance(crop_scale, (list, tuple)) and len(crop_scale) == 2
), crop_scale ), crop_scale
def get_transform(self, image, annotations): def get_transform(self, image: np.ndarray, annotations: List[Any]) -> Transform:
""" """
This function will modify instances to set the iscrowd flag to 1 for This function will modify instances to set the iscrowd flag to 1 for
annotations not picked. It relies on the dataset mapper to filter those annotations not picked. It relies on the dataset mapper to filter those
...@@ -215,22 +233,24 @@ class RandomInstanceCrop(aug.Augmentation): ...@@ -215,22 +233,24 @@ class RandomInstanceCrop(aug.Augmentation):
# example repr: "RandomInstanceCropOp::{'crop_scale': [0.8, 1.6]}" # example repr: "RandomInstanceCropOp::{'crop_scale': [0.8, 1.6]}"
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomInstanceCropOp(cfg, arg_str, is_train): def RandomInstanceCropOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
return [RandomInstanceCrop(**kwargs)] return [RandomInstanceCrop(**kwargs)]
class CropBoxAug(aug.Augmentation): class CropBoxAug(aug.Augmentation):
""" Augmentation to crop the image based on boxes """Augmentation to crop the image based on boxes
Scale the box with `box_scale_factor` around the center before cropping Scale the box with `box_scale_factor` around the center before cropping
""" """
def __init__(self, box_scale_factor=1.0):
def __init__(self, box_scale_factor: float = 1.0):
super().__init__() super().__init__()
self.box_scale_factor = box_scale_factor self.box_scale_factor = box_scale_factor
def get_transform(self, image: np.ndarray, boxes: np.ndarray): def get_transform(self, image: np.ndarray, boxes: np.ndarray) -> Transform:
# boxes: 1 x 4 in xyxy format # boxes: 1 x 4 in xyxy format
assert boxes.shape[0] == 1 assert boxes.shape[0] == 1
assert isinstance(image, np.ndarray) assert isinstance(image, np.ndarray)
...@@ -239,7 +259,7 @@ class CropBoxAug(aug.Augmentation): ...@@ -239,7 +259,7 @@ class CropBoxAug(aug.Augmentation):
box_xywh = bu.get_bbox_xywh_from_xyxy(boxes[0]) box_xywh = bu.get_bbox_xywh_from_xyxy(boxes[0])
if self.box_scale_factor != 1.0: if self.box_scale_factor != 1.0:
box_xywh = bu.scale_bbox_center(box_xywh, self.box_scale_factor) box_xywh = bu.scale_bbox_center(box_xywh, self.box_scale_factor)
box_xywh = bu.clip_box_xywh(box_xywh, [img_h, img_w]) box_xywh = bu.clip_box_xywh(box_xywh, [img_h, img_w])
box_xywh = box_xywh.int().tolist() box_xywh = box_xywh.int().tolist()
return CropTransform(*box_xywh, orig_w=img_w, orig_h=img_h) return CropTransform(*box_xywh, orig_w=img_w, orig_h=img_h)
...@@ -3,11 +3,15 @@ ...@@ -3,11 +3,15 @@
import logging import logging
from typing import List, Union
from .build import TRANSFORM_OP_REGISTRY, _json_load import detectron2.data.transforms.augmentation as aug
from detectron2.config import CfgNode
from detectron2.data import transforms as d2T from detectron2.data import transforms as d2T
from detectron2.projects.point_rend import ColorAugSSDTransform from detectron2.projects.point_rend import ColorAugSSDTransform
from .build import TRANSFORM_OP_REGISTRY, _json_load
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -23,8 +27,10 @@ D2_RANDOM_TRANSFORMS = { ...@@ -23,8 +27,10 @@ D2_RANDOM_TRANSFORMS = {
} }
def build_func(cfg, arg_str, is_train, name): def build_func(
assert is_train cfg: CfgNode, arg_str: str, is_train: bool, name: str
) -> List[Union[aug.Augmentation, d2T.Transform]]:
assert is_train, "Random augmentation is for training only"
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
return [D2_RANDOM_TRANSFORMS[name](**kwargs)] return [D2_RANDOM_TRANSFORMS[name](**kwargs)]
...@@ -35,47 +41,65 @@ def build_func(cfg, arg_str, is_train, name): ...@@ -35,47 +41,65 @@ def build_func(cfg, arg_str, is_train, name):
# example 3: RandomFlipOp::{"prob":0.5} # example 3: RandomFlipOp::{"prob":0.5}
# example 4: RandomBrightnessOp::{"intensity_min":1.0, "intensity_max":2.0} # example 4: RandomBrightnessOp::{"intensity_min":1.0, "intensity_max":2.0}
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomBrightnessOp(cfg, arg_str, is_train): def RandomBrightnessOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomBrightness") return build_func(cfg, arg_str, is_train, name="RandomBrightness")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomContrastOp(cfg, arg_str, is_train): def RandomContrastOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomContrast") return build_func(cfg, arg_str, is_train, name="RandomContrast")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomCropOp(cfg, arg_str, is_train): def RandomCropOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomCrop") return build_func(cfg, arg_str, is_train, name="RandomCrop")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomRotation(cfg, arg_str, is_train): def RandomRotation(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomRotation") return build_func(cfg, arg_str, is_train, name="RandomRotation")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomExtentOp(cfg, arg_str, is_train): def RandomExtentOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomExtent") return build_func(cfg, arg_str, is_train, name="RandomExtent")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomFlipOp(cfg, arg_str, is_train): def RandomFlipOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomFlip") return build_func(cfg, arg_str, is_train, name="RandomFlip")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomSaturationOp(cfg, arg_str, is_train): def RandomSaturationOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomSaturation") return build_func(cfg, arg_str, is_train, name="RandomSaturation")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomLightingOp(cfg, arg_str, is_train): def RandomLightingOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomLighting") return build_func(cfg, arg_str, is_train, name="RandomLighting")
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandomSSDColorAugOp(cfg, arg_str, is_train): def RandomSSDColorAugOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
assert is_train assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Union from typing import List, Optional, Union, Any
import numpy as np import numpy as np
import torch import torch
...@@ -51,7 +51,7 @@ class AugInput: ...@@ -51,7 +51,7 @@ class AugInput:
def apply_augmentations( def apply_augmentations(
self, augmentations: List[Union[Augmentation, Transform]] self, augmentations: List[Union[Augmentation, Transform]]
) -> TransformList: ) -> AugmentationList:
""" """
Equivalent of ``AugmentationList(augmentations)(self)`` Equivalent of ``AugmentationList(augmentations)(self)``
""" """
...@@ -70,14 +70,14 @@ class Tensor2Array(Transform): ...@@ -70,14 +70,14 @@ class Tensor2Array(Transform):
assert len(img.shape) == 3, img.shape assert len(img.shape) == 3, img.shape
return img.cpu().numpy().transpose(1, 2, 0) return img.cpu().numpy().transpose(1, 2, 0)
def apply_coords(self, coords): def apply_coords(self, coords: Any) -> Any:
return coords return coords
def apply_segmentation(self, segmentation: torch.Tensor) -> np.ndarray: def apply_segmentation(self, segmentation: torch.Tensor) -> np.ndarray:
assert len(segmentation.shape) == 2, segmentation.shape assert len(segmentation.shape) == 2, segmentation.shape
return segmentation.cpu().numpy() return segmentation.cpu().numpy()
def inverse(self): def inverse(self) -> Transform:
return Array2Tensor() return Array2Tensor()
...@@ -93,12 +93,12 @@ class Array2Tensor(Transform): ...@@ -93,12 +93,12 @@ class Array2Tensor(Transform):
assert len(img.shape) == 3, img.shape assert len(img.shape) == 3, img.shape
return torch.from_numpy(img.transpose(2, 0, 1).astype("float32")) return torch.from_numpy(img.transpose(2, 0, 1).astype("float32"))
def apply_coords(self, coords): def apply_coords(self, coords: Any) -> Any:
return coords return coords
def apply_segmentation(self, segmentation: np.ndarray) -> torch.Tensor: def apply_segmentation(self, segmentation: np.ndarray) -> torch.Tensor:
assert len(segmentation.shape) == 2, segmentation.shape assert len(segmentation.shape) == 2, segmentation.shape
return torch.from_numpy(segmentation.astype("long")) return torch.from_numpy(segmentation.astype("long"))
def inverse(self): def inverse(self) -> Transform:
return Tensor2Array() return Tensor2Array()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment