Commit 1c79ece9 authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Merge pull request #10227 from sigeisler:master

PiperOrigin-RevId: 397161611
parents bea8998b 01b21983
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
"""Common configurations.""" """Common configurations."""
import dataclasses import dataclasses
from typing import Optional from typing import List, Optional
# Import libraries # Import libraries
...@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config): ...@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config):
magnitude: float = 10 magnitude: float = 10
cutout_const: float = 40 cutout_const: float = 40
translate_const: float = 10 translate_const: float = 10
magnitude_std: float = 0.0
prob_to_apply: Optional[float] = None prob_to_apply: Optional[float] = None
exclude_ops: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config): ...@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config):
translate_const: float = 250 translate_const: float = 250
@dataclasses.dataclass
class RandomErasing(hyperparams.Config):
"""Configuration for RandomErasing."""
probability: float = 0.25
min_area: float = 0.02
max_area: float = 1 / 3
min_aspect: float = 0.3
max_aspect = None
min_count = 1
max_count = 1
trials = 10
@dataclasses.dataclass
class MixupAndCutmix(hyperparams.Config):
"""Configuration for MixupAndCutmix."""
mixup_alpha: float = .8
cutmix_alpha: float = 1.
prob: float = 1.0
switch_prob: float = 0.5
label_smoothing: float = 0.1
@dataclasses.dataclass @dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig): class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation. """Configuration for input data augmentation.
......
...@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig): ...@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_type: Optional[ aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment. common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
random_erasing: Optional[common.RandomErasing] = None
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True decode_jpeg_only: bool = True
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
decoder: Optional[common.DataDecoder] = common.DataDecoder() decoder: Optional[common.DataDecoder] = common.DataDecoder()
# Keep for backward compatibility. # Keep for backward compatibility.
...@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False) use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass @dataclasses.dataclass
...@@ -69,6 +73,7 @@ class Losses(hyperparams.Config): ...@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True one_hot: bool = True
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -69,6 +69,8 @@ class Parser(parser.Parser): ...@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only: bool = True, decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
dtype: str = 'float32'): dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
...@@ -85,6 +87,11 @@ class Parser(parser.Parser): ...@@ -85,6 +87,11 @@ class Parser(parser.Parser):
horizontal flip. horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
generate random scale factor for brightness, contrast and saturation.
See `preprocess_ops.color_jitter` for more details.
random_erasing: if not None, augment input image by random erasing. See
`augment.RandomErasing` for more details.
is_multilabel: A `bool`, whether or not each example has multiple labels. is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'. or 'bfloat16'.
...@@ -113,13 +120,27 @@ class Parser(parser.Parser): ...@@ -113,13 +120,27 @@ class Parser(parser.Parser):
magnitude=aug_type.randaug.magnitude, magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const, cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const, translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply) prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else: else:
raise ValueError('Augmentation policy {} not supported.'.format( raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type)) aug_type.type))
else: else:
self._augmenter = None self._augmenter = None
self._label_field_key = label_field_key self._label_field_key = label_field_key
self._color_jitter = color_jitter
if random_erasing:
self._random_erasing = augment.RandomErasing(
probability=random_erasing.probability,
min_area=random_erasing.min_area,
max_area=random_erasing.max_area,
min_aspect=random_erasing.min_aspect,
max_aspect=random_erasing.max_aspect,
min_count=random_erasing.min_count,
max_count=random_erasing.max_count,
trials=random_erasing.trials)
else:
self._random_erasing = None
self._is_multilabel = is_multilabel self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only self._decode_jpeg_only = decode_jpeg_only
...@@ -173,6 +194,12 @@ class Parser(parser.Parser): ...@@ -173,6 +194,12 @@ class Parser(parser.Parser):
if self._aug_rand_hflip: if self._aug_rand_hflip:
image = tf.image.random_flip_left_right(image) image = tf.image.random_flip_left_right(image)
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(image, self._color_jitter,
self._color_jitter,
self._color_jitter)
# Resizes image. # Resizes image.
image = tf.image.resize( image = tf.image.resize(
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
...@@ -187,6 +214,10 @@ class Parser(parser.Parser): ...@@ -187,6 +214,10 @@ class Parser(parser.Parser):
offset=MEAN_RGB, offset=MEAN_RGB,
scale=STDDEV_RGB) scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
......
...@@ -56,6 +56,7 @@ def build_classification_model( ...@@ -56,6 +56,7 @@ def build_classification_model(
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
input_specs=input_specs, input_specs=input_specs,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
kernel_initializer=model_config.kernel_initializer,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm, add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -12,10 +12,18 @@ ...@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image/video preprocessing. """Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501 AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719 RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
""" """
import math import math
from typing import Any, List, Iterable, Optional, Text, Tuple from typing import Any, List, Iterable, Optional, Text, Tuple
...@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width = tf.random.uniform( cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32) shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size) image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) pad_size, pad_size, replace)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) return image
def _fill_rectangle(image,
center_width,
center_height,
half_width,
half_height,
replace=None):
"""Fill blank area."""
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [ cutout_shape = [
image_height - (lower_pad + upper_pad), image_height - (lower_pad + upper_pad),
...@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values=1) constant_values=1)
mask = tf.expand_dims(mask, -1) mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3]) mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0), if replace is None:
tf.ones_like(image, dtype=image.dtype) * replace, image) fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image return image
...@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float): ...@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float):
return args return args
def _parse_policy_info(name: Text, prob: float, level: float, def _parse_policy_info(name: Text,
replace_value: List[int], cutout_const: float, prob: float,
translate_const: float) -> Tuple[Any, float, Any]: level: float,
replace_value: List[int],
cutout_const: float,
translate_const: float,
level_std: float = 0.) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param.""" """Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name] func = NAME_TO_FUNC[name]
if level_std > 0:
level += tf.random.normal([], dtype=tf.float32)
level = tf.clip_by_value(level, 0., _MAX_LEVEL)
args = level_to_arg(cutout_const, translate_const)[name](level) args = level_to_arg(cutout_const, translate_const)[name](level)
if name in REPLACE_FUNCS: if name in REPLACE_FUNCS:
...@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment): ...@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment):
magnitude: float = 10., magnitude: float = 10.,
cutout_const: float = 40., cutout_const: float = 40.,
translate_const: float = 100., translate_const: float = 100.,
prob_to_apply: Optional[float] = None): magnitude_std: float = 0.0,
prob_to_apply: Optional[float] = None,
exclude_ops: Optional[List[str]] = None):
"""Applies the RandAugment policy to images. """Applies the RandAugment policy to images.
Args: Args:
...@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment): ...@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment):
[5, 10]. [5, 10].
cutout_const: multiplier for applying cutout. cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation. translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each prob_to_apply: The probability to apply the selected augmentation at each
layer. layer.
exclude_ops: exclude selected operations.
""" """
super(RandAugment, self).__init__() super(RandAugment, self).__init__()
...@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment): ...@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment):
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd' 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
] ]
self.magnitude_std = magnitude_std
if exclude_ops:
self.available_ops = [
op for op in self.available_ops if op not in exclude_ops
]
def distort(self, image: tf.Tensor) -> tf.Tensor: def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the RandAugment policy to `image`. """Applies the RandAugment policy to `image`.
...@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment): ...@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment):
dtype=tf.float32) dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, prob, self.magnitude, func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const, replace_value, self.cutout_const,
self.translate_const) self.translate_const,
self.magnitude_std)
branch_fns.append(( branch_fns.append((
i, i,
# pylint:disable=g-long-lambda # pylint:disable=g-long-lambda
...@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment): ...@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment):
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
return image return image
class RandomErasing(ImageAugment):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
probability: float = 0.25,
min_area: float = 0.02,
max_area: float = 1 / 3,
min_aspect: float = 0.3,
max_aspect=None,
min_count=1,
max_count=1,
trials=10):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing rectangle.
Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing rectangle.
Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random erasing
rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased rectangles. Defaults
to 1.
max_count (int, optional): Maximum number of erased rectangles. Defaults
to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self._probability = probability
self._min_area = float(min_area)
self._max_area = float(max_area)
self._min_log_aspect = math.log(min_aspect)
self._max_log_aspect = math.log(max_aspect or 1 / min_aspect)
self._min_count = min_count
self._max_count = max_count
self._trials = trials
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, self._probability)
image = tf.cond(mirror_cond, lambda: self._erase(image), lambda: image)
return image
@tf.function
def _erase(self, image: tf.Tensor) -> tf.Tensor:
"""Erase an area."""
if self._min_count == self._max_count:
count = self._min_count
else:
count = tf.random.uniform(
shape=[],
minval=int(self._min_count),
maxval=int(self._max_count - self._min_count + 1),
dtype=tf.int32)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
area = tf.cast(image_width * image_height, tf.float32)
for _ in range(count):
# Work around since break is not supported in tf.function
is_trial_successfull = False
for _ in range(self._trials):
if not is_trial_successfull:
erase_area = tf.random.uniform(
shape=[],
minval=area * self._min_area,
maxval=area * self._max_area)
aspect_ratio = tf.math.exp(
tf.random.uniform(
shape=[],
minval=self._min_log_aspect,
maxval=self._max_log_aspect))
half_height = tf.cast(
tf.math.round(tf.math.sqrt(erase_area * aspect_ratio) / 2),
dtype=tf.int32)
half_width = tf.cast(
tf.math.round(tf.math.sqrt(erase_area / aspect_ratio) / 2),
dtype=tf.int32)
if 2 * half_height < image_height and 2 * half_width < image_width:
center_height = tf.random.uniform(
shape=[],
minval=0,
maxval=int(image_height - 2 * half_height),
dtype=tf.int32)
center_width = tf.random.uniform(
shape=[],
minval=0,
maxval=int(image_width - 2 * half_width),
dtype=tf.int32)
image = _fill_rectangle(
image,
center_width,
center_height,
half_width,
half_height,
replace=None)
is_trial_successfull = True
return image
class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self,
mixup_alpha: float = .8,
cutmix_alpha: float = 1.,
prob: float = 1.0,
switch_prob: float = 0.5,
label_smoothing: float = 0.1,
num_classes: int = 1001):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = 'batch'
self.mixup_enabled = True
if self.mixup_alpha and not self.cutmix_alpha:
self.switch_prob = -1
elif not self.mixup_alpha and self.cutmix_alpha:
self.switch_prob = 1
def __call__(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
return self.distort(images, labels)
def distort(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a
batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond = tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.mix_prob)
# pylint: disable=g-long-lambda
augment_a = lambda: self._update_labels(*tf.cond(
tf.less(
tf.random.uniform(shape=[], minval=0., maxval=1.0), self.switch_prob
), lambda: self._cutmix(images, labels), lambda: self._mixup(
images, labels)))
augment_b = lambda: (images, self._smooth_labels(labels))
# pylint: enable=g-long-lambda
return tf.cond(augment_cond, augment_a, augment_b)
@staticmethod
def _sample_from_beta(alpha, beta, shape):
sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
sample_beta = tf.random.gamma(shape, 1., beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
"""Apply cutmix."""
lam = MixupAndCutmix._sample_from_beta(self.cutmix_alpha, self.cutmix_alpha,
labels.shape)
ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0]
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
cut_width = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
random_center_height = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32)
random_center_width = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32)
bbox_area = cut_height * cut_width
lam = 1. - bbox_area / (image_height * image_width)
lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn(
lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32))
return images, labels, lam
def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(self.mixup_alpha, self.mixup_alpha,
labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam)
def _smooth_labels(self, labels: tf.Tensor) -> tf.Tensor:
off_value = self.label_smoothing / self.num_classes
on_value = 1. - self.label_smoothing + off_value
smooth_labels = tf.one_hot(
labels, self.num_classes, on_value=on_value, off_value=off_value)
return smooth_labels
def _update_labels(self, images: tf.Tensor, labels: tf.Tensor,
lam: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
labels_1 = self._smooth_labels(labels)
labels_2 = tf.reverse(labels_1, [0])
lam = tf.reshape(lam, [-1, 1])
labels = lam * labels_1 + (1. - lam) * labels_2
return images, labels
...@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter.distort(image) augmenter.distort(image)
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
def test_random_erase_replaces_some_pixels(self):
image = tf.zeros((224, 224, 3), dtype=tf.float32)
augmenter = augment.RandomErasing(probability=1., max_count=10)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertNotEqual(0, tf.reduce_max(aug_image))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
def test_mixup_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(aug_labels, 1. - label_smoothing +
2. / num_classes) # With tolerance
self.assertAllGreaterEqual(aug_labels, label_smoothing / num_classes -
1e4) # With tolerance
self.assertFalse(tf.math.reduce_all(images == aug_images))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
"""Preprocessing ops.""" """Preprocessing ops."""
import math import math
from typing import Optional
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import augment
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
CENTER_CROP_FRACTION = 0.875 CENTER_CROP_FRACTION = 0.875
...@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1): ...@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
return image, normalized_boxes, masks return image, normalized_boxes, masks
def color_jitter(image: tf.Tensor,
brightness: Optional[float] = 0.,
contrast: Optional[float] = 0.,
saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
image = tf.cast(image, dtype=tf.uint8)
image = random_brightness(image, brightness, seed=seed)
image = random_contrast(image, contrast, seed=seed)
image = random_saturation(image, saturation, seed=seed)
return image
def random_brightness(image: tf.Tensor,
brightness: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters brightness of an image.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert brightness >= 0, '`brightness` must be positive'
brightness = tf.random.uniform([],
max(0, 1 - brightness),
1 + brightness,
seed=seed,
dtype=tf.float32)
return augment.brightness(image, brightness)
def random_contrast(image: tf.Tensor,
contrast: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert contrast >= 0, '`contrast` must be positive'
contrast = tf.random.uniform([],
max(0, 1 - contrast),
1 + contrast,
seed=seed,
dtype=tf.float32)
return augment.contrast(image, contrast)
def random_saturation(image: tf.Tensor,
saturation: float = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert saturation >= 0, '`saturation` must be positive'
saturation = tf.random.uniform([],
max(0, 1 - saturation),
1 + saturation,
seed=seed,
dtype=tf.float32)
return _saturation(image, saturation)
def _saturation(image: tf.Tensor,
saturation: Optional[float] = 0.) -> tf.Tensor:
return augment.blend(
tf.repeat(tf.image.rgb_to_grayscale(image), 3, axis=-1), image,
saturation)
def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale, def random_crop_image_with_boxes_and_labels(img, boxes, labels, min_scale,
aspect_ratio_range, aspect_ratio_range,
min_overlap_params, max_retry): min_overlap_params, max_retry):
......
...@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase): ...@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_ = preprocess_ops.random_crop_image_v2( _ = preprocess_ops.random_crop_image_v2(
image_bytes, tf.constant([input_height, input_width, 3], tf.int32)) image_bytes, tf.constant([input_height, input_width, 3], tf.int32))
@parameterized.parameters((400, 600, 0), (400, 600, 0.4), (600, 400, 1.4))
def testColorJitter(self, input_height, input_width, color_jitter):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops.color_jitter(image, color_jitter,
color_jitter, color_jitter)
assert jittered_image.shape == image.shape
@parameterized.parameters((400, 600, 0), (400, 600, 0.4), (600, 400, 1))
def testSaturation(self, input_height, input_width, saturation):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
jittered_image = preprocess_ops._saturation(image, saturation)
assert jittered_image.shape == image.shape
@parameterized.parameters((640, 640, 20), (1280, 1280, 30)) @parameterized.parameters((640, 640, 20), (1280, 1280, 30))
def test_random_crop(self, input_height, input_width, num_boxes): def test_random_crop(self, input_height, input_width, num_boxes):
image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3)) image = tf.convert_to_tensor(np.random.rand(input_height, input_width, 3))
......
# Vision Transformer (ViT) # Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**: This implementation is still under development. No support will **DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase. be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929) - [![ViT Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![DEIT Paper](http://img.shields.io/badge/Paper-arXiv.2012.12877-B3181B?logo=arXiv)](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in This repository is the implementations of Vision Transformer (ViT) and
TensorFlow 2. Data-Efficient Image Transformer (DEIT) in TensorFlow 2.
* Paper title: * Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf). - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file - [Training data-efficient image transformers & distillation through attention](https://arxiv.org/pdf/2012.12877.pdf).
...@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config): ...@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size: int = 1 hidden_size: int = 1
patch_size: int = 16 patch_size: int = 16
transformer: Transformer = Transformer() transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False) use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass @dataclasses.dataclass
...@@ -51,6 +52,7 @@ class Losses(hyperparams.Config): ...@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True one_hot: bool = True
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass @dataclasses.dataclass
...@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)( ...@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask) image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain')
def image_classification_imagenet_deit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
eval_batch_size = 4096 # originally was 1024 but 4096 better for tpu v3-32
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(
l2_weight_decay=0.0,
label_smoothing=label_smoothing,
one_hot=False,
soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
aug_type=common.Augmentation(
type='randaug',
randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
label_smoothing=label_smoothing)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain') @exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer.""" """Image classification on imagenet with vision transformer."""
...@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: ...@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model=ImageClassificationModel( model=ImageClassificationModel(
num_classes=1001, num_classes=1001,
input_size=[224, 224, 3], input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone( backbone=backbones.Backbone(
type='vit', type='vit',
vit=backbones.VisionTransformer( vit=backbones.VisionTransformer(
...@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig: ...@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw': { 'adamw': {
'weight_decay_rate': 0.3, 'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$', 'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
} }
}, },
'learning_rate': { 'learning_rate': {
'type': 'cosine', 'type': 'cosine',
'cosine': { 'cosine': {
'initial_learning_rate': 0.003, 'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch, 'decay_steps': 300 * steps_per_epoch,
} }
}, },
......
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.nlp import keras_nlp
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
class TransformerEncoderBlock(keras_nlp.layers.TransformerEncoderBlock):
"""TransformerEncoderBlock layer with stochastic depth."""
def __init__(self, *args, stochastic_depth_drop_rate=0.0, **kwargs):
"""Initializes TransformerEncoderBlock."""
super().__init__(*args, **kwargs)
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
def build(self, input_shape):
if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth(self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = lambda x, *args, **kwargs: tf.identity(x)
super().build(input_shape)
def get_config(self):
config = {"stochastic_depth_drop_rate": self._stochastic_depth_drop_rate}
base_config = super().get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call."""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=training)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=training))
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=training)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output +
self._stochastic_depth(attention_output, training=training))
...@@ -17,17 +17,24 @@ ...@@ -17,17 +17,24 @@
import tensorflow as tf import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling import nn_blocks
layers = tf.keras.layers layers = tf.keras.layers
VIT_SPECS = { VIT_SPECS = {
'vit-testing': 'vit-ti16':
dict( dict(
hidden_size=1, hidden_size=192,
patch_size=16, patch_size=16,
transformer=dict(mlp_dim=1, num_heads=1, num_layers=1), transformer=dict(mlp_dim=768, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=1536, num_heads=6, num_layers=12),
), ),
'vit-b16': 'vit-b16':
dict( dict(
...@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
kernel_regularizer=None, kernel_regularizer=None,
inputs_positions=None, inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._num_layers = num_layers self._num_layers = num_layers
...@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer):
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
def build(self, input_shape): def build(self, input_shape):
self._pos_embed = AddPositionEmbs( self._pos_embed = AddPositionEmbs(
...@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer): ...@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = [] self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html # https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for _ in range(self._num_layers): for i in range(self._num_layers):
encoder_layer = keras_nlp.layers.TransformerEncoderBlock( encoder_layer = nn_blocks.TransformerEncoderBlock(
inner_activation=activations.gelu, inner_activation=activations.gelu,
num_attention_heads=self._num_heads, num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim, inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate, output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate, attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True, norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i + 1, self._num_layers),
norm_epsilon=1e-6) norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer) self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6) self._norm = layers.LayerNormalization(epsilon=1e-6)
...@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model): ...@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model):
num_layers=12, num_layers=12,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
dropout_rate=0.1, dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16, patch_size=16,
hidden_size=768, hidden_size=768,
representation_size=0, representation_size=0,
classifier='token', classifier='token',
kernel_regularizer=None): kernel_regularizer=None,
original_init=True):
"""VisionTransformer initialization function.""" """VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
...@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size=patch_size, kernel_size=patch_size,
strides=patch_size, strides=patch_size,
padding='valid', padding='valid',
kernel_regularizer=kernel_regularizer)( kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs) inputs)
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2) rows_axis, cols_axis = (1, 2)
...@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model): ...@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model):
num_heads=num_heads, num_heads=num_heads,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer)( kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)(
x) x)
if classifier == 'token': if classifier == 'token':
...@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model):
x = tf.keras.layers.Dense( x = tf.keras.layers.Dense(
representation_size, representation_size,
kernel_regularizer=kernel_regularizer, kernel_regularizer=kernel_regularizer,
name='pre_logits')( name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x) x)
x = tf.nn.tanh(x) x = tf.nn.tanh(x)
else: else:
...@@ -247,9 +268,11 @@ def build_vit(input_specs, ...@@ -247,9 +268,11 @@ def build_vit(input_specs,
num_layers=backbone_cfg.transformer.num_layers, num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate, attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate, dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs, input_specs=input_specs,
patch_size=backbone_cfg.patch_size, patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size, hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size, representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier, classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init)
...@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig): ...@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@exp_factory.register_config_factory('darknet_classification') @exp_factory.register_config_factory('darknet_classification')
def image_classification() -> cfg.ExperimentConfig: def darknet_classification() -> cfg.ExperimentConfig:
"""Image classification general.""" """Image classification general."""
return cfg.ExperimentConfig( return cfg.ExperimentConfig(
task=ImageClassificationTask(), task=ImageClassificationTask(),
......
...@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input ...@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_factory from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
from official.vision.beta.ops import augment
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask) @task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
...@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task): ...@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only=params.decode_jpeg_only, decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type, aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel, is_multilabel=is_multilabel,
dtype=params.dtype) dtype=params.dtype)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=num_classes)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode, decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training)) parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context) dataset = reader.read(input_context=input_context)
...@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task): ...@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task):
model_outputs, model_outputs,
from_logits=True, from_logits=True,
label_smoothing=losses_config.label_smoothing) label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels, model_outputs)
else: else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy( total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True) labels, model_outputs, from_logits=True)
...@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel: if not is_multilabel:
k = self.task_config.evaluation.top_k k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot: if (self.task_config.losses.one_hot or
self.task_config.losses.soft_labels):
metrics = [ metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.TopKCategoricalAccuracy(
...@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task): ...@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task):
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses) model_outputs=outputs,
labels=labels,
aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / num_replicas scaled_loss = loss / num_replicas
...@@ -266,13 +285,17 @@ class ImageClassificationTask(base_task.Task): ...@@ -266,13 +285,17 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel: if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes) labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels, loss = self.build_losses(
model_outputs=outputs,
labels=labels,
aux_losses=model.losses) aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment