Commit 40cd0a26 authored by Simon Geisler's avatar Simon Geisler
Browse files

deit without repeated aug and distillation

parent 3db445c7
...@@ -15,7 +15,7 @@ ...@@ -15,7 +15,7 @@
# Lint as: python3 # Lint as: python3
"""Common configurations.""" """Common configurations."""
from typing import Optional from typing import Optional, List
# Import libraries # Import libraries
import dataclasses import dataclasses
...@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config): ...@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config):
cutout_const: float = 40 cutout_const: float = 40
translate_const: float = 10 translate_const: float = 10
prob_to_apply: Optional[float] = None prob_to_apply: Optional[float] = None
exclude_ops: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass @dataclasses.dataclass
...@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config): ...@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config):
translate_const: float = 250 translate_const: float = 250
@dataclasses.dataclass
class RandomErasing(hyperparams.Config):
"""Configuration for RandomErasing."""
probability: float = 0.25
min_area: float = 0.02
max_area: float = 1 / 3
min_aspect: float = 0.3
max_aspect = None
min_count = 1
max_count = 1
trials = 10
@dataclasses.dataclass
class MixupAndCutmix(hyperparams.Config):
"""Configuration for MixupAndCutmix."""
mixup_alpha: float = .8
cutmix_alpha: float = 1.
prob: float = 1.0
switch_prob: float = 0.5
label_smoothing: float = 0.1
num_classes: int = 1000
@dataclasses.dataclass @dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig): class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation. """Configuration for input data augmentation.
......
...@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig): ...@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True aug_rand_hflip: bool = True
aug_type: Optional[ aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment. common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
random_erasing: Optional[common.RandomErasing] = None
file_type: str = 'tfrecord' file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded' image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label' label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True decode_jpeg_only: bool = True
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
# Keep for backward compatibility. # Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'. aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
...@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config): ...@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False) use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification # Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass @dataclasses.dataclass
...@@ -69,6 +73,7 @@ class Losses(hyperparams.Config): ...@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True one_hot: bool = True
label_smoothing: float = 0.0 label_smoothing: float = 0.0
l2_weight_decay: float = 0.0 l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -69,6 +69,8 @@ class Parser(parser.Parser): ...@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only: bool = True, decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True, aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None, aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False, is_multilabel: bool = False,
dtype: str = 'float32'): dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset. """Initializes parameters for parsing annotations in the dataset.
...@@ -85,6 +87,7 @@ class Parser(parser.Parser): ...@@ -85,6 +87,7 @@ class Parser(parser.Parser):
horizontal flip. horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment. RandAugment.
color_jitter: if > 0 the input image will be augmented by color jitter.
is_multilabel: A `bool`, whether or not each example has multiple labels. is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16', dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'. or 'bfloat16'.
...@@ -113,13 +116,28 @@ class Parser(parser.Parser): ...@@ -113,13 +116,28 @@ class Parser(parser.Parser):
magnitude=aug_type.randaug.magnitude, magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const, cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const, translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply) prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else: else:
raise ValueError('Augmentation policy {} not supported.'.format( raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type)) aug_type.type))
else: else:
self._augmenter = None self._augmenter = None
self._label_field_key = label_field_key self._label_field_key = label_field_key
self._color_jitter = color_jitter
if random_erasing:
self._random_erasing = augment.RandomErasing(
probability=random_erasing.probability,
min_area=random_erasing.min_area,
max_area=random_erasing.max_area,
min_aspect=random_erasing.min_aspect,
max_aspect=random_erasing.max_aspect,
min_count=random_erasing.min_count,
max_count=random_erasing.max_count,
trials=random_erasing.trials
)
else:
self._random_erasing = None
self._is_multilabel = is_multilabel self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only self._decode_jpeg_only = decode_jpeg_only
...@@ -213,11 +231,20 @@ class Parser(parser.Parser): ...@@ -213,11 +231,20 @@ class Parser(parser.Parser):
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR) image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3]) image.set_shape([self._output_size[0], self._output_size[1], 3])
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(
image, self._color_jitter, self._color_jitter, self._color_jitter)
# Normalizes image with mean and std pixel values. # Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image, image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB, offset=MEAN_RGB,
scale=STDDEV_RGB) scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype. # Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype) image = tf.image.convert_image_dtype(image, self._dtype)
......
...@@ -56,6 +56,7 @@ def build_classification_model( ...@@ -56,6 +56,7 @@ def build_classification_model(
num_classes=model_config.num_classes, num_classes=model_config.num_classes,
input_specs=input_specs, input_specs=input_specs,
dropout_rate=model_config.dropout_rate, dropout_rate=model_config.dropout_rate,
kernel_initializer=model_config.kernel_initializer,
kernel_regularizer=l2_regularizer, kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm, add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn, use_sync_bn=norm_activation_config.use_sync_bn,
......
...@@ -12,10 +12,17 @@ ...@@ -12,10 +12,17 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image/video preprocessing. """Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501 AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719 RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by https://github.com/rwightman/pytorch-image-models
""" """
import math import math
from typing import Any, List, Iterable, Optional, Text, Tuple from typing import Any, List, Iterable, Optional, Text, Tuple
...@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width = tf.random.uniform( cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32) shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size) image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) pad_size, pad_size, replace)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) return image
def _fill_rectangle(image, center_width, center_height, half_width,
half_height, replace=None):
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [ cutout_shape = [
image_height - (lower_pad + upper_pad), image_height - (lower_pad + upper_pad),
...@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor: ...@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values=1) constant_values=1)
mask = tf.expand_dims(mask, -1) mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3]) mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0), if replace is None:
tf.ones_like(image, dtype=image.dtype) * replace, image) fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image return image
...@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float): ...@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float):
def _parse_policy_info(name: Text, prob: float, level: float, def _parse_policy_info(name: Text, prob: float, level: float,
replace_value: List[int], cutout_const: float, replace_value: List[int], cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]: translate_const: float, level_std: float = 0.
) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param.""" """Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name] func = NAME_TO_FUNC[name]
if level_std > 0:
level += tf.random.normal([], dtype=tf.float32)
level = tf.clip_by_value(level, 0., _MAX_LEVEL)
args = level_to_arg(cutout_const, translate_const)[name](level) args = level_to_arg(cutout_const, translate_const)[name](level)
if name in REPLACE_FUNCS: if name in REPLACE_FUNCS:
...@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment): ...@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment):
magnitude: float = 10., magnitude: float = 10.,
cutout_const: float = 40., cutout_const: float = 40.,
translate_const: float = 100., translate_const: float = 100.,
prob_to_apply: Optional[float] = None): magnitude_std: float = 0.0,
prob_to_apply: Optional[float] = None,
exclude_ops: List[str] = []):
"""Applies the RandAugment policy to images. """Applies the RandAugment policy to images.
Args: Args:
...@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment): ...@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment):
[5, 10]. [5, 10].
cutout_const: multiplier for applying cutout. cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation. translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each prob_to_apply: The probability to apply the selected augmentation at each
layer. layer.
exclude_ops: exclude selected operations.
""" """
super(RandAugment, self).__init__() super(RandAugment, self).__init__()
...@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment): ...@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment):
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY', 'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd' 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
] ]
self.magnitude_std = magnitude_std
self.available_ops = [
op for op in self.available_ops if op not in exclude_ops]
def distort(self, image: tf.Tensor) -> tf.Tensor: def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the RandAugment policy to `image`. """Applies the RandAugment policy to `image`.
...@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment): ...@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment):
dtype=tf.float32) dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, prob, self.magnitude, func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const, replace_value, self.cutout_const,
self.translate_const) self.translate_const,
self.magnitude_std)
branch_fns.append(( branch_fns.append((
i, i,
# pylint:disable=g-long-lambda # pylint:disable=g-long-lambda
...@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment): ...@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment):
image = tf.cast(image, dtype=input_image_type) image = tf.cast(image, dtype=input_image_type)
return image return image
class RandomErasing(ImageAugment):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self, probability: float = 0.25, min_area: float = 0.02,
max_area: float = 1 / 3, min_aspect: float = 0.3,
max_aspect=None, min_count=1, max_count=1, trials=10):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing
rectangle. Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing
rectangle. Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random
erasing rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased
rectangles. Defaults to 1.
max_count (int, optional): Maximum number of erased
rectangles. Defaults to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self._probability = probability
self._min_area = float(min_area)
self._max_area = float(max_area)
self._min_log_aspect = math.log(min_aspect)
self._max_log_aspect = math.log(max_aspect or 1 / min_aspect)
self._min_count = min_count
self._max_count = max_count
self._trials = trials
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, .5)
tf.cond(mirror_cond, self._erase, lambda: image)
return image
@tf.function
def _erase(self, image: tf.Tensor) -> tf.Tensor:
count = self._min_count if self._min_count == self._max_count else \
tf.random.uniform(shape=[], minval=int(self._min_count),
maxval=int(self._max_count - self._min_count + 1),
dtype=tf.int32)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
area = tf.cast(image_width * image_height, tf.float32)
for _ in range(count):
for _ in range(self._trials):
erase_area = tf.random.uniform(shape=[],
minval=area * self._min_area,
maxval=area * self._max_area)
aspect_ratio = tf.math.exp(tf.random.uniform(
shape=[], minval=self._min_log_aspect,
maxval=self._max_log_aspect))
half_height = tf.cast(tf.math.round(tf.math.sqrt(
erase_area * aspect_ratio) / 2), dtype=tf.int32)
half_width = tf.cast(tf.math.round(tf.math.sqrt(
erase_area / aspect_ratio) / 2), dtype=tf.int32)
if 2 * half_height < image_height and 2 * half_width < image_width:
center_height = tf.random.uniform(
shape=[], minval=0, maxval=int(image_height - 2 * half_height),
dtype=tf.int32)
center_width = tf.random.uniform(
shape=[], minval=0, maxval=int(image_width - 2 * half_width),
dtype=tf.int32)
image = _fill_rectangle(image, center_width, center_height,
half_width, half_height, replace=None)
break
return image
class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self, mixup_alpha: float = .8, cutmix_alpha: float = 1.,
prob: float = 1.0, switch_prob: float = 0.5,
label_smoothing: float = 0.1, num_classes: int = 1001):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from
a beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = 'batch'
self.mixup_enabled = True
if self.mixup_alpha and not self.cutmix_alpha:
self.switch_prob = -1
elif not self.mixup_alpha and self.cutmix_alpha:
self.switch_prob = 1
def __call__(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
return self.distort(images, labels)
def distort(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Applies Mixup and/or Cutmix to batch of `images` and transforms the
`labels` (incl. label smoothing).
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing
a batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond = tf.less(tf.random.uniform(shape=[], minval=0., maxval=1.0),
self.mix_prob)
return tf.cond(
augment_cond,
lambda: self._update_labels(*tf.cond(
tf.less(tf.random.uniform(
shape=[], minval=0., maxval=1.0), self.switch_prob),
lambda: self._cutmix(images, labels),
lambda: self._mixup(images, labels)
)),
lambda: (images, self._smooth_labels(labels))
)
@staticmethod
def _sample_from_beta(alpha: float, beta: float, shape: tuple):
sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
sample_beta = tf.random.gamma(shape, 1., beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(
self.cutmix_alpha, self.cutmix_alpha, labels.shape)
ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0]
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
cut_width = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
random_center_height = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32)
random_center_width = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32)
bbox_area = cut_height * cut_width
lam = 1. - bbox_area / (image_height * image_width)
lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn(
lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32))
return images, labels, lam
def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(
self.mixup_alpha, self.mixup_alpha, labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam)
def _smooth_labels(self, labels: tf.Tensor) -> tf.Tensor:
off_value = self.label_smoothing / self.num_classes
on_value = 1. - self.label_smoothing + off_value
smooth_labels = tf.one_hot(labels, self.num_classes,
on_value=on_value, off_value=off_value)
return smooth_labels
def _update_labels(self, images: tf.Tensor, labels: tf.Tensor,
lam: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
labels_1 = self._smooth_labels(labels)
labels_2 = tf.reverse(labels_1, [0])
lam = tf.reshape(lam, [-1, 1])
labels = lam * labels_1 + (1. - lam) * labels_2
return images, labels
...@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase): ...@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter.distort(image) augmenter.distort(image)
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
def test_random_erase_replaces_some_pixels(self):
image = tf.zeros((224, 224, 3), dtype=tf.float32)
augmenter = augment.RandomErasing(probability=1., max_count=10)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertLess(0, tf.reduce_max(aug_image.shape))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
def test_mixup_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
self.assertTrue(not tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
self.assertTrue(not tf.math.reduce_all(images == aug_images))
if __name__ == '__main__': if __name__ == '__main__':
tf.test.main() tf.test.main()
...@@ -15,10 +15,12 @@ ...@@ -15,10 +15,12 @@
"""Preprocessing ops.""" """Preprocessing ops."""
import math import math
from typing import Optional
from six.moves import range from six.moves import range
import tensorflow as tf import tensorflow as tf
from official.vision.beta.ops import box_ops from official.vision.beta.ops import box_ops
from official.vision.beta.ops import augment
CENTER_CROP_FRACTION = 0.875 CENTER_CROP_FRACTION = 0.875
...@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1): ...@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
lambda: masks) lambda: masks)
return image, normalized_boxes, masks return image, normalized_boxes, masks
def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
contrast: Optional[float] = 0.,
saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
image = random_brightness(image, brightness, seed=seed)
image = random_contrast(image, contrast, seed=seed)
image = random_saturation(image, saturation, seed=seed)
return image
def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert brightness >= 0 and brightness <= 1., '`brightness` must be in [0, 1]'
brightness = tf.random.uniform(
[], max(0, 1 - brightness), 1 + brightness, seed=seed)
return augment.brightness(image, brightness)
def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert contrast >= 0 and contrast <= 1., '`contrast` must be in [0, 1]'
contrast = tf.random.uniform(
[], max(0, 1 - contrast), 1 + contrast, seed=seed)
return augment.contrast(image, contrast)
def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert saturation >= 0 and saturation <= 1., '`saturation` must be in [0, 1]'
saturation = tf.random.uniform(
[], max(0, 1 - saturation), 1 + saturation, seed=seed)
return augment.blend(tf.image.rgb_to_grayscale(image), image, saturation)
# Vision Transformer (ViT) # Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**: This implementation is still under development. No support will **DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase. be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929) - [![ViT Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![DEIT Paper](http://img.shields.io/badge/Paper-arXiv.2012.12877-B3181B?logo=arXiv)](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in This repository is the implementations of Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT) in
TensorFlow 2. TensorFlow 2.
* Paper title: * Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf). - [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file - [Training data-efficient image transformers & distillation through attention](https://arxiv.org/pdf/2012.12877.pdf).
...@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config): ...@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size: int = 1 hidden_size: int = 1
patch_size: int = 16 patch_size: int = 16
transformer: Transformer = Transformer() transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
@dataclasses.dataclass @dataclasses.dataclass
......
from official.vision.beta.projects.vit.modeling.layers.vit_transformer_encoder_block import TransformerEncoderBlock
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
@tf.keras.utils.register_keras_serializable(package="Vision")
class TransformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. Here we ass support for stochastic depth.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
num_attention_heads,
inner_dim,
inner_activation,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
stochastic_depth_drop_rate=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
stochastic_depth_drop_rate: Dropout propobability for the stochastic depth
regularization.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super().__init__(**kwargs)
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(TransformerEncoderBlock, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"stochastic_depth_drop_rate":
self._stochastic_depth_drop_rate,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
}
base_config = super(TransformerEncoderBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
with_stochastic_depth = training and self._stochastic_depth
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=with_stochastic_depth)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=with_stochastic_depth)
)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=with_stochastic_depth)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output
+ self._stochastic_depth(attention_output, training=with_stochastic_depth)
)
...@@ -19,6 +19,8 @@ import tensorflow as tf ...@@ -19,6 +19,8 @@ import tensorflow as tf
from official.modeling import activations from official.modeling import activations
from official.nlp import keras_nlp from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling.layers import TransformerEncoderBlock
layers = tf.keras.layers layers = tf.keras.layers
...@@ -29,6 +31,18 @@ VIT_SPECS = { ...@@ -29,6 +31,18 @@ VIT_SPECS = {
patch_size=16, patch_size=16,
transformer=dict(mlp_dim=1, num_heads=1, num_layers=1), transformer=dict(mlp_dim=1, num_heads=1, num_layers=1),
), ),
'vit-ti16':
dict(
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=6, num_layers=12),
),
'vit-b16': 'vit-b16':
dict( dict(
hidden_size=768, hidden_size=768,
...@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate=0.1, attention_dropout_rate=0.1,
kernel_regularizer=None, kernel_regularizer=None,
inputs_positions=None, inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
**kwargs): **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self._num_layers = num_layers self._num_layers = num_layers
...@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer): ...@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer):
self._attention_dropout_rate = attention_dropout_rate self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
def build(self, input_shape): def build(self, input_shape):
self._pos_embed = AddPositionEmbs( self._pos_embed = AddPositionEmbs(
...@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer): ...@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = [] self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation. # Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html # https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for _ in range(self._num_layers): for i in range(self._num_layers):
encoder_layer = keras_nlp.layers.TransformerEncoderBlock( encoder_layer = TransformerEncoderBlock(
inner_activation=activations.gelu, inner_activation=activations.gelu,
num_attention_heads=self._num_heads, num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim, inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate, output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate, attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer, kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True, norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i, self._num_layers - 1),
norm_epsilon=1e-6) norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer) self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6) self._norm = layers.LayerNormalization(epsilon=1e-6)
...@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model): ...@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model):
num_layers=12, num_layers=12,
attention_dropout_rate=0.0, attention_dropout_rate=0.0,
dropout_rate=0.1, dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]), input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16, patch_size=16,
hidden_size=768, hidden_size=768,
representation_size=0, representation_size=0,
classifier='token', classifier='token',
kernel_regularizer=None): kernel_regularizer=None,
original_init=True):
"""VisionTransformer initialization function.""" """VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:]) inputs = tf.keras.Input(shape=input_specs.shape[1:])
...@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size=patch_size, kernel_size=patch_size,
strides=patch_size, strides=patch_size,
padding='valid', padding='valid',
kernel_regularizer=kernel_regularizer)( kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs) inputs)
if tf.keras.backend.image_data_format() == 'channels_last': if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2) rows_axis, cols_axis = (1, 2)
...@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model): ...@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model):
num_heads=num_heads, num_heads=num_heads,
dropout_rate=dropout_rate, dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate, attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer)( kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)(
x) x)
if classifier == 'token': if classifier == 'token':
...@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model):
x = tf.keras.layers.Dense( x = tf.keras.layers.Dense(
representation_size, representation_size,
kernel_regularizer=kernel_regularizer, kernel_regularizer=kernel_regularizer,
name='pre_logits')( name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x) x)
x = tf.nn.tanh(x) x = tf.nn.tanh(x)
else: else:
...@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model): ...@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model):
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size]) tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
} }
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints) super(VisionTransformer, self).__init__(
inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit') @factory.register_backbone_builder('vit')
...@@ -247,9 +276,11 @@ def build_vit(input_specs, ...@@ -247,9 +276,11 @@ def build_vit(input_specs,
num_layers=backbone_cfg.transformer.num_layers, num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate, attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate, dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs, input_specs=input_specs,
patch_size=backbone_cfg.patch_size, patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size, hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size, representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier, classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer) kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init)
...@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig): ...@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@exp_factory.register_config_factory('darknet_classification') @exp_factory.register_config_factory('darknet_classification')
def image_classification() -> cfg.ExperimentConfig: def darknet_classification() -> cfg.ExperimentConfig:
"""Image classification general.""" """Image classification general."""
return cfg.ExperimentConfig( return cfg.ExperimentConfig(
task=ImageClassificationTask(), task=ImageClassificationTask(),
......
...@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input ...@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_factory from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory from official.vision.beta.modeling import factory
from official.vision.beta.ops import augment
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask) @task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
...@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task): ...@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only=params.decode_jpeg_only, decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip, aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type, aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel, is_multilabel=is_multilabel,
dtype=params.dtype) dtype=params.dtype)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=params.mixup_and_cutmix.num_classes
)
reader = input_reader_factory.input_reader_generator( reader = input_reader_factory.input_reader_generator(
params, params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type), dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode, decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training)) parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context) dataset = reader.read(input_context=input_context)
...@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task): ...@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task):
def build_losses(self, def build_losses(self,
labels: tf.Tensor, labels: tf.Tensor,
model_outputs: tf.Tensor, model_outputs: tf.Tensor,
is_validation: bool,
aux_losses: Optional[Any] = None) -> tf.Tensor: aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds sparse categorical cross entropy loss. """Builds sparse categorical cross entropy loss.
Args: Args:
labels: Input groundtruth labels. labels: Input groundtruth labels.
model_outputs: Output logits of the classifier. model_outputs: Output logits of the classifier.
is_validation: To handle that some augmentations need custom soft labels
while the validation should remain unchainged.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model. aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns: Returns:
...@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task): ...@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel: if not is_multilabel:
if losses_config.one_hot: # Some augmentation need custom soft labels in training, but validation
# should remain unchainged
if losses_config.one_hot or is_validation:
total_loss = tf.keras.losses.categorical_crossentropy( total_loss = tf.keras.losses.categorical_crossentropy(
labels, labels,
model_outputs, model_outputs,
from_logits=True, from_logits=True,
label_smoothing=losses_config.label_smoothing) label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels,
model_outputs
)
else: else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy( total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True) labels, model_outputs, from_logits=True)
...@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel: if not is_multilabel:
k = self.task_config.evaluation.top_k k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot: if (self.task_config.losses.one_hot
or self.task_config.losses.soft_labels):
metrics = [ metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'), tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy( tf.keras.metrics.TopKCategoricalAccuracy(
...@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task): ...@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task):
lambda x: tf.cast(x, tf.float32), outputs) lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss. # Computes per-replica loss.
loss = self.build_losses( loss = self.build_losses(model_outputs=outputs, labels=labels,
model_outputs=outputs, labels=labels, aux_losses=model.losses) is_validation=False, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the # Scales loss as the default gradients allreduce performs sum inside the
# optimizer. # optimizer.
scaled_loss = loss / num_replicas scaled_loss = loss / num_replicas
...@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task): ...@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs. A dictionary of logs.
""" """
features, labels = inputs features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel: if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes) labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model) outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs) outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels, loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses) is_validation=True, aux_losses=model.losses)
logs = {self.loss: loss} logs = {self.loss: loss}
if metrics: if metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment