Commit ab50d5d9 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by Facebook GitHub Bot
Browse files

Add AugMix Transform in D2go

Summary:
Pull Request resolved: https://github.com/facebookresearch/d2go/pull/322

TorchVision has recently added the AugMix Augmentantion. This diff adds support of the specific transform to D2 (https://github.com/facebookresearch/d2go/commit/87374efb134e539090e0b5c476809dc35bf6aedb)go

Reviewed By: newstzpz

Differential Revision: D37578243

fbshipit-source-id: b793715ccb24a3bd999a40c51d8c9a75f22110a3
parent bbc14dd7
...@@ -78,6 +78,37 @@ class TrivialAugmentWideImage(Transform): ...@@ -78,6 +78,37 @@ class TrivialAugmentWideImage(Transform):
raise NotImplementedError() raise NotImplementedError()
class AugMixImage(Transform):
"""AugMix transform, only support image transformation"""
def __init__(
self,
severity: int = 3,
mixture_width: int = 3,
chain_depth: int = -1,
alpha: float = 1.0,
all_ops: bool = True,
interpolation: tvtf.functional.InterpolationMode = tvtf.functional.InterpolationMode.NEAREST,
fill: Optional[List[float]] = None,
):
transform = tvtf.AugMix(
severity, mixture_width, chain_depth, alpha, all_ops, interpolation, fill
)
self.transform = ToTensorWrapper(transform)
def apply_image(self, img: np.ndarray) -> np.array:
assert (
img.dtype == np.uint8
), f"Only uint8 image format is supported, got {img.dtype}"
return self.transform(img)
def apply_coords(self, coords: np.ndarray) -> np.ndarray:
raise NotImplementedError()
def apply_segmentation(self, segmentation: np.ndarray) -> np.ndarray:
raise NotImplementedError()
# example repr: 'RandAugmentImageOp::{"magnitude": 9}' # example repr: 'RandAugmentImageOp::{"magnitude": 9}'
@TRANSFORM_OP_REGISTRY.register() @TRANSFORM_OP_REGISTRY.register()
def RandAugmentImageOp( def RandAugmentImageOp(
...@@ -98,3 +129,14 @@ def TrivialAugmentWideImageOp( ...@@ -98,3 +129,14 @@ def TrivialAugmentWideImageOp(
kwargs = _json_load(arg_str) if arg_str is not None else {} kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict) assert isinstance(kwargs, dict)
return [TrivialAugmentWideImage(**kwargs)] return [TrivialAugmentWideImage(**kwargs)]
# example repr: 'AugMixImageOp::{"severity": 3}'
@TRANSFORM_OP_REGISTRY.register()
def AugMixImageOp(
cfg: CfgNode, arg_str: str, is_train: bool
) -> List[Union[aug.Augmentation, Transform]]:
assert is_train
kwargs = _json_load(arg_str) if arg_str is not None else {}
assert isinstance(kwargs, dict)
return [AugMixImage(**kwargs)]
...@@ -44,3 +44,21 @@ class TestDataTransformsAutoAug(unittest.TestCase): ...@@ -44,3 +44,21 @@ class TestDataTransformsAutoAug(unittest.TestCase):
self.assertEqual(img.shape, trans_img.shape) self.assertEqual(img.shape, trans_img.shape)
self.assertEqual(img.dtype, trans_img.dtype) self.assertEqual(img.dtype, trans_img.dtype)
def test_aug_mix_transforms(self):
default_cfg = Detectron2GoRunner().get_default_cfg()
img = np.concatenate(
[
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
(np.random.uniform(0, 1, size=(80, 60, 1)) * 255).astype(np.uint8),
],
axis=2,
)
default_cfg.D2GO_DATA.AUG_OPS.TRAIN = ['AugMixImageOp::{"severity": 3}']
tfm = build_transform_gen(default_cfg, is_train=True)
trans_img, _ = apply_augmentations(tfm, img)
self.assertEqual(img.shape, trans_img.shape)
self.assertEqual(img.dtype, trans_img.dtype)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment