"...python/git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "82e6c3a65ab3701c3ef498bc51fbe447e8c6cbe5"
Commit 55c284c8 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 345529749
parent 318fa0af
......@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
dtype: str = 'float32'
shuffle_buffer_size: int = 10000
cycle_length: int = 10
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'
@dataclasses.dataclass
......
......@@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Classification decoder and parser."""
from typing import List, Optional
# Import libraries
import tensorflow as tf
from official.vision.beta.dataloaders import decoder
from official.vision.beta.dataloaders import parser
from official.vision.beta.ops import augment
from official.vision.beta.ops import preprocess_ops
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
......@@ -43,18 +45,20 @@ class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def __init__(self,
output_size,
num_classes,
aug_rand_hflip=True,
dtype='float32'):
output_size: List[int],
num_classes: float,
aug_rand_hflip: bool = True,
aug_policy: Optional[str] = None,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tenssor` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
......@@ -69,6 +73,16 @@ class Parser(parser.Parser):
self._dtype = tf.bfloat16
else:
raise ValueError('dtype {!r} is not supported!'.format(dtype))
if aug_policy:
if aug_policy == 'autoaug':
self._augmenter = augment.AutoAugment()
elif aug_policy == 'randaug':
self._augmenter = augment.RandAugment(num_layers=2, magnitude=20)
else:
raise ValueError(
'Augmentation policy {} not supported.'.format(aug_policy))
else:
self._augmenter = None
def _parse_train_data(self, decoded_tensors):
"""Parses data for training."""
......@@ -93,6 +107,10 @@ class Parser(parser.Parser):
image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
# Apply autoaug or randaug.
if self._augmenter is not None:
image = self._augmenter.distort(image)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
......
This diff is collapsed.
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for autoaugment."""
from __future__ import absolute_import
from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
from absl.testing import parameterized
import tensorflow as tf
from official.vision.beta.ops import augment
def get_dtype_test_cases():
return [
('uint8', tf.uint8),
('int32', tf.int32),
('float16', tf.float16),
('float32', tf.float32),
]
@parameterized.named_parameters(get_dtype_test_cases())
class TransformsTest(parameterized.TestCase, tf.test.TestCase):
"""Basic tests for fundamental transformations."""
def test_to_from_4d(self, dtype):
for shape in [(10, 10), (10, 10, 10), (10, 10, 10, 10)]:
original_ndims = len(shape)
image = tf.zeros(shape, dtype=dtype)
image_4d = augment.to_4d(image)
self.assertEqual(4, tf.rank(image_4d))
self.assertAllEqual(image, augment.from_4d(image_4d, original_ndims))
def test_transform(self, dtype):
image = tf.constant([[1, 2], [3, 4]], dtype=dtype)
self.assertAllEqual(
augment.transform(image, transforms=[1] * 8), [[4, 4], [4, 4]])
def test_translate(self, dtype):
image = tf.constant(
[[1, 0, 1, 0], [0, 1, 0, 1], [1, 0, 1, 0], [0, 1, 0, 1]], dtype=dtype)
translations = [-1, -1]
translated = augment.translate(image=image, translations=translations)
expected = [[1, 0, 1, 1], [0, 1, 0, 0], [1, 0, 1, 1], [1, 0, 1, 1]]
self.assertAllEqual(translated, expected)
def test_translate_shapes(self, dtype):
translation = [0, 0]
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.translate(image, translation))
def test_translate_invalid_translation(self, dtype):
image = tf.zeros((1, 1), dtype=dtype)
invalid_translation = [[[1, 1]]]
with self.assertRaisesRegex(TypeError, 'rank 1 or 2'):
_ = augment.translate(image, invalid_translation)
def test_rotate(self, dtype):
image = tf.reshape(tf.cast(tf.range(9), dtype), (3, 3))
rotation = 90.
transformed = augment.rotate(image=image, degrees=rotation)
expected = [[2, 5, 8], [1, 4, 7], [0, 3, 6]]
self.assertAllEqual(transformed, expected)
def test_rotate_shapes(self, dtype):
degrees = 0.
for shape in [(3, 3), (5, 5), (224, 224, 3)]:
image = tf.zeros(shape, dtype=dtype)
self.assertAllEqual(image, augment.rotate(image, degrees))
class AutoaugmentTest(tf.test.TestCase):
def test_autoaugment(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.AutoAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_randaug(self):
"""Smoke test to be sure there are no syntax errors."""
image = tf.zeros((224, 224, 3), dtype=tf.uint8)
augmenter = augment.RandAugment()
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
def test_all_policy_ops(self):
"""Smoke test to be sure all augmentation functions can execute."""
prob = 1
magnitude = 10
replace_value = [128] * 3
cutout_const = 100
translate_const = 250
image = tf.ones((224, 224, 3), dtype=tf.uint8)
for op_name in augment.NAME_TO_FUNC:
func, _, args = augment._parse_policy_info(op_name, prob, magnitude,
replace_value, cutout_const,
translate_const)
image = func(image, *args)
self.assertEqual((224, 224, 3), image.shape)
if __name__ == '__main__':
tf.test.main()
......@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task):
parser = classification_input.Parser(
output_size=input_size[:2],
num_classes=num_classes,
aug_policy=params.aug_policy,
dtype=params.dtype)
reader = input_reader.InputReader(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment