Commit 1e205552 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345529749
parent 495fbc4a
...@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig): ...@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32' dtype: str = 'float32'
shuffle_buffer_size: int = 10000 shuffle_buffer_size: int = 10000
cycle_length: int = 10 cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Classification decoder and parser.""" """Classification decoder and parser."""
from typing import List, Optional
# Import libraries # Import libraries
import tensorflow as tf import tensorflow as tf
from official.vision.beta.dataloaders import decoder from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import augment
from official.vision.beta.ops import preprocess_ops from official.vision.beta.ops import preprocess_ops
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
...@@ -43,18 +45,20 @@ class Parser(parser.Parser): ...@@ -43,18 +45,20 @@ class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors.""" """Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self, def __init__(self,
output_size, output_size: List[int],
num_classes, num_classes: float,
aug_rand_hflip=True, aug_rand_hflip: bool = True,
dtype='float32'): aug_policy: Optional[str] = None,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
Args: Args:
output_size: `Tenssor` or `list` for [height, width] of output image. The output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level. output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes. num_classes: `float`, number of classes.
aug_rand_hflip: `bool`, if True, augment training with random aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip. horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'. or 'bfloat16'.
""" """
...@@ -69,6 +73,16 @@ class Parser(parser.Parser): ...@@ -69,6 +73,16 @@ class Parser(parser.Parser):
self._dtype = tf.bfloat16 self._dtype = tf.bfloat16
else: else:
raise ValueError('dtype {!r} is not supported!'.format(dtype)) raise ValueError('dtype {!r} is not supported!'.format(dtype))
if aug_policy:
if aug_policy == 'autoaug':
self._augmenter = augment.AutoAugment()
elif aug_policy == 'randaug':
self._augmenter = augment.RandAugment(num_layers=2, magnitude=20)
else:
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_policy))
else:
self._augmenter = None
def _parse_train_data(self, decoded_tensors): def _parse_train_data(self, decoded_tensors):
"""Parses data for training.""" """Parses data for training."""
...@@ -93,6 +107,10 @@ class Parser(parser.Parser): ...@@ -93,6 +107,10 @@ class Parser(parser.Parser):
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
# Apply autoaug or randaug.
if self._augmenter is not None:
image = self._augmenter.distort(image)
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB, offset=MEAN_RGB,
......
This diff is collapsed.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for autoaugment."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.ops import augment
def get_dtype_test_cases():
return [
('uint8', tf.uint8),
('int32', tf.int32),
('float16', tf.float16),
('float32', tf.float32),
]
@parameterized.named_parameters(get_dtype_test_cases())
class TransformsTest(parameterized.TestCase, tf.test.TestCase):
"""Basic tests for fundamental transformations."""
def test_to_from_4d(self, dtype):
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
original_ndims = len(shape)
image = tf.zeros(shape, dtype=dtype)
image_4d = augment.to_4d(image)
self.assertEqual(4, tf.rank(image_4d))
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
translation = [0, 0]
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.translate(image, translation))
def test_translate_invalid_translation(self, dtype):
image = tf.zeros((1, 1), dtype=dtype)
invalid_translation = [[[1, 1]]]
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
_ = augment.translate(image, invalid_translation)
def test_rotate(self, dtype):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
degrees = 0.
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase):
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
tf.test.main()
...@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task): ...@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task):
parser = classification_input.Parser( parser = classification_input.Parser(
output_size=input_size[:2], output_size=input_size[:2],
num_classes=num_classes, num_classes=num_classes,
aug_policy=params.aug_policy,
dtype=params.dtype) dtype=params.dtype)
reader = input_reader.InputReader( reader = input_reader.InputReader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment