Commit 82f17be0 authored by Sam Tsai's avatar Sam Tsai Committed by Facebook GitHub Bot
Browse files

add typing to transforms

Summary: Add typing to transform.

Reviewed By: wat3rBro

Differential Revision: D27145140

fbshipit-source-id: 8556427b421bf91a05692a590db175c68c4d6890
parent daf7f294
......@@ -2,18 +2,29 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import json
import random
from typing import List, Optional, Tuple
import cv2
import json
import numpy as np
import torchvision.transforms as T
from detectron2.config import CfgNode
from detectron2.data.transforms import Transform, TransformGen, NoOpTransform
from .build import TRANSFORM_OP_REGISTRY
from detectron2.data.transforms import Transform, TransformGen, NoOpTransform
import torchvision.transforms as T
class AffineTransform(Transform):
def __init__(self, M, img_w, img_h, flags=None, border_mode=None, is_inversed_M=False):
def __init__(
self,
M: np.ndarray,
img_w: int,
img_h: int,
flags: Optional[int] = None,
border_mode: Optional[int] = None,
is_inversed_M: bool = False,
):
"""
Args:
will transform img according to affine transform M
......@@ -26,7 +37,7 @@ class AffineTransform(Transform):
if border_mode is not None:
self.warp_kwargs["borderMode"] = border_mode
def apply_image(self, img):
def apply_image(self, img: np.ndarray) -> np.ndarray:
M = self.M
if self.is_inversed_M:
M = M[:2]
......@@ -38,7 +49,7 @@ class AffineTransform(Transform):
)
return img
def apply_coords(self, coords):
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
# Add row of ones to enable matrix multiplication
coords = coords.T
ones = np.ones((1, coords.shape[1]))
......@@ -52,21 +63,22 @@ class AffineTransform(Transform):
class RandomPivotScaling(TransformGen):
"""
Uniformly pick a random pivot point inside image frame, scaling the image
around the pivot point using the scale factor sampled from a list of
given scales. The pivot point's location is unchanged after the transform.
Arguments:
scales: List[float]: each element can be any positive float number,
when larger than 1.0 objects become larger after transform
and vice versa.
Uniformly pick a random pivot point inside image frame, scaling the image
around the pivot point using the scale factor sampled from a list of
given scales. The pivot point's location is unchanged after the transform.
Arguments:
scales: List[float]: each element can be any positive float number,
when larger than 1.0 objects become larger after transform
and vice versa.
"""
def __init__(self, scales):
def __init__(self, scales: List[int]):
super().__init__()
self._init(locals())
self.scales = scales
def get_transform(self, img):
def get_transform(self, img: np.ndarray) -> Transform:
img_h, img_w, _ = img.shape
img_h = float(img_h)
img_w = float(img_w)
......@@ -85,10 +97,8 @@ class RandomPivotScaling(TransformGen):
rb = (img_w, img_h)
pivot = (pivot_x, pivot_y)
pts1 = np.float32([lt, pivot, rb])
pts2 = np.float32([
_interp(pivot, lt, scale),
pivot,
_interp(pivot, rb, scale)],
pts2 = np.float32(
[_interp(pivot, lt, scale), pivot, _interp(pivot, rb, scale)],
)
M = cv2.getAffineTransform(pts1, pts2)
......@@ -97,16 +107,17 @@ class RandomPivotScaling(TransformGen):
class RandomAffine(TransformGen):
"""
Apply random affine trasform to the image given
probabilities and ranges in each dimension.
Apply random affine trasform to the image given
probabilities and ranges in each dimension.
"""
def __init__(
self,
prob=0.5,
angle_range=(-90, 90),
translation_range=(0, 0),
scale_range=(1.0, 1.0),
shear_range=(0, 0),
prob: float = 0.5,
angle_range: Tuple[float, float] = (-90, 90),
translation_range: Tuple[float, float] = (0, 0),
scale_range: Tuple[float, float] = (1.0, 1.0),
shear_range: Tuple[float, float] = (0, 0),
):
"""
Args:
......@@ -123,7 +134,7 @@ class RandomAffine(TransformGen):
# Turn all locals into member variables.
self._init(locals())
def get_transform(self, img):
def get_transform(self, img: np.ndarray) -> Transform:
im_h, im_w = img.shape[:2]
max_size = max(im_w, im_h)
center = [im_w / 2, im_h / 2]
......@@ -148,11 +159,13 @@ class RandomAffine(TransformGen):
M = np.linalg.inv(M_inv)
# Center in output patch
img_corners = np.array([
[0, 0, im_w, im_w],
[0, im_h, 0, im_h],
[1, 1, 1, 1],
])
img_corners = np.array(
[
[0, 0, im_w, im_w],
[0, im_h, 0, im_h],
[1, 1, 1, 1],
]
)
transformed_corners = M @ img_corners
x_min = np.amin(transformed_corners[0])
x_max = np.amax(transformed_corners[0])
......@@ -184,14 +197,15 @@ class RandomAffine(TransformGen):
max_size,
flags=cv2.WARP_INVERSE_MAP + cv2.INTER_LINEAR,
border_mode=cv2.BORDER_REPLICATE,
is_inversed_M=True
is_inversed_M=True,
)
else:
return NoOpTransform()
# example repr: "RandomPivotScalingOp::[1.0, 0.75, 0.5]"
@TRANSFORM_OP_REGISTRY.register()
def RandomPivotScalingOp(cfg, arg_str, is_train):
def RandomPivotScalingOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train
scales = json.loads(arg_str)
assert isinstance(scales, list)
......@@ -200,7 +214,7 @@ def RandomPivotScalingOp(cfg, arg_str, is_train):
@TRANSFORM_OP_REGISTRY.register()
def RandomAffineOp(cfg, arg_str, is_train):
def RandomAffineOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train
kwargs = json.loads(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Tuple, Dict
import detectron2.data.transforms.augmentation as aug
from detectron2.data.transforms import NoOpTransform, Transform
import numpy as np
from detectron2.config import CfgNode
from detectron2.data.transforms import NoOpTransform, Transform
from .build import TRANSFORM_OP_REGISTRY, _json_load
class LocalizedBoxMotionBlurTransform(Transform):
""" Transform to blur provided bounding boxes from an image. """
def __init__(self, bounding_boxes, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
def __init__(
self,
bounding_boxes: List[List[int]],
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
import imgaug.augmenters as iaa
super().__init__()
self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img):
bbox_regions = [img[y:y+h, x:x+w] for x, y, w, h in self.bounding_boxes]
def apply_image(self, img: np.ndarray) -> np.ndarray:
bbox_regions = [img[y : y + h, x : x + w] for x, y, w, h in self.bounding_boxes]
blurred_boxes = self.aug.augment_images(bbox_regions)
new_img = np.array(img)
for (x, y, w, h), blurred in zip(self.bounding_boxes, blurred_boxes):
new_img[y:y+h, x:x+w] = blurred
new_img[y : y + h, x : x + w] = blurred
return new_img
def apply_segmentation(self, segmentation):
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
""" Apply no transform on the full-image segmentation. """
return segmentation
def apply_coords(self, coords):
def apply_coords(self, coords: np.ndarray):
""" Apply no transform on the coordinates. """
return coords
......@@ -37,31 +47,42 @@ class LocalizedBoxMotionBlurTransform(Transform):
""" The inverse is a No-op, only for geometric transforms. """
return NoOpTransform()
class LocalizedBoxMotionBlur(aug.Augmentation):
"""
Performs faked motion blur on bounding box annotations in an image.
Randomly selects motion blur parameters from the ranges `k`, `angle`, `direction`.
"""
def __init__(self, prob=0.5, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
def __init__(
self,
prob: float = 0.5,
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
super().__init__()
self._init(locals())
def _validate_bbox_xywh_within_bounds(self, bbox, img_h, img_w):
def _validate_bbox_xywh_within_bounds(
self, bbox: List[int], img_h: int, img_w: int
):
x, y, w, h = bbox
assert x >= 0, f"Invalid x {x}"
assert y >= 0, f"Invalid y {x}"
assert y+h <= img_h, f"Invalid right {x+w} (img width {img_w})"
assert y+h <= img_h, f"Invalid bottom {y+h} (img height {img_h})"
assert y + h <= img_h, f"Invalid right {x+w} (img width {img_w})"
assert y + h <= img_h, f"Invalid bottom {y+h} (img height {img_h})"
def get_transform(self, image, annotations):
def get_transform(self, image: np.ndarray, annotations: List[Dict]) -> Transform:
do_tfm = self._rand_range() < self.prob
if do_tfm:
return self._get_blur_transform(image, annotations)
else:
return NoOpTransform()
def _get_blur_transform(self, image, annotations):
def _get_blur_transform(
self, image: np.ndarray, annotations: List[Dict]
) -> Transform:
"""
Return a `Transform` that simulates motion blur within the image's bounding box regions.
"""
......@@ -78,9 +99,12 @@ class LocalizedBoxMotionBlur(aug.Augmentation):
direction=self.direction,
)
# example repr: "LocalizedBoxMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train):
def RandomLocalizedBoxMotionBlurOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Transform]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......@@ -88,7 +112,12 @@ def RandomLocalizedBoxMotionBlurOp(cfg, arg_str, is_train):
class MotionBlurTransform(Transform):
def __init__(self, k=(7, 15), angle=(0, 360), direction=(-1.0, 1.0)):
def __init__(
self,
k: Tuple[float, float] = (7, 15),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
"""
Args:
will apply the specified blur to the image
......@@ -99,23 +128,29 @@ class MotionBlurTransform(Transform):
self._set_attributes(locals())
self.aug = iaa.MotionBlur(k, angle, direction, 1)
def apply_image(self, img):
def apply_image(self, img: np.ndarray) -> np.ndarray:
img = self.aug.augment_image(img)
return img
def apply_segmentation(self, segmentation):
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
return segmentation
def apply_coords(self, coords):
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
return coords
class RandomMotionBlur(aug.Augmentation):
"""
Apply random motion blur.
Apply random motion blur.
"""
def __init__(self, prob=0.5, k=(3, 7), angle=(0, 360), direction=(-1.0, 1.0)):
def __init__(
self,
prob: float = 0.5,
k: Tuple[float, float] = (3, 7),
angle: Tuple[float, float] = (0, 360),
direction: Tuple[float, float] = (-1.0, 1.0),
):
"""
Args:
prob (float): probability of applying transform
......@@ -127,7 +162,7 @@ class RandomMotionBlur(aug.Augmentation):
# Turn all locals into member variables.
self._init(locals())
def get_transform(self, img):
def get_transform(self, img: np.ndarray) -> Transform:
do = self._rand_range() < self.prob
if do:
return MotionBlurTransform(self.k, self.angle, self.direction)
......@@ -137,7 +172,7 @@ class RandomMotionBlur(aug.Augmentation):
# example repr: "RandomMotionBlurOp::{'prob': 0.5, 'k': [3,7], 'angle': [0, 360]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomMotionBlurOp(cfg, arg_str, is_train):
def RandomMotionBlurOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[Transform]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Tuple, List
import numpy as np
import torch
......@@ -18,7 +19,7 @@ def get_box_union(boxes: Boxes):
return Boxes(union_bt)
def get_box_from_mask(mask: np.ndarray):
def get_box_from_mask(mask: torch.Tensor) -> Tuple[int, int, int, int]:
"""Find if there are non-zero elements per row/column first and then find
min/max position of those elements.
Only support 2d image (h x w)
......@@ -39,7 +40,9 @@ def get_box_from_mask(mask: np.ndarray):
return cmin, rmin, cmax - cmin + 1, rmax - rmin + 1
def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio):
def get_min_box_aspect_ratio(
bbox_xywh: torch.Tensor, target_aspect_ratio: float
) -> torch.Tensor:
"""Get a minimal bbox that matches the target_aspect_ratio
target_aspect_ratio is representation by w/h
bbox are represented by pixel coordinates"""
......@@ -58,19 +61,21 @@ def get_min_box_aspect_ratio(bbox_xywh, target_aspect_ratio):
return torch.cat([new_xy, new_wh])
def get_box_center(bbox_xywh):
def get_box_center(bbox_xywh: torch.Tensor) -> torch.Tensor:
"""Get the center of the bbox"""
return torch.Tensor(bbox_xywh[:2]) + torch.Tensor(bbox_xywh[2:]) / 2.0
def get_bbox_xywh_from_center_wh(bbox_center, bbox_wh):
def get_bbox_xywh_from_center_wh(
bbox_center: torch.Tensor, bbox_wh: torch.Tensor
) -> torch.Tensor:
"""Get a bbox from bbox center and the width and height"""
bbox_wh = torch.Tensor(bbox_wh)
bbox_xy = torch.Tensor(bbox_center) - bbox_wh / 2.0
return torch.cat([bbox_xy, bbox_wh])
def get_bbox_xyxy_from_xywh(bbox_xywh):
def get_bbox_xyxy_from_xywh(bbox_xywh: torch.Tensor) -> torch.Tensor:
"""Convert the bbox from xywh format to xyxy format
bbox are represented by pixel coordinates,
the center of pixels are (x + 0.5, y + 0.5)
......@@ -85,7 +90,7 @@ def get_bbox_xyxy_from_xywh(bbox_xywh):
)
def get_bbox_xywh_from_xyxy(bbox_xyxy):
def get_bbox_xywh_from_xyxy(bbox_xyxy: torch.Tensor) -> torch.Tensor:
"""Convert the bbox from xyxy format to xywh format"""
return torch.Tensor(
[
......@@ -97,25 +102,25 @@ def get_bbox_xywh_from_xyxy(bbox_xyxy):
)
def to_boxes_from_xywh(bbox_xywh):
def to_boxes_from_xywh(bbox_xywh: torch.Tensor) -> torch.Tensor:
return Boxes(get_bbox_xyxy_from_xywh(bbox_xywh).unsqueeze(0))
def scale_bbox_center(bbox_xywh, target_scale):
def scale_bbox_center(bbox_xywh: torch.Tensor, target_scale: float) -> torch.Tensor:
"""Scale the bbox around the center of the bbox"""
box_center = get_box_center(bbox_xywh)
box_wh = torch.Tensor(bbox_xywh[2:]) * target_scale
return get_bbox_xywh_from_center_wh(box_center, box_wh)
def offset_bbox(bbox_xywh, target_offset):
def offset_bbox(bbox_xywh: torch.Tensor, target_offset: float) -> torch.Tensor:
"""Offset the bbox based on target_offset"""
box_center = get_box_center(bbox_xywh)
new_center = box_center + torch.Tensor(target_offset)
return get_bbox_xywh_from_center_wh(new_center, bbox_xywh[2:])
def clip_box_xywh(bbox_xywh, image_size_hw):
def clip_box_xywh(bbox_xywh: torch.Tensor, image_size_hw: List[int]):
"""Clip the bbox based on image_size_hw"""
h, w = image_size_hw
bbox_xyxy = get_bbox_xyxy_from_xywh(bbox_xywh)
......
......@@ -4,7 +4,9 @@
import json
import logging
from typing import List, Dict, Optional, Tuple
from detectron2.config import CfgNode
from detectron2.data import transforms as d2T
from detectron2.utils.registry import Registry
......@@ -15,7 +17,7 @@ logger = logging.getLogger(__name__)
TRANSFORM_OP_REGISTRY = Registry("D2GO_TRANSFORM_REGISTRY")
def _json_load(arg_str):
def _json_load(arg_str: str) -> Dict:
try:
return json.loads(arg_str)
except json.decoder.JSONDecodeError as e:
......@@ -25,7 +27,9 @@ def _json_load(arg_str):
# example repr: "ResizeShortestEdgeOp"
@TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeOp(cfg, arg_str, is_train):
def ResizeShortestEdgeOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[d2T.Transform]:
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
max_size = cfg.INPUT.MAX_SIZE_TRAIN
......@@ -47,9 +51,11 @@ def ResizeShortestEdgeOp(cfg, arg_str, is_train):
# example repr: "ResizeShortestEdgeSquareOp"
@TRANSFORM_OP_REGISTRY.register()
def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train):
""" Resize the input to square using INPUT.MIN_SIZE_TRAIN or INPUT.MIN_SIZE_TEST
without keeping aspect ratio
def ResizeShortestEdgeSquareOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[d2T.Transform]:
"""Resize the input to square using INPUT.MIN_SIZE_TRAIN or INPUT.MIN_SIZE_TEST
without keeping aspect ratio
"""
if is_train:
min_size = cfg.INPUT.MIN_SIZE_TRAIN
......@@ -67,7 +73,7 @@ def ResizeShortestEdgeSquareOp(cfg, arg_str, is_train):
@TRANSFORM_OP_REGISTRY.register()
def ResizeOp(cfg, arg_str, is_train):
def ResizeOp(cfg: CfgNode, arg_str: str, is_train: bool) -> List[d2T.Transform]:
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [d2T.Resize(**kwargs)]
......@@ -76,7 +82,7 @@ def ResizeOp(cfg, arg_str, is_train):
_TRANSFORM_REPR_SEPARATOR = "::"
def parse_tfm_gen_repr(tfm_gen_repr):
def parse_tfm_gen_repr(tfm_gen_repr: str) -> Tuple[str, Optional[str]]:
if tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 0:
return tfm_gen_repr, None
elif tfm_gen_repr.count(_TRANSFORM_REPR_SEPARATOR) == 1:
......@@ -88,7 +94,7 @@ def parse_tfm_gen_repr(tfm_gen_repr):
)
def build_transform_gen(cfg, is_train):
def build_transform_gen(cfg: CfgNode, is_train: bool) -> List[d2T.Transform]:
"""
This function builds a list of TransformGen or Transform objects using the a list of
strings from cfg.D2GO_DATA.AUG_OPS.TRAIN/TEST. Each string (aka. `tfm_gen_repr`)
......
......@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List
from typing import List, Callable, Union
import detectron2.data.transforms.augmentation as aug
import numpy as np
......@@ -22,7 +22,7 @@ class InvertibleColorTransform(Transform):
coordinates such as bounding boxes should not be changed)
"""
def __init__(self, op, inverse_op):
def __init__(self, op: Callable, inverse_op: Callable):
"""
Args:
op (Callable): operation to be applied to the image,
......@@ -35,16 +35,16 @@ class InvertibleColorTransform(Transform):
super().__init__()
self._set_attributes(locals())
def apply_image(self, img):
def apply_image(self, img: np.ndarray) -> np.ndarray:
return self.op(img)
def apply_coords(self, coords):
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
return coords
def inverse(self):
def inverse(self) -> Transform:
return InvertibleColorTransform(self.inverse_op, self.op)
def apply_segmentation(self, segmentation):
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
return segmentation
......@@ -60,7 +60,7 @@ class RandomContrastYUV(aug.Augmentation):
super().__init__()
self._init(locals())
def get_transform(self, img):
def get_transform(self, img: np.ndarray) -> Transform:
w = np.random.uniform(self.intensity_min, self.intensity_max)
pure_gray = np.zeros_like(img)
pure_gray[:, :, 0] = 0.5
......@@ -77,7 +77,7 @@ class RandomSaturationYUV(aug.Augmentation):
super().__init__()
self._init(locals())
def get_transform(self, img):
def get_transform(self, img: np.ndarray) -> Transform:
assert (
len(img.shape) == 3 and img.shape[-1] == 3
), f"Expected (H, W, 3), image shape {img.shape}"
......@@ -87,7 +87,7 @@ class RandomSaturationYUV(aug.Augmentation):
return BlendTransform(src_image=grayscale, src_weight=1 - w, dst_weight=w)
def convert_rgb_to_yuv_bt601(image):
def convert_rgb_to_yuv_bt601(image: np.ndarray) -> np.ndarray:
"""Convert RGB image in (H, W, C) to YUV format
image: range 0 ~ 255
"""
......@@ -96,7 +96,7 @@ def convert_rgb_to_yuv_bt601(image):
return image
def convery_yuv_bt601_to_rgb(image):
def convery_yuv_bt601_to_rgb(image: np.ndarray) -> np.ndarray:
return du.convert_image_to_rgb(image, "YUV-BT.601")
......@@ -107,7 +107,7 @@ class RGB2YUVBT601(aug.Augmentation):
convert_rgb_to_yuv_bt601, convery_yuv_bt601_to_rgb
)
def get_transform(self, image):
def get_transform(self, image) -> Transform:
return self.trans
......@@ -118,11 +118,13 @@ class YUVBT6012RGB(aug.Augmentation):
convery_yuv_bt601_to_rgb, convert_rgb_to_yuv_bt601
)
def get_transform(self, image):
def get_transform(self, image) -> Transform:
return self.trans
def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augmentation]:
def build_func(
cfg: CfgNode, arg_str: str, is_train: bool, obj
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......@@ -130,20 +132,28 @@ def build_func(cfg: CfgNode, arg_str: str, is_train: bool, obj) -> List[aug.Augm
@TRANSFORM_OP_REGISTRY.register()
def RandomContrastYUVOp(cfg, arg_str, is_train):
def RandomContrastYUVOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RandomContrastYUV)
@TRANSFORM_OP_REGISTRY.register()
def RandomSaturationYUVOp(cfg, arg_str, is_train):
def RandomSaturationYUVOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RandomSaturationYUV)
@TRANSFORM_OP_REGISTRY.register()
def RGB2YUVBT601Op(cfg, arg_str, is_train):
def RGB2YUVBT601Op(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=RGB2YUVBT601)
@TRANSFORM_OP_REGISTRY.register()
def YUVBT6012RGBOp(cfg, arg_str, is_train):
def YUVBT6012RGBOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
return build_func(cfg, arg_str, is_train, obj=YUVBT6012RGB)
......@@ -3,9 +3,11 @@
import math
from typing import List, Optional, Tuple, Union, Any
import detectron2.data.transforms.augmentation as aug
import numpy as np
from detectron2.config import CfgNode
from detectron2.data.transforms import ExtentTransform, CropTransform
from detectron2.structures import BoxMode
......@@ -22,7 +24,7 @@ class CropBoundary(aug.Augmentation):
super().__init__()
self.count = count
def get_transform(self, image):
def get_transform(self, image: np.ndarray) -> Transform:
img_h, img_w = image.shape[:2]
assert self.count < img_h and self.count < img_w
assert img_h > self.count * 2
......@@ -32,13 +34,22 @@ class CropBoundary(aug.Augmentation):
class PadTransform(Transform):
def __init__(self, x0, y0, w, h, org_w, org_h, pad_mode="constant"):
def __init__(
self,
x0: int,
y0: int,
w: int,
h: int,
org_w: int,
org_h: int,
pad_mode: str = "constant",
):
super().__init__()
assert x0 + w <= org_w
assert y0 + h <= org_h
self._set_attributes(locals())
def apply_image(self, img):
def apply_image(self, img: np.ndarray) -> np.array:
"""img: HxWxC or HxW"""
assert len(img.shape) == 2 or len(img.shape) == 3
assert img.shape[0] == self.h and img.shape[1] == self.w
......@@ -64,12 +75,12 @@ InvertibleCropTransform = CropTransform
class PadBorderDivisible(aug.Augmentation):
def __init__(self, size_divisibility, pad_mode="constant"):
def __init__(self, size_divisibility: int, pad_mode: str = "constant"):
super().__init__()
self.size_divisibility = size_divisibility
self.pad_mode = pad_mode
def get_transform(self, image):
def get_transform(self, image: np.ndarray) -> Transform:
""" image: HxWxC """
assert len(image.shape) == 3 and image.shape[2] in [1, 3]
H, W = image.shape[:2]
......@@ -80,7 +91,10 @@ class PadBorderDivisible(aug.Augmentation):
class RandomCropFixedAspectRatio(aug.Augmentation):
def __init__(
self, crop_aspect_ratios_list, scale_range=None, offset_scale_range=None
self,
crop_aspect_ratios_list: List[float],
scale_range: Optional[Union[List, Tuple]] = None,
offset_scale_range: Optional[Union[List, Tuple]] = None,
):
super().__init__()
assert isinstance(crop_aspect_ratios_list, (list, tuple))
......@@ -103,21 +117,21 @@ class RandomCropFixedAspectRatio(aug.Augmentation):
self.rng = np.random.default_rng()
def _pick_aspect_ratio(self):
def _pick_aspect_ratio(self) -> float:
return self.rng.choice(self.crop_aspect_ratios_list)
def _pick_scale(self):
def _pick_scale(self) -> float:
if self.scale_range is None:
return 1.0
return self.rng.uniform(*self.scale_range)
def _pick_offset(self, box_w, box_h):
def _pick_offset(self, box_w: float, box_h: float) -> Tuple[float, float]:
if self.offset_scale_range is None:
return [0, 0]
offset_scale = self.rng.uniform(*self.offset_scale_range, size=2)
return offset_scale[0] * box_w, offset_scale[1] * box_h
def get_transform(self, image, sem_seg):
def get_transform(self, image: np.ndarray, sem_seg: np.ndarray) -> Transform:
# HWC or HW for image, HW for sem_seg
assert len(image.shape) in [2, 3]
assert len(sem_seg.shape) == 2
......@@ -148,7 +162,9 @@ class RandomCropFixedAspectRatio(aug.Augmentation):
# example repr: "CropBoundaryOp::{'count': 3}"
@TRANSFORM_OP_REGISTRY.register()
def CropBoundaryOp(cfg, arg_str, is_train):
def CropBoundaryOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......@@ -157,7 +173,9 @@ def CropBoundaryOp(cfg, arg_str, is_train):
# example repr: "RandomCropFixedAspectRatioOp::{'crop_aspect_ratios_list': [0.5], 'scale_range': [0.8, 1.2], 'offset_scale_range': [-0.3, 0.3]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train):
def RandomCropFixedAspectRatioOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......@@ -165,7 +183,7 @@ def RandomCropFixedAspectRatioOp(cfg, arg_str, is_train):
class RandomInstanceCrop(aug.Augmentation):
def __init__(self, crop_scale=(0.8, 1.6)):
def __init__(self, crop_scale: Tuple[float, float] = (0.8, 1.6)):
"""
Generates a CropTransform centered around the instance.
crop_scale: [low, high] relative crop scale around the instance, this
......@@ -177,7 +195,7 @@ class RandomInstanceCrop(aug.Augmentation):
isinstance(crop_scale, (list, tuple)) and len(crop_scale) == 2
), crop_scale
def get_transform(self, image, annotations):
def get_transform(self, image: np.ndarray, annotations: List[Any]) -> Transform:
"""
This function will modify instances to set the iscrowd flag to 1 for
annotations not picked. It relies on the dataset mapper to filter those
......@@ -215,22 +233,24 @@ class RandomInstanceCrop(aug.Augmentation):
# example repr: "RandomInstanceCropOp::{'crop_scale': [0.8, 1.6]}"
@TRANSFORM_OP_REGISTRY.register()
def RandomInstanceCropOp(cfg, arg_str, is_train):
def RandomInstanceCropOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandomInstanceCrop(**kwargs)]
class CropBoxAug(aug.Augmentation):
""" Augmentation to crop the image based on boxes
Scale the box with `box_scale_factor` around the center before cropping
"""Augmentation to crop the image based on boxes
Scale the box with `box_scale_factor` around the center before cropping
"""
def __init__(self, box_scale_factor=1.0):
def __init__(self, box_scale_factor: float = 1.0):
super().__init__()
self.box_scale_factor = box_scale_factor
def get_transform(self, image: np.ndarray, boxes: np.ndarray):
def get_transform(self, image: np.ndarray, boxes: np.ndarray) -> Transform:
# boxes: 1 x 4 in xyxy format
assert boxes.shape[0] == 1
assert isinstance(image, np.ndarray)
......@@ -239,7 +259,7 @@ class CropBoxAug(aug.Augmentation):
box_xywh = bu.get_bbox_xywh_from_xyxy(boxes[0])
if self.box_scale_factor != 1.0:
box_xywh = bu.scale_bbox_center(box_xywh, self.box_scale_factor)
box_xywh = bu.scale_bbox_center(box_xywh, self.box_scale_factor)
box_xywh = bu.clip_box_xywh(box_xywh, [img_h, img_w])
box_xywh = box_xywh.int().tolist()
return CropTransform(*box_xywh, orig_w=img_w, orig_h=img_h)
......@@ -3,11 +3,15 @@
import logging
from typing import List, Union
from .build import TRANSFORM_OP_REGISTRY, _json_load
import detectron2.data.transforms.augmentation as aug
from detectron2.config import CfgNode
from detectron2.data import transforms as d2T
from detectron2.projects.point_rend import ColorAugSSDTransform
from .build import TRANSFORM_OP_REGISTRY, _json_load
logger = logging.getLogger(__name__)
......@@ -23,8 +27,10 @@ D2_RANDOM_TRANSFORMS = {
}
def build_func(cfg, arg_str, is_train, name):
assert is_train
def build_func(
cfg: CfgNode, arg_str: str, is_train: bool, name: str
) -> List[Union[aug.Augmentation, d2T.Transform]]:
assert is_train, "Random augmentation is for training only"
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [D2_RANDOM_TRANSFORMS[name](**kwargs)]
......@@ -35,47 +41,65 @@ def build_func(cfg, arg_str, is_train, name):
# example 3: RandomFlipOp::{"prob":0.5}
# example 4: RandomBrightnessOp::{"intensity_min":1.0, "intensity_max":2.0}
@TRANSFORM_OP_REGISTRY.register()
def RandomBrightnessOp(cfg, arg_str, is_train):
def RandomBrightnessOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomBrightness")
@TRANSFORM_OP_REGISTRY.register()
def RandomContrastOp(cfg, arg_str, is_train):
def RandomContrastOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomContrast")
@TRANSFORM_OP_REGISTRY.register()
def RandomCropOp(cfg, arg_str, is_train):
def RandomCropOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomCrop")
@TRANSFORM_OP_REGISTRY.register()
def RandomRotation(cfg, arg_str, is_train):
def RandomRotation(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomRotation")
@TRANSFORM_OP_REGISTRY.register()
def RandomExtentOp(cfg, arg_str, is_train):
def RandomExtentOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomExtent")
@TRANSFORM_OP_REGISTRY.register()
def RandomFlipOp(cfg, arg_str, is_train):
def RandomFlipOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomFlip")
@TRANSFORM_OP_REGISTRY.register()
def RandomSaturationOp(cfg, arg_str, is_train):
def RandomSaturationOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomSaturation")
@TRANSFORM_OP_REGISTRY.register()
def RandomLightingOp(cfg, arg_str, is_train):
def RandomLightingOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
return build_func(cfg, arg_str, is_train, name="RandomLighting")
@TRANSFORM_OP_REGISTRY.register()
def RandomSSDColorAugOp(cfg, arg_str, is_train):
def RandomSSDColorAugOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, d2T.Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
......
......@@ -2,7 +2,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Union
from typing import List, Optional, Union, Any
import numpy as np
import torch
......@@ -51,7 +51,7 @@ class AugInput:
def apply_augmentations(
self, augmentations: List[Union[Augmentation, Transform]]
) -> TransformList:
) -> AugmentationList:
"""
Equivalent of ``AugmentationList(augmentations)(self)``
"""
......@@ -70,14 +70,14 @@ class Tensor2Array(Transform):
assert len(img.shape) == 3, img.shape
return img.cpu().numpy().transpose(1, 2, 0)
def apply_coords(self, coords):
def apply_coords(self, coords: Any) -> Any:
return coords
def apply_segmentation(self, segmentation: torch.Tensor) -> np.ndarray:
assert len(segmentation.shape) == 2, segmentation.shape
return segmentation.cpu().numpy()
def inverse(self):
def inverse(self) -> Transform:
return Array2Tensor()
......@@ -93,12 +93,12 @@ class Array2Tensor(Transform):
assert len(img.shape) == 3, img.shape
return torch.from_numpy(img.transpose(2, 0, 1).astype("float32"))
def apply_coords(self, coords):
def apply_coords(self, coords: Any) -> Any:
return coords
def apply_segmentation(self, segmentation: np.ndarray) -> torch.Tensor:
assert len(segmentation.shape) == 2, segmentation.shape
return torch.from_numpy(segmentation.astype("long"))
def inverse(self):
def inverse(self) -> Transform:
return Tensor2Array()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment