"docs/vscode:/vscode.git/clone" did not exist on "0ce29733d2676713ae04ed8bb030215af1f06b33"
Commit 40cd0a26 authored by Simon Geisler's avatar Simon Geisler
Browse files

deit without repeated aug and distillation

parent 3db445c7
......@@ -15,7 +15,7 @@
# Lint as: python3
"""Common configurations."""
from typing import Optional
from typing import Optional, List
# Import libraries
import dataclasses
......@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config):
cutout_const: float = 40
translate_const: float = 10
prob_to_apply: Optional[float] = None
exclude_ops: List[str] = dataclasses.field(default_factory=list)
@dataclasses.dataclass
......@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config):
translate_const: float = 250
@dataclasses.dataclass
class RandomErasing(hyperparams.Config):
"""Configuration for RandomErasing."""
probability: float = 0.25
min_area: float = 0.02
max_area: float = 1 / 3
min_aspect: float = 0.3
max_aspect = None
min_count = 1
max_count = 1
trials = 10
@dataclasses.dataclass
class MixupAndCutmix(hyperparams.Config):
"""Configuration for MixupAndCutmix."""
mixup_alpha: float = .8
cutmix_alpha: float = 1.
prob: float = 1.0
switch_prob: float = 0.5
label_smoothing: float = 0.1
num_classes: int = 1000
@dataclasses.dataclass
class Augmentation(hyperparams.OneOfConfig):
"""Configuration for input data augmentation.
......
......@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip: bool = True
aug_type: Optional[
common.Augmentation] = None # Choose from AutoAugment and RandAugment.
color_jitter: float = 0.
random_erasing: Optional[common.RandomErasing] = None
file_type: str = 'tfrecord'
image_field_key: str = 'image/encoded'
label_field_key: str = 'image/class/label'
decode_jpeg_only: bool = True
mixup_and_cutmix: Optional[common.MixupAndCutmix] = None
# Keep for backward compatibility.
aug_policy: Optional[str] = None # None, 'autoaug', or 'randaug'.
......@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass
......
......@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only: bool = True,
aug_rand_hflip: bool = True,
aug_type: Optional[common.Augmentation] = None,
color_jitter: float = 0.,
random_erasing: Optional[common.RandomErasing] = None,
is_multilabel: bool = False,
dtype: str = 'float32'):
"""Initializes parameters for parsing annotations in the dataset.
......@@ -85,6 +87,7 @@ class Parser(parser.Parser):
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: if > 0 the input image will be augmented by color jitter.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
......@@ -113,13 +116,28 @@ class Parser(parser.Parser):
magnitude=aug_type.randaug.magnitude,
cutout_const=aug_type.randaug.cutout_const,
translate_const=aug_type.randaug.translate_const,
prob_to_apply=aug_type.randaug.prob_to_apply)
prob_to_apply=aug_type.randaug.prob_to_apply,
exclude_ops=aug_type.randaug.exclude_ops)
else:
raise ValueError('Augmentation policy {} not supported.'.format(
aug_type.type))
else:
self._augmenter = None
self._label_field_key = label_field_key
self._color_jitter = color_jitter
if random_erasing:
self._random_erasing = augment.RandomErasing(
probability=random_erasing.probability,
min_area=random_erasing.min_area,
max_area=random_erasing.max_area,
min_aspect=random_erasing.min_aspect,
max_aspect=random_erasing.max_aspect,
min_count=random_erasing.min_count,
max_count=random_erasing.max_count,
trials=random_erasing.trials
)
else:
self._random_erasing = None
self._is_multilabel = is_multilabel
self._decode_jpeg_only = decode_jpeg_only
......@@ -213,11 +231,20 @@ class Parser(parser.Parser):
image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
image.set_shape([self._output_size[0], self._output_size[1], 3])
# Color jitter.
if self._color_jitter > 0:
image = preprocess_ops.color_jitter(
image, self._color_jitter, self._color_jitter, self._color_jitter)
# Normalizes image with mean and std pixel values.
image = preprocess_ops.normalize_image(image,
offset=MEAN_RGB,
scale=STDDEV_RGB)
# Random erasing after the image has been normalized
if self._random_erasing is not None:
image = self._random_erasing.distort(image)
# Convert image to self._dtype.
image = tf.image.convert_image_dtype(image, self._dtype)
......
......@@ -56,6 +56,7 @@ def build_classification_model(
num_classes=model_config.num_classes,
input_specs=input_specs,
dropout_rate=model_config.dropout_rate,
kernel_initializer=model_config.kernel_initializer,
kernel_regularizer=l2_regularizer,
add_head_batch_norm=model_config.add_head_batch_norm,
use_sync_bn=norm_activation_config.use_sync_bn,
......
......@@ -12,10 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""AutoAugment and RandAugment policies for enhanced image/video preprocessing.
"""Augmentation policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by https://github.com/rwightman/pytorch-image-models
"""
import math
from typing import Any, List, Iterable, Optional, Text, Tuple
......@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width = tf.random.uniform(
shape=[], minval=0, maxval=image_width, dtype=tf.int32)
lower_pad = tf.maximum(0, cutout_center_height - pad_size)
upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size)
left_pad = tf.maximum(0, cutout_center_width - pad_size)
right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size)
image = _fill_rectangle(image, cutout_center_width, cutout_center_height,
pad_size, pad_size, replace)
return image
def _fill_rectangle(image, center_width, center_height, half_width,
half_height, replace=None):
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
lower_pad = tf.maximum(0, center_height - half_height)
upper_pad = tf.maximum(0, image_height - center_height - half_height)
left_pad = tf.maximum(0, center_width - half_width)
right_pad = tf.maximum(0, image_width - center_width - half_width)
cutout_shape = [
image_height - (lower_pad + upper_pad),
......@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values=1)
mask = tf.expand_dims(mask, -1)
mask = tf.tile(mask, [1, 1, 3])
image = tf.where(
tf.equal(mask, 0),
tf.ones_like(image, dtype=image.dtype) * replace, image)
if replace is None:
fill = tf.random.normal(tf.shape(image), dtype=image.dtype)
elif isinstance(replace, tf.Tensor):
fill = replace
else:
fill = tf.ones_like(image, dtype=image.dtype) * replace
image = tf.where(tf.equal(mask, 0), fill, image)
return image
......@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float):
def _parse_policy_info(name: Text, prob: float, level: float,
replace_value: List[int], cutout_const: float,
translate_const: float) -> Tuple[Any, float, Any]:
translate_const: float, level_std: float = 0.
) -> Tuple[Any, float, Any]:
"""Return the function that corresponds to `name` and update `level` param."""
func = NAME_TO_FUNC[name]
if level_std > 0:
level += tf.random.normal([], dtype=tf.float32)
level = tf.clip_by_value(level, 0., _MAX_LEVEL)
args = level_to_arg(cutout_const, translate_const)[name](level)
if name in REPLACE_FUNCS:
......@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment):
magnitude: float = 10.,
cutout_const: float = 40.,
translate_const: float = 100.,
prob_to_apply: Optional[float] = None):
magnitude_std: float = 0.0,
prob_to_apply: Optional[float] = None,
exclude_ops: List[str] = []):
"""Applies the RandAugment policy to images.
Args:
......@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment):
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
exclude_ops: exclude selected operations.
"""
super(RandAugment, self).__init__()
......@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment):
'Color', 'Contrast', 'Brightness', 'Sharpness', 'ShearX', 'ShearY',
'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'
]
self.magnitude_std = magnitude_std
self.available_ops = [
op for op in self.available_ops if op not in exclude_ops]
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies the RandAugment policy to `image`.
......@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment):
dtype=tf.float32)
func, _, args = _parse_policy_info(op_name, prob, self.magnitude,
replace_value, self.cutout_const,
self.translate_const)
self.translate_const,
self.magnitude_std)
branch_fns.append((
i,
# pylint:disable=g-long-lambda
......@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment):
image = tf.cast(image, dtype=input_image_type)
return image
class RandomErasing(ImageAugment):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self, probability: float = 0.25, min_area: float = 0.02,
max_area: float = 1 / 3, min_aspect: float = 0.3,
max_aspect=None, min_count=1, max_count=1, trials=10):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing
rectangle. Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing
rectangle. Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random
erasing rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased
rectangles. Defaults to 1.
max_count (int, optional): Maximum number of erased
rectangles. Defaults to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self._probability = probability
self._min_area = float(min_area)
self._max_area = float(max_area)
self._min_log_aspect = math.log(min_aspect)
self._max_log_aspect = math.log(max_aspect or 1 / min_aspect)
self._min_count = min_count
self._max_count = max_count
self._trials = trials
def distort(self, image: tf.Tensor) -> tf.Tensor:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random = tf.random.uniform(shape=[], minval=0., maxval=1.0)
mirror_cond = tf.less(uniform_random, .5)
tf.cond(mirror_cond, self._erase, lambda: image)
return image
@tf.function
def _erase(self, image: tf.Tensor) -> tf.Tensor:
count = self._min_count if self._min_count == self._max_count else \
tf.random.uniform(shape=[], minval=int(self._min_count),
maxval=int(self._max_count - self._min_count + 1),
dtype=tf.int32)
image_height = tf.shape(image)[0]
image_width = tf.shape(image)[1]
area = tf.cast(image_width * image_height, tf.float32)
for _ in range(count):
for _ in range(self._trials):
erase_area = tf.random.uniform(shape=[],
minval=area * self._min_area,
maxval=area * self._max_area)
aspect_ratio = tf.math.exp(tf.random.uniform(
shape=[], minval=self._min_log_aspect,
maxval=self._max_log_aspect))
half_height = tf.cast(tf.math.round(tf.math.sqrt(
erase_area * aspect_ratio) / 2), dtype=tf.int32)
half_width = tf.cast(tf.math.round(tf.math.sqrt(
erase_area / aspect_ratio) / 2), dtype=tf.int32)
if 2 * half_height < image_height and 2 * half_width < image_width:
center_height = tf.random.uniform(
shape=[], minval=0, maxval=int(image_height - 2 * half_height),
dtype=tf.int32)
center_width = tf.random.uniform(
shape=[], minval=0, maxval=int(image_width - 2 * half_width),
dtype=tf.int32)
image = _fill_rectangle(image, center_width, center_height,
half_width, half_height, replace=None)
break
return image
class MixupAndCutmix:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def __init__(self, mixup_alpha: float = .8, cutmix_alpha: float = 1.,
prob: float = 1.0, switch_prob: float = 0.5,
label_smoothing: float = 0.1, num_classes: int = 1001):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from
a beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self.mixup_alpha = mixup_alpha
self.cutmix_alpha = cutmix_alpha
self.mix_prob = prob
self.switch_prob = switch_prob
self.label_smoothing = label_smoothing
self.num_classes = num_classes
self.mode = 'batch'
self.mixup_enabled = True
if self.mixup_alpha and not self.cutmix_alpha:
self.switch_prob = -1
elif not self.mixup_alpha and self.cutmix_alpha:
self.switch_prob = 1
def __call__(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
return self.distort(images, labels)
def distort(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
"""Applies Mixup and/or Cutmix to batch of `images` and transforms the
`labels` (incl. label smoothing).
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing
a batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond = tf.less(tf.random.uniform(shape=[], minval=0., maxval=1.0),
self.mix_prob)
return tf.cond(
augment_cond,
lambda: self._update_labels(*tf.cond(
tf.less(tf.random.uniform(
shape=[], minval=0., maxval=1.0), self.switch_prob),
lambda: self._cutmix(images, labels),
lambda: self._mixup(images, labels)
)),
lambda: (images, self._smooth_labels(labels))
)
@staticmethod
def _sample_from_beta(alpha: float, beta: float, shape: tuple):
sample_alpha = tf.random.gamma(shape, 1., beta=alpha)
sample_beta = tf.random.gamma(shape, 1., beta=beta)
return sample_alpha / (sample_alpha + sample_beta)
def _cutmix(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(
self.cutmix_alpha, self.cutmix_alpha, labels.shape)
ratio = tf.math.sqrt(1 - lam)
batch_size = tf.shape(images)[0]
image_height, image_width = tf.shape(images)[1], tf.shape(images)[2]
cut_height = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
cut_width = tf.cast(
ratio * tf.cast(image_height, dtype=tf.float32), dtype=tf.int32)
random_center_height = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_height, dtype=tf.int32)
random_center_width = tf.random.uniform(
shape=[batch_size], minval=0, maxval=image_width, dtype=tf.int32)
bbox_area = cut_height * cut_width
lam = 1. - bbox_area / (image_height * image_width)
lam = tf.cast(lam, dtype=tf.float32)
images = tf.map_fn(
lambda x: _fill_rectangle(*x),
(images, random_center_width, random_center_height, cut_width // 2,
cut_height // 2, tf.reverse(images, [0])),
dtype=(tf.float32, tf.int32, tf.int32, tf.int32, tf.int32, tf.float32),
fn_output_signature=tf.TensorSpec(images.shape[1:], dtype=tf.float32))
return images, labels, lam
def _mixup(self, images: tf.Tensor,
labels: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]:
lam = MixupAndCutmix._sample_from_beta(
self.mixup_alpha, self.mixup_alpha, labels.shape)
lam = tf.reshape(lam, [-1, 1, 1, 1])
images = lam * images + (1. - lam) * tf.reverse(images, [0])
return images, labels, tf.squeeze(lam)
def _smooth_labels(self, labels: tf.Tensor) -> tf.Tensor:
off_value = self.label_smoothing / self.num_classes
on_value = 1. - self.label_smoothing + off_value
smooth_labels = tf.one_hot(labels, self.num_classes,
on_value=on_value, off_value=off_value)
return smooth_labels
def _update_labels(self, images: tf.Tensor, labels: tf.Tensor,
lam: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
labels_1 = self._smooth_labels(labels)
labels_2 = tf.reverse(labels_1, [0])
lam = tf.reshape(lam, [-1, 1])
labels = lam * labels_1 + (1. - lam) * labels_2
return images, labels
......@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter.distort(image)
class RandomErasingTest(tf.test.TestCase, parameterized.TestCase):
def test_random_erase_replaces_some_pixels(self):
image = tf.zeros((224, 224, 3), dtype=tf.float32)
augmenter = augment.RandomErasing(probability=1., max_count=10)
aug_image = augmenter.distort(image)
self.assertEqual((224, 224, 3), aug_image.shape)
self.assertLess(0, tf.reduce_max(aug_image.shape))
class MixupAndCutmixTest(tf.test.TestCase, parameterized.TestCase):
def test_mixup_and_cutmix_smoothes_labels(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
num_classes=num_classes, label_smoothing=label_smoothing)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
def test_mixup_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=1., cutmix_alpha=0., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
self.assertTrue(not tf.math.reduce_all(images == aug_images))
def test_cutmix_changes_image(self):
batch_size = 12
num_classes = 1000
label_smoothing = 0.1
images = tf.random.normal((batch_size, 224, 224, 3), dtype=tf.float32)
labels = tf.range(batch_size)
augmenter = augment.MixupAndCutmix(
mixup_alpha=0., cutmix_alpha=1., num_classes=num_classes)
aug_images, aug_labels = augmenter.distort(images, labels)
self.assertEqual(images.shape, aug_images.shape)
self.assertEqual(images.dtype, aug_images.dtype)
self.assertEqual([batch_size, num_classes], aug_labels.shape)
self.assertAllLessEqual(
aug_labels, 1. - label_smoothing + 2. / num_classes) # With tolerance
self.assertAllGreaterEqual(
aug_labels, label_smoothing / num_classes - 1e4) # With tolerance
self.assertTrue(not tf.math.reduce_all(images == aug_images))
if __name__ == '__main__':
tf.test.main()
......@@ -15,10 +15,12 @@
"""Preprocessing ops."""
import math
from typing import Optional
from six.moves import range
import tensorflow as tf
from official.vision.beta.ops import box_ops
from official.vision.beta.ops import augment
CENTER_CROP_FRACTION = 0.875
......@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
lambda: masks)
return image, normalized_boxes, masks
def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
contrast: Optional[float] = 0.,
saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
image = random_brightness(image, brightness, seed=seed)
image = random_contrast(image, contrast, seed=seed)
image = random_saturation(image, saturation, seed=seed)
return image
def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert brightness >= 0 and brightness <= 1., '`brightness` must be in [0, 1]'
brightness = tf.random.uniform(
[], max(0, 1 - brightness), 1 + brightness, seed=seed)
return augment.brightness(image, brightness)
def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert contrast >= 0 and contrast <= 1., '`contrast` must be in [0, 1]'
contrast = tf.random.uniform(
[], max(0, 1 - contrast), 1 + contrast, seed=seed)
return augment.contrast(image, contrast)
def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
seed: Optional[int] = None) -> tf.Tensor:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert saturation >= 0 and saturation <= 1., '`saturation` must be in [0, 1]'
saturation = tf.random.uniform(
[], max(0, 1 - saturation), 1 + saturation, seed=seed)
return augment.blend(tf.image.rgb_to_grayscale(image), image, saturation)
# Vision Transformer (ViT)
# Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**: This implementation is still under development. No support will
be provided during the development phase.
[![Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![ViT Paper](http://img.shields.io/badge/Paper-arXiv.2010.11929-B3181B?logo=arXiv)](https://arxiv.org/abs/2010.11929)
- [![DEIT Paper](http://img.shields.io/badge/Paper-arXiv.2012.12877-B3181B?logo=arXiv)](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in
This repository is the implementations of Vision Transformer (ViT) and Data-Efficient Image Transformer (DEIT) in
TensorFlow 2.
* Paper title:
[An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
\ No newline at end of file
- [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/pdf/2010.11929.pdf).
- [Training data-efficient image transformers & distillation through attention](https://arxiv.org/pdf/2012.12877.pdf).
......@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size: int = 1
patch_size: int = 16
transformer: Transformer = Transformer()
init_stochastic_depth_rate: float = 0.0
original_init: bool = True
@dataclasses.dataclass
......
......@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot: bool = True
label_smoothing: float = 0.0
l2_weight_decay: float = 0.0
soft_labels: bool = False
@dataclasses.dataclass
......@@ -79,6 +81,843 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification.ImageClassificationTask)
@exp_factory.register_config_factory('deit_imagenet_pretrain_noaug')
def image_classification_imagenet_deit_imagenet_pretrain_noaug() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_noaug_sd')
def image_classification_imagenet_deit_imagenet_pretrain_noaug_sd() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_mixupandcutmix')
def image_classification_imagenet_deit_imagenet_pretrain_sd_mixupandcutmix() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=label_smoothing,
one_hot=False, soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
mixup_and_cutmix=common.MixupAndCutmix(
num_classes=num_classes,
label_smoothing=label_smoothing
)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing()),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase_randa')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing(),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout']))),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=label_smoothing,
one_hot=False, soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing(),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
num_classes=num_classes,
label_smoothing=label_smoothing
)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase_randa_mixup')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixup() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=label_smoothing,
one_hot=False, soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing(),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
num_classes=num_classes,
label_smoothing=label_smoothing,
cutmix_alpha=0
)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase_randa_cutmix')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_cutmix() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=label_smoothing,
one_hot=False, soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing(),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
num_classes=num_classes,
label_smoothing=label_smoothing,
mixup_alpha=0
)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix_sanity')
def image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix_sanity() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
num_classes = 1001
label_smoothing = 0.1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=num_classes,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=label_smoothing,
one_hot=False, soft_labels=True),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# repeated_aug=repeated_aug,
color_jitter=0.4,
random_erasing=common.RandomErasing(),
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(
magnitude=9, exclude_ops=['Cutout'])),
mixup_and_cutmix=common.MixupAndCutmix(
num_classes=num_classes,
label_smoothing=label_smoothing,
prob=0,
)),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('deit_imagenet_pretrain_sd_randacomplete')
def image_classification_imagenet_deit_imagenet_pretrain_sd_randacomplete() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096 # 1024
eval_batch_size = 4096 # 1024
repeated_aug = 1
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES * repeated_aug // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
model_name='vit-b16',
representation_size=768,
init_stochastic_depth_rate=0.1,
original_init=False,
transformer=backbones.Transformer(
dropout_rate=0.0, attention_dropout_rate=0.0)))),
losses=Losses(l2_weight_decay=0.0, label_smoothing=0.1),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size,
# # repeated_aug=repeated_aug,
color_jitter=0.4,
aug_type=common.Augmentation(
type='randaug', randaug=common.RandAugment(magnitude=9))),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.05,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.0005 * train_batch_size / 512,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 5 * steps_per_epoch,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain_deitinit')
def image_classification_imagenet_vit_pretrain_deitinit() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
train_batch_size = 4096
eval_batch_size = 4096
steps_per_epoch = IMAGENET_TRAIN_EXAMPLES // train_batch_size
config = cfg.ExperimentConfig(
task=ImageClassificationTask(
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
original_init=False,
model_name='vit-b16',
representation_size=768))),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'train*'),
is_training=True,
global_batch_size=train_batch_size),
validation_data=DataConfig(
input_path=os.path.join(IMAGENET_INPUT_PATH_BASE, 'valid*'),
is_training=False,
global_batch_size=eval_batch_size)),
trainer=cfg.TrainerConfig(
steps_per_loop=steps_per_epoch,
summary_interval=steps_per_epoch,
checkpoint_interval=steps_per_epoch,
train_steps=300 * steps_per_epoch,
validation_steps=IMAGENET_VAL_EXAMPLES // eval_batch_size,
validation_interval=steps_per_epoch,
optimizer_config=optimization.OptimizationConfig({
'optimizer': {
'type': 'adamw',
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
'warmup': {
'type': 'linear',
'linear': {
'warmup_steps': 10000,
'warmup_learning_rate': 0
}
}
})),
restrictions=[
'task.train_data.is_training != None',
'task.validation_data.is_training != None'
])
return config
@exp_factory.register_config_factory('vit_imagenet_pretrain')
def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
"""Image classification on imagenet with vision transformer."""
......@@ -90,6 +929,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model=ImageClassificationModel(
num_classes=1001,
input_size=[224, 224, 3],
kernel_initializer='zeros',
backbone=backbones.Backbone(
type='vit',
vit=backbones.VisionTransformer(
......@@ -116,12 +956,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw': {
'weight_decay_rate': 0.3,
'include_in_weight_decay': r'.*(kernel|weight):0$',
'gradient_clip_norm': 0.0
}
},
'learning_rate': {
'type': 'cosine',
'cosine': {
'initial_learning_rate': 0.003,
'initial_learning_rate': 0.003 * train_batch_size / 4096,
'decay_steps': 300 * steps_per_epoch,
}
},
......
from official.vision.beta.projects.vit.modeling.layers.vit_transformer_encoder_block import TransformerEncoderBlock
\ No newline at end of file
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import tensorflow as tf
from official.vision.beta.modeling.layers.nn_layers import StochasticDepth
@tf.keras.utils.register_keras_serializable(package="Vision")
class TransformerEncoderBlock(tf.keras.layers.Layer):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. Here we ass support for stochastic depth.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def __init__(self,
num_attention_heads,
inner_dim,
inner_activation,
output_range=None,
kernel_initializer="glorot_uniform",
bias_initializer="zeros",
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None,
bias_constraint=None,
use_bias=True,
norm_first=False,
norm_epsilon=1e-12,
output_dropout=0.0,
attention_dropout=0.0,
inner_dropout=0.0,
stochastic_depth_drop_rate=0.0,
attention_initializer=None,
attention_axes=None,
**kwargs):
"""Initializes `TransformerEncoderBlock`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
stochastic_depth_drop_rate: Dropout propobability for the stochastic depth
regularization.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super().__init__(**kwargs)
self._num_heads = num_attention_heads
self._inner_dim = inner_dim
self._inner_activation = inner_activation
self._attention_dropout = attention_dropout
self._attention_dropout_rate = attention_dropout
self._output_dropout = output_dropout
self._output_dropout_rate = output_dropout
self._output_range = output_range
self._kernel_initializer = tf.keras.initializers.get(kernel_initializer)
self._bias_initializer = tf.keras.initializers.get(bias_initializer)
self._kernel_regularizer = tf.keras.regularizers.get(kernel_regularizer)
self._bias_regularizer = tf.keras.regularizers.get(bias_regularizer)
self._activity_regularizer = tf.keras.regularizers.get(activity_regularizer)
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._use_bias = use_bias
self._norm_first = norm_first
self._norm_epsilon = norm_epsilon
self._inner_dropout = inner_dropout
self._stochastic_depth_drop_rate = stochastic_depth_drop_rate
if attention_initializer:
self._attention_initializer = tf.keras.initializers.get(
attention_initializer)
else:
self._attention_initializer = self._kernel_initializer
self._attention_axes = attention_axes
def build(self, input_shape):
if isinstance(input_shape, tf.TensorShape):
input_tensor_shape = input_shape
elif isinstance(input_shape, (list, tuple)):
input_tensor_shape = tf.TensorShape(input_shape[0])
else:
raise ValueError(
"The type of input shape argument is not supported, got: %s" %
type(input_shape))
einsum_equation = "abc,cd->abd"
if len(input_tensor_shape.as_list()) > 3:
einsum_equation = "...bc,cd->...bd"
hidden_size = input_tensor_shape[-1]
if hidden_size % self._num_heads != 0:
raise ValueError(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, self._num_heads))
self._attention_head_size = int(hidden_size // self._num_heads)
common_kwargs = dict(
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint)
self._attention_layer = tf.keras.layers.MultiHeadAttention(
num_heads=self._num_heads,
key_dim=self._attention_head_size,
dropout=self._attention_dropout,
use_bias=self._use_bias,
kernel_initializer=self._attention_initializer,
attention_axes=self._attention_axes,
name="self_attention",
**common_kwargs)
self._attention_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self._attention_layer_norm = (
tf.keras.layers.LayerNormalization(
name="self_attention_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32))
self._intermediate_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, self._inner_dim),
bias_axes="d",
kernel_initializer=self._kernel_initializer,
name="intermediate",
**common_kwargs)
policy = tf.keras.mixed_precision.global_policy()
if policy.name == "mixed_bfloat16":
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy = tf.float32
self._intermediate_activation_layer = tf.keras.layers.Activation(
self._inner_activation, dtype=policy)
self._inner_dropout_layer = tf.keras.layers.Dropout(
rate=self._inner_dropout)
self._output_dense = tf.keras.layers.experimental.EinsumDense(
einsum_equation,
output_shape=(None, hidden_size),
bias_axes="d",
name="output",
kernel_initializer=self._kernel_initializer,
**common_kwargs)
self._output_dropout = tf.keras.layers.Dropout(rate=self._output_dropout)
# Use float32 in layernorm for numeric stability.
self._output_layer_norm = tf.keras.layers.LayerNormalization(
name="output_layer_norm",
axis=-1,
epsilon=self._norm_epsilon,
dtype=tf.float32)
if self._stochastic_depth_drop_rate:
self._stochastic_depth = StochasticDepth(
self._stochastic_depth_drop_rate)
else:
self._stochastic_depth = None
super(TransformerEncoderBlock, self).build(input_shape)
def get_config(self):
config = {
"num_attention_heads":
self._num_heads,
"inner_dim":
self._inner_dim,
"inner_activation":
self._inner_activation,
"output_dropout":
self._output_dropout_rate,
"attention_dropout":
self._attention_dropout_rate,
"output_range":
self._output_range,
"kernel_initializer":
tf.keras.initializers.serialize(self._kernel_initializer),
"bias_initializer":
tf.keras.initializers.serialize(self._bias_initializer),
"kernel_regularizer":
tf.keras.regularizers.serialize(self._kernel_regularizer),
"bias_regularizer":
tf.keras.regularizers.serialize(self._bias_regularizer),
"activity_regularizer":
tf.keras.regularizers.serialize(self._activity_regularizer),
"kernel_constraint":
tf.keras.constraints.serialize(self._kernel_constraint),
"bias_constraint":
tf.keras.constraints.serialize(self._bias_constraint),
"use_bias":
self._use_bias,
"norm_first":
self._norm_first,
"norm_epsilon":
self._norm_epsilon,
"inner_dropout":
self._inner_dropout,
"stochastic_depth_drop_rate":
self._stochastic_depth_drop_rate,
"attention_initializer":
tf.keras.initializers.serialize(self._attention_initializer),
"attention_axes": self._attention_axes,
}
base_config = super(TransformerEncoderBlock, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def call(self, inputs, training=None):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if isinstance(inputs, (list, tuple)):
if len(inputs) == 2:
input_tensor, attention_mask = inputs
key_value = None
elif len(inputs) == 3:
input_tensor, key_value, attention_mask = inputs
else:
raise ValueError("Unexpected inputs to %s with length at %d" %
(self.__class__, len(inputs)))
else:
input_tensor, key_value, attention_mask = (inputs, None, None)
with_stochastic_depth = training and self._stochastic_depth
if self._output_range:
if self._norm_first:
source_tensor = input_tensor[:, 0:self._output_range, :]
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor[:, 0:self._output_range, :]
if attention_mask is not None:
attention_mask = attention_mask[:, 0:self._output_range, :]
else:
if self._norm_first:
source_tensor = input_tensor
input_tensor = self._attention_layer_norm(input_tensor)
if key_value is not None:
key_value = self._attention_layer_norm(key_value)
target_tensor = input_tensor
if key_value is None:
key_value = input_tensor
attention_output = self._attention_layer(
query=target_tensor, value=key_value, attention_mask=attention_mask)
attention_output = self._attention_dropout(attention_output)
if self._norm_first:
attention_output = source_tensor + self._stochastic_depth(
attention_output, training=with_stochastic_depth)
else:
attention_output = self._attention_layer_norm(
target_tensor +
self._stochastic_depth(attention_output, training=with_stochastic_depth)
)
if self._norm_first:
source_attention_output = attention_output
attention_output = self._output_layer_norm(attention_output)
inner_output = self._intermediate_dense(attention_output)
inner_output = self._intermediate_activation_layer(inner_output)
inner_output = self._inner_dropout_layer(inner_output)
layer_output = self._output_dense(inner_output)
layer_output = self._output_dropout(layer_output)
if self._norm_first:
return source_attention_output + self._stochastic_depth(
layer_output, training=with_stochastic_depth)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output = tf.cast(layer_output, tf.float32)
return self._output_layer_norm(
layer_output
+ self._stochastic_depth(attention_output, training=with_stochastic_depth)
)
......@@ -19,6 +19,8 @@ import tensorflow as tf
from official.modeling import activations
from official.nlp import keras_nlp
from official.vision.beta.modeling.backbones import factory
from official.vision.beta.modeling.layers import nn_layers
from official.vision.beta.projects.vit.modeling.layers import TransformerEncoderBlock
layers = tf.keras.layers
......@@ -29,6 +31,18 @@ VIT_SPECS = {
patch_size=16,
transformer=dict(mlp_dim=1, num_heads=1, num_layers=1),
),
'vit-ti16':
dict(
hidden_size=192,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=3, num_layers=12),
),
'vit-s16':
dict(
hidden_size=384,
patch_size=16,
transformer=dict(mlp_dim=3072, num_heads=6, num_layers=12),
),
'vit-b16':
dict(
hidden_size=768,
......@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate=0.1,
kernel_regularizer=None,
inputs_positions=None,
init_stochastic_depth_rate=0.0,
kernel_initializer='glorot_uniform',
**kwargs):
super().__init__(**kwargs)
self._num_layers = num_layers
......@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer):
self._attention_dropout_rate = attention_dropout_rate
self._kernel_regularizer = kernel_regularizer
self._inputs_positions = inputs_positions
self._init_stochastic_depth_rate = init_stochastic_depth_rate
self._kernel_initializer = kernel_initializer
def build(self, input_shape):
self._pos_embed = AddPositionEmbs(
......@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer):
self._encoder_layers = []
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for _ in range(self._num_layers):
encoder_layer = keras_nlp.layers.TransformerEncoderBlock(
for i in range(self._num_layers):
encoder_layer = TransformerEncoderBlock(
inner_activation=activations.gelu,
num_attention_heads=self._num_heads,
inner_dim=self._mlp_dim,
output_dropout=self._dropout_rate,
attention_dropout=self._attention_dropout_rate,
kernel_regularizer=self._kernel_regularizer,
kernel_initializer=self._kernel_initializer,
norm_first=True,
stochastic_depth_drop_rate=nn_layers.get_stochastic_depth_rate(
self._init_stochastic_depth_rate, i, self._num_layers - 1),
norm_epsilon=1e-6)
self._encoder_layers.append(encoder_layer)
self._norm = layers.LayerNormalization(epsilon=1e-6)
......@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model):
num_layers=12,
attention_dropout_rate=0.0,
dropout_rate=0.1,
init_stochastic_depth_rate=0.0,
input_specs=layers.InputSpec(shape=[None, None, None, 3]),
patch_size=16,
hidden_size=768,
representation_size=0,
classifier='token',
kernel_regularizer=None):
kernel_regularizer=None,
original_init=True):
"""VisionTransformer initialization function."""
inputs = tf.keras.Input(shape=input_specs.shape[1:])
......@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size=patch_size,
strides=patch_size,
padding='valid',
kernel_regularizer=kernel_regularizer)(
kernel_regularizer=kernel_regularizer,
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
inputs)
if tf.keras.backend.image_data_format() == 'channels_last':
rows_axis, cols_axis = (1, 2)
......@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model):
num_heads=num_heads,
dropout_rate=dropout_rate,
attention_dropout_rate=attention_dropout_rate,
kernel_regularizer=kernel_regularizer)(
kernel_regularizer=kernel_regularizer,
kernel_initializer='glorot_uniform' if original_init else dict(
class_name='TruncatedNormal', config=dict(stddev=.02)),
init_stochastic_depth_rate=init_stochastic_depth_rate)(
x)
if classifier == 'token':
......@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model):
x = tf.keras.layers.Dense(
representation_size,
kernel_regularizer=kernel_regularizer,
name='pre_logits')(
name='pre_logits',
kernel_initializer='lecun_normal' if original_init else 'he_uniform')(
x)
x = tf.nn.tanh(x)
else:
......@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model):
tf.reshape(x, [-1, 1, 1, representation_size or hidden_size])
}
super(VisionTransformer, self).__init__(inputs=inputs, outputs=endpoints)
super(VisionTransformer, self).__init__(
inputs=inputs, outputs=endpoints)
@factory.register_backbone_builder('vit')
......@@ -247,9 +276,11 @@ def build_vit(input_specs,
num_layers=backbone_cfg.transformer.num_layers,
attention_dropout_rate=backbone_cfg.transformer.attention_dropout_rate,
dropout_rate=backbone_cfg.transformer.dropout_rate,
init_stochastic_depth_rate=backbone_cfg.init_stochastic_depth_rate,
input_specs=input_specs,
patch_size=backbone_cfg.patch_size,
hidden_size=backbone_cfg.hidden_size,
representation_size=backbone_cfg.representation_size,
classifier=backbone_cfg.classifier,
kernel_regularizer=l2_regularizer)
kernel_regularizer=l2_regularizer,
original_init=backbone_cfg.original_init)
......@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@exp_factory.register_config_factory('darknet_classification')
def image_classification() -> cfg.ExperimentConfig:
def darknet_classification() -> cfg.ExperimentConfig:
"""Image classification general."""
return cfg.ExperimentConfig(
task=ImageClassificationTask(),
......
......@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from official.vision.beta.dataloaders import input_reader_factory
from official.vision.beta.dataloaders import tfds_factory
from official.vision.beta.modeling import factory
from official.vision.beta.ops import augment
@task_factory.register_task_cls(exp_cfg.ImageClassificationTask)
......@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only=params.decode_jpeg_only,
aug_rand_hflip=params.aug_rand_hflip,
aug_type=params.aug_type,
color_jitter=params.color_jitter,
random_erasing=params.random_erasing,
is_multilabel=is_multilabel,
dtype=params.dtype)
postprocess_fn = None
if params.mixup_and_cutmix:
postprocess_fn = augment.MixupAndCutmix(
mixup_alpha=params.mixup_and_cutmix.mixup_alpha,
cutmix_alpha=params.mixup_and_cutmix.cutmix_alpha,
prob=params.mixup_and_cutmix.prob,
label_smoothing=params.mixup_and_cutmix.label_smoothing,
num_classes=params.mixup_and_cutmix.num_classes
)
reader = input_reader_factory.input_reader_generator(
params,
dataset_fn=dataset_fn.pick_dataset_fn(params.file_type),
decoder_fn=decoder.decode,
parser_fn=parser.parse_fn(params.is_training))
parser_fn=parser.parse_fn(params.is_training),
postprocess_fn=postprocess_fn)
dataset = reader.read(input_context=input_context)
......@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task):
def build_losses(self,
labels: tf.Tensor,
model_outputs: tf.Tensor,
is_validation: bool,
aux_losses: Optional[Any] = None) -> tf.Tensor:
"""Builds sparse categorical cross entropy loss.
Args:
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
is_validation: To handle that some augmentations need custom soft labels
while the validation should remain unchainged.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
......@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
if losses_config.one_hot:
# Some augmentation need custom soft labels in training, but validation
# should remain unchainged
if losses_config.one_hot or is_validation:
total_loss = tf.keras.losses.categorical_crossentropy(
labels,
model_outputs,
from_logits=True,
label_smoothing=losses_config.label_smoothing)
elif losses_config.soft_labels:
total_loss = tf.nn.softmax_cross_entropy_with_logits(
labels,
model_outputs
)
else:
total_loss = tf.keras.losses.sparse_categorical_crossentropy(
labels, model_outputs, from_logits=True)
......@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel = self.task_config.train_data.is_multilabel
if not is_multilabel:
k = self.task_config.evaluation.top_k
if self.task_config.losses.one_hot:
if (self.task_config.losses.one_hot
or self.task_config.losses.soft_labels):
metrics = [
tf.keras.metrics.CategoricalAccuracy(name='accuracy'),
tf.keras.metrics.TopKCategoricalAccuracy(
......@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task):
lambda x: tf.cast(x, tf.float32), outputs)
# Computes per-replica loss.
loss = self.build_losses(
model_outputs=outputs, labels=labels, aux_losses=model.losses)
loss = self.build_losses(model_outputs=outputs, labels=labels,
is_validation=False, aux_losses=model.losses)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss = loss / num_replicas
......@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
"""
features, labels = inputs
one_hot = self.task_config.losses.one_hot
soft_labels = self.task_config.losses.soft_labels
is_multilabel = self.task_config.train_data.is_multilabel
if self.task_config.losses.one_hot and not is_multilabel:
if (one_hot or soft_labels) and not is_multilabel:
labels = tf.one_hot(labels, self.task_config.model.num_classes)
outputs = self.inference_step(features, model)
outputs = tf.nest.map_structure(lambda x: tf.cast(x, tf.float32), outputs)
loss = self.build_losses(model_outputs=outputs, labels=labels,
aux_losses=model.losses)
is_validation=True, aux_losses=model.losses)
logs = {self.loss: loss}
if metrics:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment