Commit 6f036248 authored by Peizhao Zhang's avatar Peizhao Zhang Committed by Facebook GitHub Bot
Browse files

Added randaug and trivialaug to d2go.

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/259

Added randaug and trivialaug to d2go.
* Only apply to images.

Reviewed By: wat3rBro

Differential Revision: D36605753

fbshipit-source-id: f72ea6f3de65ab115e4c6612343e93a092cbd7ea
parent 4c643114
...@@ -4,4 +4,4 @@ ...@@ -4,4 +4,4 @@
# import all modules to make sure Registry works # import all modules to make sure Registry works
# @fb-only: from . import fb # isort:skip # noqa # @fb-only: from . import fb # isort:skip # noqa
from . import affine, blur, box_utils, color_yuv, crop, d2_native # noqa from . import affine, auto_aug, blur, box_utils, color_yuv, crop, d2_native # noqa
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import List, Optional, Union
import detectron2.data.transforms.augmentation as aug
import numpy as np
import torchvision.transforms as tvtf
from d2go.data.transforms.tensor import Array2Tensor, Tensor2Array
from detectron2.config import CfgNode
from fvcore.transforms.transform import Transform
from .build import _json_load, TRANSFORM_OP_REGISTRY
class ToTensorWrapper:
def __init__(self, transform):
self.a2t = Array2Tensor(preserve_dtype=True)
self.transform = transform
self.t2a = Tensor2Array()
def __call__(self, img: np.ndarray):
return self.t2a.apply_image(self.transform(self.a2t.apply_image(img)))
class RandAugmentImage(Transform):
"""Rand Augment transform, only support image transformation"""
def __init__(
self,
num_ops: int = 2,
magnitude: int = 9,
num_magnitude_bins: int = 31,
interpolation: tvtf.functional.InterpolationMode = tvtf.functional.InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
):
transform = tvtf.RandAugment(
num_ops, magnitude, num_magnitude_bins, interpolation, fill
)
self.transform = ToTensorWrapper(transform)
def apply_image(self, img: np.ndarray) -> np.array:
assert (
img.dtype == np.uint8
), f"Only uint8 image format is supported, got {img.dtype}"
return self.transform(img)
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
raise NotImplementedError()
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
raise NotImplementedError()
class TrivialAugmentWideImage(Transform):
"""TrivialAugmentWide transform, only support image transformation"""
def __init__(
self,
num_magnitude_bins: int = 31,
interpolation: tvtf.functional.InterpolationMode = tvtf.functional.InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
):
transform = tvtf.TrivialAugmentWide(num_magnitude_bins, interpolation, fill)
self.transform = ToTensorWrapper(transform)
def apply_image(self, img: np.ndarray) -> np.array:
assert (
img.dtype == np.uint8
), f"Only uint8 image format is supported, got {img.dtype}"
return self.transform(img)
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
raise NotImplementedError()
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
raise NotImplementedError()
# example repr: 'RandAugmentImageOp::{"magnitude": 9}'
@TRANSFORM_OP_REGISTRY.register()
def RandAugmentImageOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [RandAugmentImage(**kwargs)]
# example repr: 'TrivialAugmentWideImageOp::{"num_magnitude_bins": 31}'
@TRANSFORM_OP_REGISTRY.register()
def TrivialAugmentWideImageOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [TrivialAugmentWideImage(**kwargs)]
...@@ -8,7 +8,7 @@ import numpy as np ...@@ -8,7 +8,7 @@ import numpy as np
import torch import torch
from detectron2.data.transforms.augmentation import Augmentation, AugmentationList from detectron2.data.transforms.augmentation import Augmentation, AugmentationList
from detectron2.structures import Boxes from detectron2.structures import Boxes
from fvcore.transforms.transform import Transform, TransformList from fvcore.transforms.transform import Transform
class AugInput: class AugInput:
...@@ -84,14 +84,26 @@ class Tensor2Array(Transform): ...@@ -84,14 +84,26 @@ class Tensor2Array(Transform):
class Array2Tensor(Transform): class Array2Tensor(Transform):
"""Convert image np array (HWC) to torch tensor (CHW)""" """Convert image np array (HWC) to torch tensor (CHW)"""
def __init__(self): def __init__(self, preserve_dtype: bool = False):
"""
preserve_dtype: always convert to float32 if False
"""
super().__init__() super().__init__()
self.preserve_dtype = preserve_dtype
def apply_image(self, img: np.ndarray) -> torch.Tensor: def apply_image(self, img: np.ndarray) -> torch.Tensor:
# HWC -> CHW # HW(C) -> CHW
assert isinstance(img, np.ndarray) assert isinstance(img, np.ndarray)
assert len(img.shape) == 3, img.shape assert len(img.shape) in [2, 3], img.shape
return torch.from_numpy(img.transpose(2, 0, 1).astype("float32"))
if len(img.shape) == 2:
# HW -> HWC
img = np.expand_dims(img, axis=2)
if not self.preserve_dtype:
img = img.astype("float32")
return torch.from_numpy(img.transpose(2, 0, 1))
def apply_coords(self, coords: Any) -> Any: def apply_coords(self, coords: Any) -> Any:
return coords return coords
......
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import unittest
import numpy as np
from d2go.data.transforms.build import build_transform_gen
from d2go.runner import Detectron2GoRunner
from detectron2.data.transforms import apply_augmentations
class TestDataTransformsAutoAug(unittest.TestCase):
def test_rand_aug_transforms(self):
default_cfg = Detectron2GoRunner().get_default_cfg()
img = np.concatenate(
[
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
],
axis=2,
)
default_cfg.D2GO_DATA.AUG_OPS.TRAIN = ['RandAugmentImageOp::{"num_ops": 20}']
tfm = build_transform_gen(default_cfg, is_train=True)
trans_img, _ = apply_augmentations(tfm, img)
self.assertEqual(img.shape, trans_img.shape)
self.assertEqual(img.dtype, trans_img.dtype)
def test_trivial_aug_transforms(self):
default_cfg = Detectron2GoRunner().get_default_cfg()
img = np.concatenate(
[
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
],
axis=2,
)
default_cfg.D2GO_DATA.AUG_OPS.TRAIN = ["TrivialAugmentWideImageOp"]
tfm = build_transform_gen(default_cfg, is_train=True)
trans_img, _ = apply_augmentations(tfm, img)
self.assertEqual(img.shape, trans_img.shape)
self.assertEqual(img.dtype, trans_img.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment