Commit 96674ab0 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 416886349
parent 8d41d6c0
...@@ -55,9 +55,14 @@ class Parser(hyperparams.Config): ...@@ -55,9 +55,14 @@ class Parser(hyperparams.Config):
aug_rand_hflip: bool = False aug_rand_hflip: bool = False
aug_scale_min: float = 1.0 aug_scale_min: float = 1.0
aug_scale_max: float = 1.0 aug_scale_max: float = 1.0
aug_policy: Optional[str] = None
skip_crowd_during_training: bool = True skip_crowd_during_training: bool = True
max_num_instances: int = 100 max_num_instances: int = 100
# Can choose AutoAugment and RandAugment.
# TODO(b/205346436) Support RandAugment.
aug_type: Optional[common.Augmentation] = None
# Keep for backward compatibility. Not used.
aug_policy: Optional[str] = None
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet. ...@@ -19,11 +19,13 @@ into (image, labels) tuple for RetinaNet.
""" """
# Import libraries # Import libraries
from absl import logging
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import parser from official.vision.beta.dataloaders import parser
from official.vision.beta.dataloaders import utils from official.vision.beta.dataloaders import utils
from official.vision.beta.ops import anchor from official.vision.beta.ops import anchor
from official.vision.beta.ops import augment
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
from official.vision.beta.ops import preprocess_ops from official.vision.beta.ops import preprocess_ops
...@@ -40,6 +42,7 @@ class Parser(parser.Parser): ...@@ -40,6 +42,7 @@ class Parser(parser.Parser):
anchor_size, anchor_size,
match_threshold=0.5, match_threshold=0.5,
unmatched_threshold=0.5, unmatched_threshold=0.5,
aug_type=None,
aug_rand_hflip=False, aug_rand_hflip=False,
aug_scale_min=1.0, aug_scale_min=1.0,
aug_scale_max=1.0, aug_scale_max=1.0,
...@@ -71,6 +74,8 @@ class Parser(parser.Parser): ...@@ -71,6 +74,8 @@ class Parser(parser.Parser):
unmatched_threshold: `float` number between 0 and 1 representing the unmatched_threshold: `float` number between 0 and 1 representing the
upper-bound threshold to assign negative labels for anchors. An anchor upper-bound threshold to assign negative labels for anchors. An anchor
with a score below the threshold is labeled negative. with a score below the threshold is labeled negative.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. The latter is not supported, and will raise ValueError.
aug_rand_hflip: `bool`, if True, augment training with random horizontal aug_rand_hflip: `bool`, if True, augment training with random horizontal
flip. flip.
aug_scale_min: `float`, the minimum scale applied to `output_size` for aug_scale_min: `float`, the minimum scale applied to `output_size` for
...@@ -108,7 +113,20 @@ class Parser(parser.Parser): ...@@ -108,7 +113,20 @@ class Parser(parser.Parser):
self._aug_scale_min = aug_scale_min self._aug_scale_min = aug_scale_min
self._aug_scale_max = aug_scale_max self._aug_scale_max = aug_scale_max
# Data Augmentation with AutoAugment. # Data augmentation with AutoAugment or RandAugment.
self._augmenter = None
if aug_type is not None:
if aug_type.type == 'autoaug':
logging.info('Using AutoAugment.')
self._augmenter = augment.AutoAugment(
augmentation_name=aug_type.autoaug.augmentation_name,
cutout_const=aug_type.autoaug.cutout_const,
translate_const=aug_type.autoaug.translate_const)
else:
# TODO(b/205346436) Support RandAugment.
raise ValueError(f'Augmentation policy {aug_type.type} not supported.')
# Deprecated. Data Augmentation with AutoAugment.
self._use_autoaugment = use_autoaugment self._use_autoaugment = use_autoaugment
self._autoaugment_policy_name = autoaugment_policy_name self._autoaugment_policy_name = autoaugment_policy_name
...@@ -138,9 +156,13 @@ class Parser(parser.Parser): ...@@ -138,9 +156,13 @@ class Parser(parser.Parser):
for k, v in attributes.items(): for k, v in attributes.items():
attributes[k] = tf.gather(v, indices) attributes[k] = tf.gather(v, indices)
# Gets original image and its size. # Gets original image.
image = data['image'] image = data['image']
# Apply autoaug or randaug.
if self._augmenter is not None:
image, boxes = self._augmenter.distort_with_boxes(image, boxes)
image_shape = tf.shape(input=image)[0:2] image_shape = tf.shape(input=image)[0:2]
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
......
This diff is collapsed.
...@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -95,15 +95,7 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
'reduced_cifar10', 'reduced_cifar10',
'svhn', 'svhn',
'reduced_imagenet', 'reduced_imagenet',
] 'detection_v0',
AVAILABLE_POLICIES = [
'v0',
'test',
'simple',
'reduced_cifar10',
'svhn',
'reduced_imagenet',
] ]
def test_autoaugment(self): def test_autoaugment(self):
...@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -116,6 +108,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
def test_autoaugment_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_randaug(self): def test_randaug(self):
"""Smoke test to be sure there are no syntax errors.""" """Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8) image = tf.zeros((224, 224, 3), dtype=tf.uint8)
...@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -125,6 +129,17 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((224, 224, 3), aug_image.shape) self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug_with_bboxes(self):
"""Smoke test to be sure there are no syntax errors with bboxes."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
augmenter = augment.RandAugment()
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertEqual((2, 4), aug_bboxes.shape)
def test_all_policy_ops(self): def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute.""" """Smoke test to be sure all augmentation functions can execute."""
...@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -135,14 +150,37 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250 translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8) image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_with_bboxes(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC: for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const, replace_value, cutout_const,
translate_const) translate_const)
image = func(image, *args) image, bboxes = func(image, bboxes, *args)
self.assertEqual((224, 224, 3), image.shape) self.assertEqual((224, 224, 3), image.shape)
self.assertEqual((2, 4), bboxes.shape)
def test_autoaugment_video(self): def test_autoaugment_video(self):
"""Smoke test with video to be sure there are no syntax errors.""" """Smoke test with video to be sure there are no syntax errors."""
...@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -154,6 +192,18 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
self.assertEqual((2, 224, 224, 3), aug_image.shape) self.assertEqual((2, 224, 224, 3), aug_image.shape)
def test_autoaugment_video_with_boxes(self):
"""Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for policy in self.AVAILABLE_POLICIES:
augmenter = augment.AutoAugment(augmentation_name=policy)
aug_image, aug_bboxes = augmenter.distort_with_boxes(image, bboxes)
self.assertEqual((2, 224, 224, 3), aug_image.shape)
self.assertEqual((2, 2, 4), aug_bboxes.shape)
def test_randaug_video(self): def test_randaug_video(self):
"""Smoke test with video to be sure there are no syntax errors.""" """Smoke test with video to be sure there are no syntax errors."""
image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8) image = tf.zeros((2, 224, 224, 3), dtype=tf.uint8)
...@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -173,14 +223,48 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
translate_const = 250 translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8) image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = None
for op_name in augment.NAME_TO_FUNC.keys() - augment.REQUIRE_BOXES_FUNCS:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape)
self.assertIsNone(bboxes)
def test_all_policy_ops_video_with_bboxes(self):
"""Smoke test to be sure all video augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((2, 224, 224, 3), dtype=tf.uint8)
bboxes = tf.ones((2, 2, 4), dtype=tf.float32)
for op_name in augment.NAME_TO_FUNC: for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude, func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const, replace_value, cutout_const,
translate_const) translate_const)
image = func(image, *args) if op_name in {
'Rotate_BBox',
'ShearX_BBox',
'ShearY_BBox',
'TranslateX_BBox',
'TranslateY_BBox',
'TranslateY_Only_BBoxes',
}:
with self.assertRaises(ValueError):
func(image, bboxes, *args)
else:
image, bboxes = func(image, bboxes, *args)
self.assertEqual((2, 224, 224, 3), image.shape) self.assertEqual((2, 224, 224, 3), image.shape)
self.assertEqual((2, 2, 4), bboxes.shape)
def _generate_test_policy(self): def _generate_test_policy(self):
"""Generate a test policy at random.""" """Generate a test policy at random."""
......
...@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task): ...@@ -119,6 +119,7 @@ class RetinaNetTask(base_task.Task):
dtype=params.dtype, dtype=params.dtype,
match_threshold=params.parser.match_threshold, match_threshold=params.parser.match_threshold,
unmatched_threshold=params.parser.unmatched_threshold, unmatched_threshold=params.parser.unmatched_threshold,
aug_type=params.parser.aug_type,
aug_rand_hflip=params.parser.aug_rand_hflip, aug_rand_hflip=params.parser.aug_rand_hflip,
aug_scale_min=params.parser.aug_scale_min, aug_scale_min=params.parser.aug_scale_min,
aug_scale_max=params.parser.aug_scale_max, aug_scale_max=params.parser.aug_scale_max,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment