Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40cd0a26
Commit
40cd0a26
authored
Aug 13, 2021
by
Simon Geisler
Browse files
deit without repeated aug and distillation
parent
3db445c7
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1763 additions
and
34 deletions
+1763
-34
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+26
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+5
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+28
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+1
-0
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+287
-11
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+77
-0
official/vision/beta/ops/preprocess_ops.py
official/vision/beta/ops/preprocess_ops.py
+83
-0
official/vision/beta/projects/vit/README.md
official/vision/beta/projects/vit/README.md
+6
-4
official/vision/beta/projects/vit/configs/backbones.py
official/vision/beta/projects/vit/configs/backbones.py
+2
-0
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+842
-1
official/vision/beta/projects/vit/modeling/layers/__init__.py
...cial/vision/beta/projects/vit/modeling/layers/__init__.py
+1
-0
official/vision/beta/projects/vit/modeling/layers/vit_transformer_encoder_block.py
...ects/vit/modeling/layers/vit_transformer_encoder_block.py
+331
-0
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+39
-8
official/vision/beta/projects/yolo/configs/darknet_classification.py
...sion/beta/projects/yolo/configs/darknet_classification.py
+1
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+34
-7
No files found.
official/vision/beta/configs/common.py
View file @
40cd0a26
...
...
@@ -15,7 +15,7 @@
# Lint as: python3
"""Common configurations."""
from
typing
import
Optional
from
typing
import
Optional
,
List
# Import libraries
import
dataclasses
...
...
@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config):
cutout_const
:
float
=
40
translate_const
:
float
=
10
prob_to_apply
:
Optional
[
float
]
=
None
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
@
dataclasses
.
dataclass
...
...
@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config):
translate_const
:
float
=
250
@
dataclasses
.
dataclass
class
RandomErasing
(
hyperparams
.
Config
):
"""Configuration for RandomErasing."""
probability
:
float
=
0.25
min_area
:
float
=
0.02
max_area
:
float
=
1
/
3
min_aspect
:
float
=
0.3
max_aspect
=
None
min_count
=
1
max_count
=
1
trials
=
10
@
dataclasses
.
dataclass
class
MixupAndCutmix
(
hyperparams
.
Config
):
"""Configuration for MixupAndCutmix."""
mixup_alpha
:
float
=
.
8
cutmix_alpha
:
float
=
1.
prob
:
float
=
1.0
switch_prob
:
float
=
0.5
label_smoothing
:
float
=
0.1
num_classes
:
int
=
1000
@
dataclasses
.
dataclass
class
Augmentation
(
hyperparams
.
OneOfConfig
):
"""Configuration for input data augmentation.
...
...
official/vision/beta/configs/image_classification.py
View file @
40cd0a26
...
...
@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip
:
bool
=
True
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
decode_jpeg_only
:
bool
=
True
mixup_and_cutmix
:
Optional
[
common
.
MixupAndCutmix
]
=
None
# Keep for backward compatibility.
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'.
...
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
...
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
40cd0a26
...
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
is_multilabel
:
bool
=
False
,
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
...
...
@@ -85,6 +87,7 @@ class Parser(parser.Parser):
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
color_jitter: if > 0 the input image will be augmented by color jitter.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
...
...
@@ -113,13 +116,28 @@ class Parser(parser.Parser):
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
)
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
else
:
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_color_jitter
=
color_jitter
if
random_erasing
:
self
.
_random_erasing
=
augment
.
RandomErasing
(
probability
=
random_erasing
.
probability
,
min_area
=
random_erasing
.
min_area
,
max_area
=
random_erasing
.
max_area
,
min_aspect
=
random_erasing
.
min_aspect
,
max_aspect
=
random_erasing
.
max_aspect
,
min_count
=
random_erasing
.
min_count
,
max_count
=
random_erasing
.
max_count
,
trials
=
random_erasing
.
trials
)
else
:
self
.
_random_erasing
=
None
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
...
...
@@ -213,11 +231,20 @@ class Parser(parser.Parser):
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
official/vision/beta/modeling/factory.py
View file @
40cd0a26
...
...
@@ -56,6 +56,7 @@ def build_classification_model(
num_classes
=
model_config
.
num_classes
,
input_specs
=
input_specs
,
dropout_rate
=
model_config
.
dropout_rate
,
kernel_initializer
=
model_config
.
kernel_initializer
,
kernel_regularizer
=
l2_regularizer
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
...
...
official/vision/beta/ops/augment.py
View file @
40cd0a26
...
...
@@ -12,10 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Auto
Augment
and RandAugment
policies for enhanced image/video preprocessing.
"""Augment
ation
policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by https://github.com/rwightman/pytorch-image-models
"""
import
math
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
...
...
@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
image
=
_fill_rectangle
(
image
,
cutout_center_width
,
cutout_center_height
,
pad_size
,
pad_size
,
replace
)
return
image
def
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
):
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
lower_pad
=
tf
.
maximum
(
0
,
center_height
-
half_height
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
center_height
-
half_height
)
left_pad
=
tf
.
maximum
(
0
,
center_width
-
half_width
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
center_width
-
half_width
)
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
...
...
@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
if
replace
is
None
:
fill
=
tf
.
random
.
normal
(
tf
.
shape
(
image
),
dtype
=
image
.
dtype
)
elif
isinstance
(
replace
,
tf
.
Tensor
):
fill
=
replace
else
:
fill
=
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
fill
,
image
)
return
image
...
...
@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float):
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
translate_const
:
float
,
level_std
:
float
=
0.
)
->
Tuple
[
Any
,
float
,
Any
]:
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
if
level_std
>
0
:
level
+=
tf
.
random
.
normal
([],
dtype
=
tf
.
float32
)
level
=
tf
.
clip_by_value
(
level
,
0.
,
_MAX_LEVEL
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
if
name
in
REPLACE_FUNCS
:
...
...
@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment):
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
prob_to_apply
:
Optional
[
float
]
=
None
):
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
List
[
str
]
=
[]):
"""Applies the RandAugment policy to images.
Args:
...
...
@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment):
[5, 10].
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
layer.
exclude_ops: exclude selected operations.
"""
super
(
RandAugment
,
self
).
__init__
()
...
...
@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment):
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
]
self
.
magnitude_std
=
magnitude_std
self
.
available_ops
=
[
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the RandAugment policy to `image`.
...
...
@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment):
dtype
=
tf
.
float32
)
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
replace_value
,
self
.
cutout_const
,
self
.
translate_const
)
self
.
translate_const
,
self
.
magnitude_std
)
branch_fns
.
append
((
i
,
# pylint:disable=g-long-lambda
...
...
@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment):
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
class
RandomErasing
(
ImageAugment
):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
probability
:
float
=
0.25
,
min_area
:
float
=
0.02
,
max_area
:
float
=
1
/
3
,
min_aspect
:
float
=
0.3
,
max_aspect
=
None
,
min_count
=
1
,
max_count
=
1
,
trials
=
10
):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing
rectangle. Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing
rectangle. Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random
erasing rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased
rectangles. Defaults to 1.
max_count (int, optional): Maximum number of erased
rectangles. Defaults to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self
.
_probability
=
probability
self
.
_min_area
=
float
(
min_area
)
self
.
_max_area
=
float
(
max_area
)
self
.
_min_log_aspect
=
math
.
log
(
min_aspect
)
self
.
_max_log_aspect
=
math
.
log
(
max_aspect
or
1
/
min_aspect
)
self
.
_min_count
=
min_count
self
.
_max_count
=
max_count
self
.
_trials
=
trials
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
mirror_cond
=
tf
.
less
(
uniform_random
,
.
5
)
tf
.
cond
(
mirror_cond
,
self
.
_erase
,
lambda
:
image
)
return
image
@
tf
.
function
def
_erase
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
count
=
self
.
_min_count
if
self
.
_min_count
==
self
.
_max_count
else
\
tf
.
random
.
uniform
(
shape
=
[],
minval
=
int
(
self
.
_min_count
),
maxval
=
int
(
self
.
_max_count
-
self
.
_min_count
+
1
),
dtype
=
tf
.
int32
)
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
for
_
in
range
(
count
):
for
_
in
range
(
self
.
_trials
):
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
area
*
self
.
_min_area
,
maxval
=
area
*
self
.
_max_area
)
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
self
.
_min_log_aspect
,
maxval
=
self
.
_max_log_aspect
))
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
dtype
=
tf
.
int32
)
center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
dtype
=
tf
.
int32
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
break
return
image
class
MixupAndCutmix
:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
mixup_alpha
:
float
=
.
8
,
cutmix_alpha
:
float
=
1.
,
prob
:
float
=
1.0
,
switch_prob
:
float
=
0.5
,
label_smoothing
:
float
=
0.1
,
num_classes
:
int
=
1001
):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from
a beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self
.
mixup_alpha
=
mixup_alpha
self
.
cutmix_alpha
=
cutmix_alpha
self
.
mix_prob
=
prob
self
.
switch_prob
=
switch_prob
self
.
label_smoothing
=
label_smoothing
self
.
num_classes
=
num_classes
self
.
mode
=
'batch'
self
.
mixup_enabled
=
True
if
self
.
mixup_alpha
and
not
self
.
cutmix_alpha
:
self
.
switch_prob
=
-
1
elif
not
self
.
mixup_alpha
and
self
.
cutmix_alpha
:
self
.
switch_prob
=
1
def
__call__
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
return
self
.
distort
(
images
,
labels
)
def
distort
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Applies Mixup and/or Cutmix to batch of `images` and transforms the
`labels` (incl. label smoothing).
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing
a batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond
=
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
mix_prob
)
return
tf
.
cond
(
augment_cond
,
lambda
:
self
.
_update_labels
(
*
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
switch_prob
),
lambda
:
self
.
_cutmix
(
images
,
labels
),
lambda
:
self
.
_mixup
(
images
,
labels
)
)),
lambda
:
(
images
,
self
.
_smooth_labels
(
labels
))
)
@
staticmethod
def
_sample_from_beta
(
alpha
:
float
,
beta
:
float
,
shape
:
tuple
):
sample_alpha
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
alpha
)
sample_beta
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
beta
)
return
sample_alpha
/
(
sample_alpha
+
sample_beta
)
def
_cutmix
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
labels
.
shape
)
ratio
=
tf
.
math
.
sqrt
(
1
-
lam
)
batch_size
=
tf
.
shape
(
images
)[
0
]
image_height
,
image_width
=
tf
.
shape
(
images
)[
1
],
tf
.
shape
(
images
)[
2
]
cut_height
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
cut_width
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
random_center_height
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
random_center_width
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
bbox_area
=
cut_height
*
cut_width
lam
=
1.
-
bbox_area
/
(
image_height
*
image_width
)
lam
=
tf
.
cast
(
lam
,
dtype
=
tf
.
float32
)
images
=
tf
.
map_fn
(
lambda
x
:
_fill_rectangle
(
*
x
),
(
images
,
random_center_width
,
random_center_height
,
cut_width
//
2
,
cut_height
//
2
,
tf
.
reverse
(
images
,
[
0
])),
dtype
=
(
tf
.
float32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
float32
),
fn_output_signature
=
tf
.
TensorSpec
(
images
.
shape
[
1
:],
dtype
=
tf
.
float32
))
return
images
,
labels
,
lam
def
_mixup
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
labels
.
shape
)
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
,
1
,
1
])
images
=
lam
*
images
+
(
1.
-
lam
)
*
tf
.
reverse
(
images
,
[
0
])
return
images
,
labels
,
tf
.
squeeze
(
lam
)
def
_smooth_labels
(
self
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
off_value
=
self
.
label_smoothing
/
self
.
num_classes
on_value
=
1.
-
self
.
label_smoothing
+
off_value
smooth_labels
=
tf
.
one_hot
(
labels
,
self
.
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
)
return
smooth_labels
def
_update_labels
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
lam
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
labels_1
=
self
.
_smooth_labels
(
labels
)
labels_2
=
tf
.
reverse
(
labels_1
,
[
0
])
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
])
labels
=
lam
*
labels_1
+
(
1.
-
lam
)
*
labels_2
return
images
,
labels
official/vision/beta/ops/augment_test.py
View file @
40cd0a26
...
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter
.
distort
(
image
)
class
RandomErasingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_random_erase_replaces_some_pixels
(
self
):
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandomErasing
(
probability
=
1.
,
max_count
=
10
)
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertLess
(
0
,
tf
.
reduce_max
(
aug_image
.
shape
))
class
MixupAndCutmixTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_mixup_and_cutmix_smoothes_labels
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
def
test_mixup_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
1.
,
cutmix_alpha
=
0.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertTrue
(
not
tf
.
math
.
reduce_all
(
images
==
aug_images
))
def
test_cutmix_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
0.
,
cutmix_alpha
=
1.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertTrue
(
not
tf
.
math
.
reduce_all
(
images
==
aug_images
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/ops/preprocess_ops.py
View file @
40cd0a26
...
...
@@ -15,10 +15,12 @@
"""Preprocessing ops."""
import
math
from
typing
import
Optional
from
six.moves
import
range
import
tensorflow
as
tf
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
augment
CENTER_CROP_FRACTION
=
0.875
...
...
@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
lambda
:
masks
)
return
image
,
normalized_boxes
,
masks
def
color_jitter
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
contrast
:
Optional
[
float
]
=
0.
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
return
image
def
random_brightness
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
brightness
>=
0
and
brightness
<=
1.
,
'`brightness` must be in [0, 1]'
brightness
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
)
return
augment
.
brightness
(
image
,
brightness
)
def
random_contrast
(
image
:
tf
.
Tensor
,
contrast
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
contrast
>=
0
and
contrast
<=
1.
,
'`contrast` must be in [0, 1]'
contrast
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
)
return
augment
.
contrast
(
image
,
contrast
)
def
random_saturation
(
image
:
tf
.
Tensor
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
saturation
>=
0
and
saturation
<=
1.
,
'`saturation` must be in [0, 1]'
saturation
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
)
return
augment
.
blend
(
tf
.
image
.
rgb_to_grayscale
(
image
),
image
,
saturation
)
official/vision/beta/projects/vit/README.md
View file @
40cd0a26
# Vision Transformer (ViT)
# Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in
This repository is the implementations of Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
in
TensorFlow 2.
*
Paper title:
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
\ No newline at end of file
-
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
-
[
Training data-efficient image transformers & distillation through attention
](
https://arxiv.org/pdf/2012.12877.pdf
)
.
official/vision/beta/projects/vit/configs/backbones.py
View file @
40cd0a26
...
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size
:
int
=
1
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
init_stochastic_depth_rate
:
float
=
0.0
original_init
:
bool
=
True
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/vit/configs/image_classification.py
View file @
40cd0a26
...
...
@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
...
...
@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
...
...
@@ -79,6 +81,843 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification
.
ImageClassificationTask
)
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_noaug'
)
def
image_classification_imagenet_deit_imagenet_pretrain_noaug
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_noaug_sd'
)
def
image_classification_imagenet_deit_imagenet_pretrain_noaug_sd
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_mixupandcutmix'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_mixupandcutmix
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
()),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase_randa'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
(),
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
]))),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
(),
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase_randa_mixup'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixup
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
(),
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
,
cutmix_alpha
=
0
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase_randa_cutmix'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_cutmix
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
(),
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
,
mixup_alpha
=
0
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix_sanity'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_erase_randa_mixupandcutmix_sanity
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# repeated_aug=repeated_aug,
color_jitter
=
0.4
,
random_erasing
=
common
.
RandomErasing
(),
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
,
prob
=
0
,
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain_sd_randacomplete'
)
def
image_classification_imagenet_deit_imagenet_pretrain_sd_randacomplete
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# 1024
eval_batch_size
=
4096
# 1024
repeated_aug
=
1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
*
repeated_aug
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
0.1
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
# # repeated_aug=repeated_aug,
color_jitter
=
0.4
,
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
))),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain_deitinit'
)
def
image_classification_imagenet_vit_pretrain_deitinit
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
eval_batch_size
=
4096
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
original_init
=
False
,
model_name
=
'vit-b16'
,
representation_size
=
768
))),
losses
=
Losses
(
l2_weight_decay
=
0.0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
*
train_batch_size
/
4096
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
10000
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
...
...
@@ -90,6 +929,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
...
...
@@ -116,12 +956,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'initial_learning_rate'
:
0.003
*
train_batch_size
/
4096
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
...
...
official/vision/beta/projects/vit/modeling/layers/__init__.py
0 → 100644
View file @
40cd0a26
from
official.vision.beta.projects.vit.modeling.layers.vit_transformer_encoder_block
import
TransformerEncoderBlock
\ No newline at end of file
official/vision/beta/projects/vit/modeling/layers/vit_transformer_encoder_block.py
0 → 100644
View file @
40cd0a26
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers.nn_layers
import
StochasticDepth
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Vision"
)
class
TransformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. Here we ass support for stochastic depth.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def
__init__
(
self
,
num_attention_heads
,
inner_dim
,
inner_activation
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
stochastic_depth_drop_rate
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
"""Initializes `TransformerEncoderBlock`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
stochastic_depth_drop_rate: Dropout propobability for the stochastic depth
regularization.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_attention_heads
self
.
_inner_dim
=
inner_dim
self
.
_inner_activation
=
inner_activation
self
.
_attention_dropout
=
attention_dropout
self
.
_attention_dropout_rate
=
attention_dropout
self
.
_output_dropout
=
output_dropout
self
.
_output_dropout_rate
=
output_dropout
self
.
_output_range
=
output_range
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_inner_dropout
=
inner_dropout
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
"The type of input shape argument is not supported, got: %s"
%
type
(
input_shape
))
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_attention_layer
=
tf
.
keras
.
layers
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
None
super
(
TransformerEncoderBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
_num_heads
,
"inner_dim"
:
self
.
_inner_dim
,
"inner_activation"
:
self
.
_inner_activation
,
"output_dropout"
:
self
.
_output_dropout_rate
,
"attention_dropout"
:
self
.
_attention_dropout_rate
,
"output_range"
:
self
.
_output_range
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"inner_dropout"
:
self
.
_inner_dropout
,
"stochastic_depth_drop_rate"
:
self
.
_stochastic_depth_drop_rate
,
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
}
base_config
=
super
(
TransformerEncoderBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
with_stochastic_depth
=
training
and
self
.
_stochastic_depth
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
with_stochastic_depth
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
)
official/vision/beta/projects/vit/modeling/vit.py
View file @
40cd0a26
...
...
@@ -19,6 +19,8 @@ import tensorflow as tf
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.vit.modeling.layers
import
TransformerEncoderBlock
layers
=
tf
.
keras
.
layers
...
...
@@ -29,6 +31,18 @@ VIT_SPECS = {
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
),
'vit-ti16'
:
dict
(
hidden_size
=
192
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
3
,
num_layers
=
12
),
),
'vit-s16'
:
dict
(
hidden_size
=
384
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
6
,
num_layers
=
12
),
),
'vit-b16'
:
dict
(
hidden_size
=
768
,
...
...
@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
...
...
@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer):
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
...
...
@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer):
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for
_
in
range
(
self
.
_num_layers
):
encoder_layer
=
keras_nlp
.
layers
.
TransformerEncoderBlock
(
for
i
in
range
(
self
.
_num_layers
):
encoder_layer
=
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
norm_first
=
True
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
,
self
.
_num_layers
-
1
),
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
...
...
@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model):
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
init_stochastic_depth_rate
=
0.0
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
hidden_size
=
768
,
representation_size
=
0
,
classifier
=
'token'
,
kernel_regularizer
=
None
):
kernel_regularizer
=
None
,
original_init
=
True
):
"""VisionTransformer initialization function."""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
...
...
@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
...
...
@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model):
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'glorot_uniform'
if
original_init
else
dict
(
class_name
=
'TruncatedNormal'
,
config
=
dict
(
stddev
=
.
02
)),
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)(
x
)
if
classifier
==
'token'
:
...
...
@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
)(
name
=
'pre_logits'
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
...
...
@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model):
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
@
factory
.
register_backbone_builder
(
'vit'
)
...
...
@@ -247,9 +276,11 @@ def build_vit(input_specs,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
init_stochastic_depth_rate
=
backbone_cfg
.
init_stochastic_depth_rate
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifier
=
backbone_cfg
.
classifier
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
)
official/vision/beta/projects/yolo/configs/darknet_classification.py
View file @
40cd0a26
...
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
def
image
_classification
()
->
cfg
.
ExperimentConfig
:
def
darknet
_classification
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(),
...
...
official/vision/beta/tasks/image_classification.py
View file @
40cd0a26
...
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
augment
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
...
...
@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
postprocess_fn
=
None
if
params
.
mixup_and_cutmix
:
postprocess_fn
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
params
.
mixup_and_cutmix
.
mixup_alpha
,
cutmix_alpha
=
params
.
mixup_and_cutmix
.
cutmix_alpha
,
prob
=
params
.
mixup_and_cutmix
.
prob
,
label_smoothing
=
params
.
mixup_and_cutmix
.
label_smoothing
,
num_classes
=
params
.
mixup_and_cutmix
.
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
...
@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task):
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
is_validation
:
bool
,
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
"""Builds sparse categorical cross entropy loss.
Args:
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
is_validation: To handle that some augmentations need custom soft labels
while the validation should remain unchainged.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
...
...
@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
if
losses_config
.
one_hot
:
# Some augmentation need custom soft labels in training, but validation
# should remain unchainged
if
losses_config
.
one_hot
or
is_validation
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
elif
losses_config
.
soft_labels
:
total_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
model_outputs
)
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
...
...
@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
if
(
self
.
task_config
.
losses
.
one_hot
or
self
.
task_config
.
losses
.
soft_labels
):
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
...
...
@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task):
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labe
ls
,
aux_losses
=
model
.
losses
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
is_validation
=
Fa
ls
e
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
scaled_loss
=
loss
/
num_replicas
...
...
@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
"""
features
,
labels
=
inputs
one_hot
=
self
.
task_config
.
losses
.
one_hot
soft_labels
=
self
.
task_config
.
losses
.
soft_labels
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
if
(
one_hot
or
soft_labels
)
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
is_validation
=
True
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment