Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40cd0a26
Commit
40cd0a26
authored
Aug 13, 2021
by
Simon Geisler
Browse files
deit without repeated aug and distillation
parent
3db445c7
Changes
15
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
1763 additions
and
34 deletions
+1763
-34
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+26
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+5
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+28
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+1
-0
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+287
-11
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+77
-0
official/vision/beta/ops/preprocess_ops.py
official/vision/beta/ops/preprocess_ops.py
+83
-0
official/vision/beta/projects/vit/README.md
official/vision/beta/projects/vit/README.md
+6
-4
official/vision/beta/projects/vit/configs/backbones.py
official/vision/beta/projects/vit/configs/backbones.py
+2
-0
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+842
-1
official/vision/beta/projects/vit/modeling/layers/__init__.py
...cial/vision/beta/projects/vit/modeling/layers/__init__.py
+1
-0
official/vision/beta/projects/vit/modeling/layers/vit_transformer_encoder_block.py
...ects/vit/modeling/layers/vit_transformer_encoder_block.py
+331
-0
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+39
-8
official/vision/beta/projects/yolo/configs/darknet_classification.py
...sion/beta/projects/yolo/configs/darknet_classification.py
+1
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+34
-7
No files found.
official/vision/beta/configs/common.py
View file @
40cd0a26
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
# Lint as: python3
# Lint as: python3
"""Common configurations."""
"""Common configurations."""
from
typing
import
Optional
from
typing
import
Optional
,
List
# Import libraries
# Import libraries
import
dataclasses
import
dataclasses
...
@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config):
...
@@ -32,6 +32,7 @@ class RandAugment(hyperparams.Config):
cutout_const
:
float
=
40
cutout_const
:
float
=
40
translate_const
:
float
=
10
translate_const
:
float
=
10
prob_to_apply
:
Optional
[
float
]
=
None
prob_to_apply
:
Optional
[
float
]
=
None
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config):
...
@@ -42,6 +43,30 @@ class AutoAugment(hyperparams.Config):
translate_const
:
float
=
250
translate_const
:
float
=
250
@
dataclasses
.
dataclass
class
RandomErasing
(
hyperparams
.
Config
):
"""Configuration for RandomErasing."""
probability
:
float
=
0.25
min_area
:
float
=
0.02
max_area
:
float
=
1
/
3
min_aspect
:
float
=
0.3
max_aspect
=
None
min_count
=
1
max_count
=
1
trials
=
10
@
dataclasses
.
dataclass
class
MixupAndCutmix
(
hyperparams
.
Config
):
"""Configuration for MixupAndCutmix."""
mixup_alpha
:
float
=
.
8
cutmix_alpha
:
float
=
1.
prob
:
float
=
1.0
switch_prob
:
float
=
0.5
label_smoothing
:
float
=
0.1
num_classes
:
int
=
1000
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Augmentation
(
hyperparams
.
OneOfConfig
):
class
Augmentation
(
hyperparams
.
OneOfConfig
):
"""Configuration for input data augmentation.
"""Configuration for input data augmentation.
...
...
official/vision/beta/configs/image_classification.py
View file @
40cd0a26
...
@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig):
...
@@ -40,10 +40,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip
:
bool
=
True
aug_rand_hflip
:
bool
=
True
aug_type
:
Optional
[
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
file_type
:
str
=
'tfrecord'
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
label_field_key
:
str
=
'image/class/label'
decode_jpeg_only
:
bool
=
True
decode_jpeg_only
:
bool
=
True
mixup_and_cutmix
:
Optional
[
common
.
MixupAndCutmix
]
=
None
# Keep for backward compatibility.
# Keep for backward compatibility.
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'.
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'.
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
40cd0a26
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only
:
bool
=
True
,
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
is_multilabel
:
bool
=
False
,
is_multilabel
:
bool
=
False
,
dtype
:
str
=
'float32'
):
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
"""Initializes parameters for parsing annotations in the dataset.
...
@@ -85,6 +87,7 @@ class Parser(parser.Parser):
...
@@ -85,6 +87,7 @@ class Parser(parser.Parser):
horizontal flip.
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
RandAugment.
color_jitter: if > 0 the input image will be augmented by color jitter.
is_multilabel: A `bool`, whether or not each example has multiple labels.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
or 'bfloat16'.
...
@@ -113,13 +116,28 @@ class Parser(parser.Parser):
...
@@ -113,13 +116,28 @@ class Parser(parser.Parser):
magnitude
=
aug_type
.
randaug
.
magnitude
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
)
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
aug_type
.
type
))
else
:
else
:
self
.
_augmenter
=
None
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_label_field_key
=
label_field_key
self
.
_color_jitter
=
color_jitter
if
random_erasing
:
self
.
_random_erasing
=
augment
.
RandomErasing
(
probability
=
random_erasing
.
probability
,
min_area
=
random_erasing
.
min_area
,
max_area
=
random_erasing
.
max_area
,
min_aspect
=
random_erasing
.
min_aspect
,
max_aspect
=
random_erasing
.
max_aspect
,
min_count
=
random_erasing
.
min_count
,
max_count
=
random_erasing
.
max_count
,
trials
=
random_erasing
.
trials
)
else
:
self
.
_random_erasing
=
None
self
.
_is_multilabel
=
is_multilabel
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
self
.
_decode_jpeg_only
=
decode_jpeg_only
...
@@ -213,11 +231,20 @@ class Parser(parser.Parser):
...
@@ -213,11 +231,20 @@ class Parser(parser.Parser):
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
official/vision/beta/modeling/factory.py
View file @
40cd0a26
...
@@ -56,6 +56,7 @@ def build_classification_model(
...
@@ -56,6 +56,7 @@ def build_classification_model(
num_classes
=
model_config
.
num_classes
,
num_classes
=
model_config
.
num_classes
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
dropout_rate
=
model_config
.
dropout_rate
,
dropout_rate
=
model_config
.
dropout_rate
,
kernel_initializer
=
model_config
.
kernel_initializer
,
kernel_regularizer
=
l2_regularizer
,
kernel_regularizer
=
l2_regularizer
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
...
...
official/vision/beta/ops/augment.py
View file @
40cd0a26
...
@@ -12,10 +12,17 @@
...
@@ -12,10 +12,17 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Auto
Augment
and RandAugment
policies for enhanced image/video preprocessing.
"""Augment
ation
policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by https://github.com/rwightman/pytorch-image-models
"""
"""
import
math
import
math
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
...
@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -295,10 +302,21 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width
=
tf
.
random
.
uniform
(
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
image
=
_fill_rectangle
(
image
,
cutout_center_width
,
cutout_center_height
,
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
pad_size
,
pad_size
,
replace
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
return
image
def
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
):
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
lower_pad
=
tf
.
maximum
(
0
,
center_height
-
half_height
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
center_height
-
half_height
)
left_pad
=
tf
.
maximum
(
0
,
center_width
-
half_width
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
center_width
-
half_width
)
cutout_shape
=
[
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
image_height
-
(
lower_pad
+
upper_pad
),
...
@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -311,9 +329,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values
=
1
)
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
if
replace
is
None
:
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
fill
=
tf
.
random
.
normal
(
tf
.
shape
(
image
),
dtype
=
image
.
dtype
)
elif
isinstance
(
replace
,
tf
.
Tensor
):
fill
=
replace
else
:
fill
=
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
fill
,
image
)
return
image
return
image
...
@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float):
...
@@ -805,9 +829,15 @@ def level_to_arg(cutout_const: float, translate_const: float):
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
translate_const
:
float
,
level_std
:
float
=
0.
)
->
Tuple
[
Any
,
float
,
Any
]:
"""Return the function that corresponds to `name` and update `level` param."""
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
func
=
NAME_TO_FUNC
[
name
]
if
level_std
>
0
:
level
+=
tf
.
random
.
normal
([],
dtype
=
tf
.
float32
)
level
=
tf
.
clip_by_value
(
level
,
0.
,
_MAX_LEVEL
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
if
name
in
REPLACE_FUNCS
:
if
name
in
REPLACE_FUNCS
:
...
@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment):
...
@@ -1184,7 +1214,9 @@ class RandAugment(ImageAugment):
magnitude
:
float
=
10.
,
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
translate_const
:
float
=
100.
,
prob_to_apply
:
Optional
[
float
]
=
None
):
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
List
[
str
]
=
[]):
"""Applies the RandAugment policy to images.
"""Applies the RandAugment policy to images.
Args:
Args:
...
@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment):
...
@@ -1196,8 +1228,11 @@ class RandAugment(ImageAugment):
[5, 10].
[5, 10].
cutout_const: multiplier for applying cutout.
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
prob_to_apply: The probability to apply the selected augmentation at each
layer.
layer.
exclude_ops: exclude selected operations.
"""
"""
super
(
RandAugment
,
self
).
__init__
()
super
(
RandAugment
,
self
).
__init__
()
...
@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment):
...
@@ -1212,6 +1247,9 @@ class RandAugment(ImageAugment):
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
]
]
self
.
magnitude_std
=
magnitude_std
self
.
available_ops
=
[
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the RandAugment policy to `image`.
"""Applies the RandAugment policy to `image`.
...
@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment):
...
@@ -1246,7 +1284,8 @@ class RandAugment(ImageAugment):
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
replace_value
,
self
.
cutout_const
,
replace_value
,
self
.
cutout_const
,
self
.
translate_const
)
self
.
translate_const
,
self
.
magnitude_std
)
branch_fns
.
append
((
branch_fns
.
append
((
i
,
i
,
# pylint:disable=g-long-lambda
# pylint:disable=g-long-lambda
...
@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment):
...
@@ -1267,3 +1306,240 @@ class RandAugment(ImageAugment):
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
return
image
class
RandomErasing
(
ImageAugment
):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
probability
:
float
=
0.25
,
min_area
:
float
=
0.02
,
max_area
:
float
=
1
/
3
,
min_aspect
:
float
=
0.3
,
max_aspect
=
None
,
min_count
=
1
,
max_count
=
1
,
trials
=
10
):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing
rectangle. Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing
rectangle. Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random
erasing rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased
rectangles. Defaults to 1.
max_count (int, optional): Maximum number of erased
rectangles. Defaults to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self
.
_probability
=
probability
self
.
_min_area
=
float
(
min_area
)
self
.
_max_area
=
float
(
max_area
)
self
.
_min_log_aspect
=
math
.
log
(
min_aspect
)
self
.
_max_log_aspect
=
math
.
log
(
max_aspect
or
1
/
min_aspect
)
self
.
_min_count
=
min_count
self
.
_max_count
=
max_count
self
.
_trials
=
trials
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
mirror_cond
=
tf
.
less
(
uniform_random
,
.
5
)
tf
.
cond
(
mirror_cond
,
self
.
_erase
,
lambda
:
image
)
return
image
@
tf
.
function
def
_erase
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
count
=
self
.
_min_count
if
self
.
_min_count
==
self
.
_max_count
else
\
tf
.
random
.
uniform
(
shape
=
[],
minval
=
int
(
self
.
_min_count
),
maxval
=
int
(
self
.
_max_count
-
self
.
_min_count
+
1
),
dtype
=
tf
.
int32
)
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
for
_
in
range
(
count
):
for
_
in
range
(
self
.
_trials
):
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
area
*
self
.
_min_area
,
maxval
=
area
*
self
.
_max_area
)
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
self
.
_min_log_aspect
,
maxval
=
self
.
_max_log_aspect
))
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
dtype
=
tf
.
int32
)
center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
dtype
=
tf
.
int32
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
break
return
image
class
MixupAndCutmix
:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
mixup_alpha
:
float
=
.
8
,
cutmix_alpha
:
float
=
1.
,
prob
:
float
=
1.0
,
switch_prob
:
float
=
0.5
,
label_smoothing
:
float
=
0.1
,
num_classes
:
int
=
1001
):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from
a beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self
.
mixup_alpha
=
mixup_alpha
self
.
cutmix_alpha
=
cutmix_alpha
self
.
mix_prob
=
prob
self
.
switch_prob
=
switch_prob
self
.
label_smoothing
=
label_smoothing
self
.
num_classes
=
num_classes
self
.
mode
=
'batch'
self
.
mixup_enabled
=
True
if
self
.
mixup_alpha
and
not
self
.
cutmix_alpha
:
self
.
switch_prob
=
-
1
elif
not
self
.
mixup_alpha
and
self
.
cutmix_alpha
:
self
.
switch_prob
=
1
def
__call__
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
return
self
.
distort
(
images
,
labels
)
def
distort
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Applies Mixup and/or Cutmix to batch of `images` and transforms the
`labels` (incl. label smoothing).
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing
a batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond
=
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
mix_prob
)
return
tf
.
cond
(
augment_cond
,
lambda
:
self
.
_update_labels
(
*
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
switch_prob
),
lambda
:
self
.
_cutmix
(
images
,
labels
),
lambda
:
self
.
_mixup
(
images
,
labels
)
)),
lambda
:
(
images
,
self
.
_smooth_labels
(
labels
))
)
@
staticmethod
def
_sample_from_beta
(
alpha
:
float
,
beta
:
float
,
shape
:
tuple
):
sample_alpha
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
alpha
)
sample_beta
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
beta
)
return
sample_alpha
/
(
sample_alpha
+
sample_beta
)
def
_cutmix
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
labels
.
shape
)
ratio
=
tf
.
math
.
sqrt
(
1
-
lam
)
batch_size
=
tf
.
shape
(
images
)[
0
]
image_height
,
image_width
=
tf
.
shape
(
images
)[
1
],
tf
.
shape
(
images
)[
2
]
cut_height
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
cut_width
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
random_center_height
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
random_center_width
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
bbox_area
=
cut_height
*
cut_width
lam
=
1.
-
bbox_area
/
(
image_height
*
image_width
)
lam
=
tf
.
cast
(
lam
,
dtype
=
tf
.
float32
)
images
=
tf
.
map_fn
(
lambda
x
:
_fill_rectangle
(
*
x
),
(
images
,
random_center_width
,
random_center_height
,
cut_width
//
2
,
cut_height
//
2
,
tf
.
reverse
(
images
,
[
0
])),
dtype
=
(
tf
.
float32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
float32
),
fn_output_signature
=
tf
.
TensorSpec
(
images
.
shape
[
1
:],
dtype
=
tf
.
float32
))
return
images
,
labels
,
lam
def
_mixup
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
labels
.
shape
)
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
,
1
,
1
])
images
=
lam
*
images
+
(
1.
-
lam
)
*
tf
.
reverse
(
images
,
[
0
])
return
images
,
labels
,
tf
.
squeeze
(
lam
)
def
_smooth_labels
(
self
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
off_value
=
self
.
label_smoothing
/
self
.
num_classes
on_value
=
1.
-
self
.
label_smoothing
+
off_value
smooth_labels
=
tf
.
one_hot
(
labels
,
self
.
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
)
return
smooth_labels
def
_update_labels
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
lam
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
labels_1
=
self
.
_smooth_labels
(
labels
)
labels_2
=
tf
.
reverse
(
labels_1
,
[
0
])
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
])
labels
=
lam
*
labels_1
+
(
1.
-
lam
)
*
labels_2
return
images
,
labels
official/vision/beta/ops/augment_test.py
View file @
40cd0a26
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter
.
distort
(
image
)
augmenter
.
distort
(
image
)
class
RandomErasingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_random_erase_replaces_some_pixels
(
self
):
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandomErasing
(
probability
=
1.
,
max_count
=
10
)
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertLess
(
0
,
tf
.
reduce_max
(
aug_image
.
shape
))
class
MixupAndCutmixTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_mixup_and_cutmix_smoothes_labels
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
def
test_mixup_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
1.
,
cutmix_alpha
=
0.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertTrue
(
not
tf
.
math
.
reduce_all
(
images
==
aug_images
))
def
test_cutmix_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
0.
,
cutmix_alpha
=
1.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertTrue
(
not
tf
.
math
.
reduce_all
(
images
==
aug_images
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/beta/ops/preprocess_ops.py
View file @
40cd0a26
...
@@ -15,10 +15,12 @@
...
@@ -15,10 +15,12 @@
"""Preprocessing ops."""
"""Preprocessing ops."""
import
math
import
math
from
typing
import
Optional
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
augment
CENTER_CROP_FRACTION
=
0.875
CENTER_CROP_FRACTION
=
0.875
...
@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
...
@@ -555,3 +557,84 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
lambda
:
masks
)
lambda
:
masks
)
return
image
,
normalized_boxes
,
masks
return
image
,
normalized_boxes
,
masks
def
color_jitter
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
contrast
:
Optional
[
float
]
=
0.
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
return
image
def
random_brightness
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
brightness
>=
0
and
brightness
<=
1.
,
'`brightness` must be in [0, 1]'
brightness
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
)
return
augment
.
brightness
(
image
,
brightness
)
def
random_contrast
(
image
:
tf
.
Tensor
,
contrast
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
contrast
>=
0
and
contrast
<=
1.
,
'`contrast` must be in [0, 1]'
contrast
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
)
return
augment
.
contrast
(
image
,
contrast
)
def
random_saturation
(
image
:
tf
.
Tensor
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented version of `image`.
"""
assert
saturation
>=
0
and
saturation
<=
1.
,
'`saturation` must be in [0, 1]'
saturation
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
)
return
augment
.
blend
(
tf
.
image
.
rgb_to_grayscale
(
image
),
image
,
saturation
)
official/vision/beta/projects/vit/README.md
View file @
40cd0a26
# Vision Transformer (ViT)
# Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**
: This implementation is still under development. No support will
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
be provided during the development phase.
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT) in
This repository is the implementations of Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
in
TensorFlow 2.
TensorFlow 2.
*
Paper title:
*
Paper title:
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
-
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
\ No newline at end of file
-
[
Training data-efficient image transformers & distillation through attention
](
https://arxiv.org/pdf/2012.12877.pdf
)
.
official/vision/beta/projects/vit/configs/backbones.py
View file @
40cd0a26
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size
:
int
=
1
hidden_size
:
int
=
1
patch_size
:
int
=
16
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
transformer
:
Transformer
=
Transformer
()
init_stochastic_depth_rate
:
float
=
0.0
original_init
:
bool
=
True
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/vit/configs/image_classification.py
View file @
40cd0a26
This diff is collapsed.
Click to expand it.
official/vision/beta/projects/vit/modeling/layers/__init__.py
0 → 100644
View file @
40cd0a26
from
official.vision.beta.projects.vit.modeling.layers.vit_transformer_encoder_block
import
TransformerEncoderBlock
\ No newline at end of file
official/vision/beta/projects/vit/modeling/layers/vit_transformer_encoder_block.py
0 → 100644
View file @
40cd0a26
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import
tensorflow
as
tf
from
official.vision.beta.modeling.layers.nn_layers
import
StochasticDepth
@
tf
.
keras
.
utils
.
register_keras_serializable
(
package
=
"Vision"
)
class
TransformerEncoderBlock
(
tf
.
keras
.
layers
.
Layer
):
"""TransformerEncoderBlock layer.
This layer implements the Transformer Encoder from
"Attention Is All You Need". (https://arxiv.org/abs/1706.03762),
which combines a `tf.keras.layers.MultiHeadAttention` layer with a
two-layer feedforward network. Here we ass support for stochastic depth.
References:
[Attention Is All You Need](https://arxiv.org/abs/1706.03762)
[BERT: Pre-training of Deep Bidirectional Transformers for Language
Understanding](https://arxiv.org/abs/1810.04805)
"""
def
__init__
(
self
,
num_attention_heads
,
inner_dim
,
inner_activation
,
output_range
=
None
,
kernel_initializer
=
"glorot_uniform"
,
bias_initializer
=
"zeros"
,
kernel_regularizer
=
None
,
bias_regularizer
=
None
,
activity_regularizer
=
None
,
kernel_constraint
=
None
,
bias_constraint
=
None
,
use_bias
=
True
,
norm_first
=
False
,
norm_epsilon
=
1e-12
,
output_dropout
=
0.0
,
attention_dropout
=
0.0
,
inner_dropout
=
0.0
,
stochastic_depth_drop_rate
=
0.0
,
attention_initializer
=
None
,
attention_axes
=
None
,
**
kwargs
):
"""Initializes `TransformerEncoderBlock`.
Args:
num_attention_heads: Number of attention heads.
inner_dim: The output dimension of the first Dense layer in a two-layer
feedforward network.
inner_activation: The activation for the first Dense layer in a two-layer
feedforward network.
output_range: the sequence output range, [0, output_range) for slicing the
target sequence. `None` means the target sequence is not sliced.
kernel_initializer: Initializer for dense layer kernels.
bias_initializer: Initializer for dense layer biases.
kernel_regularizer: Regularizer for dense layer kernels.
bias_regularizer: Regularizer for dense layer biases.
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer. If set False,
use_bias in attention layer is disabled.
norm_first: Whether to normalize inputs to attention and intermediate
dense layers. If set False, output of attention and intermediate dense
layers is normalized.
norm_epsilon: Epsilon value to initialize normalization layers.
output_dropout: Dropout probability for the post-attention and output
dropout.
attention_dropout: Dropout probability for within the attention layer.
inner_dropout: Dropout probability for the first Dense layer in a
two-layer feedforward network.
stochastic_depth_drop_rate: Dropout propobability for the stochastic depth
regularization.
attention_initializer: Initializer for kernels of attention layers. If set
`None`, attention layers use kernel_initializer as initializer for
kernel.
attention_axes: axes over which the attention is applied. `None` means
attention over all axes, but batch, heads, and features.
**kwargs: keyword arguments/
"""
super
().
__init__
(
**
kwargs
)
self
.
_num_heads
=
num_attention_heads
self
.
_inner_dim
=
inner_dim
self
.
_inner_activation
=
inner_activation
self
.
_attention_dropout
=
attention_dropout
self
.
_attention_dropout_rate
=
attention_dropout
self
.
_output_dropout
=
output_dropout
self
.
_output_dropout_rate
=
output_dropout
self
.
_output_range
=
output_range
self
.
_kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
kernel_initializer
)
self
.
_bias_initializer
=
tf
.
keras
.
initializers
.
get
(
bias_initializer
)
self
.
_kernel_regularizer
=
tf
.
keras
.
regularizers
.
get
(
kernel_regularizer
)
self
.
_bias_regularizer
=
tf
.
keras
.
regularizers
.
get
(
bias_regularizer
)
self
.
_activity_regularizer
=
tf
.
keras
.
regularizers
.
get
(
activity_regularizer
)
self
.
_kernel_constraint
=
tf
.
keras
.
constraints
.
get
(
kernel_constraint
)
self
.
_bias_constraint
=
tf
.
keras
.
constraints
.
get
(
bias_constraint
)
self
.
_use_bias
=
use_bias
self
.
_norm_first
=
norm_first
self
.
_norm_epsilon
=
norm_epsilon
self
.
_inner_dropout
=
inner_dropout
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
if
attention_initializer
:
self
.
_attention_initializer
=
tf
.
keras
.
initializers
.
get
(
attention_initializer
)
else
:
self
.
_attention_initializer
=
self
.
_kernel_initializer
self
.
_attention_axes
=
attention_axes
def
build
(
self
,
input_shape
):
if
isinstance
(
input_shape
,
tf
.
TensorShape
):
input_tensor_shape
=
input_shape
elif
isinstance
(
input_shape
,
(
list
,
tuple
)):
input_tensor_shape
=
tf
.
TensorShape
(
input_shape
[
0
])
else
:
raise
ValueError
(
"The type of input shape argument is not supported, got: %s"
%
type
(
input_shape
))
einsum_equation
=
"abc,cd->abd"
if
len
(
input_tensor_shape
.
as_list
())
>
3
:
einsum_equation
=
"...bc,cd->...bd"
hidden_size
=
input_tensor_shape
[
-
1
]
if
hidden_size
%
self
.
_num_heads
!=
0
:
raise
ValueError
(
"The input size (%d) is not a multiple of the number of attention "
"heads (%d)"
%
(
hidden_size
,
self
.
_num_heads
))
self
.
_attention_head_size
=
int
(
hidden_size
//
self
.
_num_heads
)
common_kwargs
=
dict
(
bias_initializer
=
self
.
_bias_initializer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
bias_regularizer
=
self
.
_bias_regularizer
,
activity_regularizer
=
self
.
_activity_regularizer
,
kernel_constraint
=
self
.
_kernel_constraint
,
bias_constraint
=
self
.
_bias_constraint
)
self
.
_attention_layer
=
tf
.
keras
.
layers
.
MultiHeadAttention
(
num_heads
=
self
.
_num_heads
,
key_dim
=
self
.
_attention_head_size
,
dropout
=
self
.
_attention_dropout
,
use_bias
=
self
.
_use_bias
,
kernel_initializer
=
self
.
_attention_initializer
,
attention_axes
=
self
.
_attention_axes
,
name
=
"self_attention"
,
**
common_kwargs
)
self
.
_attention_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
# It is probably safe in mixed_float16, but we haven't validated this yet.
self
.
_attention_layer_norm
=
(
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"self_attention_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
))
self
.
_intermediate_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
self
.
_inner_dim
),
bias_axes
=
"d"
,
kernel_initializer
=
self
.
_kernel_initializer
,
name
=
"intermediate"
,
**
common_kwargs
)
policy
=
tf
.
keras
.
mixed_precision
.
global_policy
()
if
policy
.
name
==
"mixed_bfloat16"
:
# bfloat16 causes BERT with the LAMB optimizer to not converge
# as well, so we use float32.
# TODO(b/154538392): Investigate this.
policy
=
tf
.
float32
self
.
_intermediate_activation_layer
=
tf
.
keras
.
layers
.
Activation
(
self
.
_inner_activation
,
dtype
=
policy
)
self
.
_inner_dropout_layer
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_inner_dropout
)
self
.
_output_dense
=
tf
.
keras
.
layers
.
experimental
.
EinsumDense
(
einsum_equation
,
output_shape
=
(
None
,
hidden_size
),
bias_axes
=
"d"
,
name
=
"output"
,
kernel_initializer
=
self
.
_kernel_initializer
,
**
common_kwargs
)
self
.
_output_dropout
=
tf
.
keras
.
layers
.
Dropout
(
rate
=
self
.
_output_dropout
)
# Use float32 in layernorm for numeric stability.
self
.
_output_layer_norm
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
"output_layer_norm"
,
axis
=-
1
,
epsilon
=
self
.
_norm_epsilon
,
dtype
=
tf
.
float32
)
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
None
super
(
TransformerEncoderBlock
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"num_attention_heads"
:
self
.
_num_heads
,
"inner_dim"
:
self
.
_inner_dim
,
"inner_activation"
:
self
.
_inner_activation
,
"output_dropout"
:
self
.
_output_dropout_rate
,
"attention_dropout"
:
self
.
_attention_dropout_rate
,
"output_range"
:
self
.
_output_range
,
"kernel_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_kernel_initializer
),
"bias_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_bias_initializer
),
"kernel_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_kernel_regularizer
),
"bias_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_bias_regularizer
),
"activity_regularizer"
:
tf
.
keras
.
regularizers
.
serialize
(
self
.
_activity_regularizer
),
"kernel_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_kernel_constraint
),
"bias_constraint"
:
tf
.
keras
.
constraints
.
serialize
(
self
.
_bias_constraint
),
"use_bias"
:
self
.
_use_bias
,
"norm_first"
:
self
.
_norm_first
,
"norm_epsilon"
:
self
.
_norm_epsilon
,
"inner_dropout"
:
self
.
_inner_dropout
,
"stochastic_depth_drop_rate"
:
self
.
_stochastic_depth_drop_rate
,
"attention_initializer"
:
tf
.
keras
.
initializers
.
serialize
(
self
.
_attention_initializer
),
"attention_axes"
:
self
.
_attention_axes
,
}
base_config
=
super
(
TransformerEncoderBlock
,
self
).
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
"""Transformer self-attention encoder block call.
Args:
inputs: a single tensor or a list of tensors.
`input tensor` as the single sequence of embeddings.
[`input tensor`, `attention mask`] to have the additional attention
mask.
[`query tensor`, `key value tensor`, `attention mask`] to have separate
input streams for the query, and key/value to the multi-head
attention.
Returns:
An output tensor with the same dimensions as input/query tensor.
"""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
with_stochastic_depth
=
training
and
self
.
_stochastic_depth
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
)
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
with_stochastic_depth
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
with_stochastic_depth
)
)
official/vision/beta/projects/vit/modeling/vit.py
View file @
40cd0a26
...
@@ -19,6 +19,8 @@ import tensorflow as tf
...
@@ -19,6 +19,8 @@ import tensorflow as tf
from
official.modeling
import
activations
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.vit.modeling.layers
import
TransformerEncoderBlock
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
...
@@ -29,6 +31,18 @@ VIT_SPECS = {
...
@@ -29,6 +31,18 @@ VIT_SPECS = {
patch_size
=
16
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
),
),
'vit-ti16'
:
dict
(
hidden_size
=
192
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
3
,
num_layers
=
12
),
),
'vit-s16'
:
dict
(
hidden_size
=
384
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
6
,
num_layers
=
12
),
),
'vit-b16'
:
'vit-b16'
:
dict
(
dict
(
hidden_size
=
768
,
hidden_size
=
768
,
...
@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -112,6 +126,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_num_layers
=
num_layers
...
@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -121,6 +137,8 @@ class Encoder(tf.keras.layers.Layer):
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
self
.
_pos_embed
=
AddPositionEmbs
(
...
@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -131,15 +149,18 @@ class Encoder(tf.keras.layers.Layer):
self
.
_encoder_layers
=
[]
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for
_
in
range
(
self
.
_num_layers
):
for
i
in
range
(
self
.
_num_layers
):
encoder_layer
=
keras_nlp
.
layers
.
TransformerEncoderBlock
(
encoder_layer
=
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
norm_first
=
True
,
norm_first
=
True
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
,
self
.
_num_layers
-
1
),
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
...
@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model):
...
@@ -164,12 +185,14 @@ class VisionTransformer(tf.keras.Model):
num_layers
=
12
,
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
init_stochastic_depth_rate
=
0.0
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
patch_size
=
16
,
hidden_size
=
768
,
hidden_size
=
768
,
representation_size
=
0
,
representation_size
=
0
,
classifier
=
'token'
,
classifier
=
'token'
,
kernel_regularizer
=
None
):
kernel_regularizer
=
None
,
original_init
=
True
):
"""VisionTransformer initialization function."""
"""VisionTransformer initialization function."""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
...
@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model):
...
@@ -178,7 +201,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size
=
patch_size
,
kernel_size
=
patch_size
,
strides
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
inputs
)
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
rows_axis
,
cols_axis
=
(
1
,
2
)
...
@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model):
...
@@ -203,7 +227,10 @@ class VisionTransformer(tf.keras.Model):
num_heads
=
num_heads
,
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'glorot_uniform'
if
original_init
else
dict
(
class_name
=
'TruncatedNormal'
,
config
=
dict
(
stddev
=
.
02
)),
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)(
x
)
x
)
if
classifier
==
'token'
:
if
classifier
==
'token'
:
...
@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model):
...
@@ -215,7 +242,8 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
keras
.
layers
.
Dense
(
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
)(
name
=
'pre_logits'
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
x
)
x
)
x
=
tf
.
nn
.
tanh
(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
else
:
...
@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model):
...
@@ -225,7 +253,8 @@ class VisionTransformer(tf.keras.Model):
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
}
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
@
factory
.
register_backbone_builder
(
'vit'
)
@
factory
.
register_backbone_builder
(
'vit'
)
...
@@ -247,9 +276,11 @@ def build_vit(input_specs,
...
@@ -247,9 +276,11 @@ def build_vit(input_specs,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
init_stochastic_depth_rate
=
backbone_cfg
.
init_stochastic_depth_rate
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifier
=
backbone_cfg
.
classifier
,
classifier
=
backbone_cfg
.
classifier
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
)
official/vision/beta/projects/yolo/configs/darknet_classification.py
View file @
40cd0a26
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
def
image
_classification
()
->
cfg
.
ExperimentConfig
:
def
darknet
_classification
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
return
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(),
task
=
ImageClassificationTask
(),
...
...
official/vision/beta/tasks/image_classification.py
View file @
40cd0a26
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
augment
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
...
@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task):
...
@@ -103,14 +104,27 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only
=
params
.
decode_jpeg_only
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
is_multilabel
=
is_multilabel
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
postprocess_fn
=
None
if
params
.
mixup_and_cutmix
:
postprocess_fn
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
params
.
mixup_and_cutmix
.
mixup_alpha
,
cutmix_alpha
=
params
.
mixup_and_cutmix
.
cutmix_alpha
,
prob
=
params
.
mixup_and_cutmix
.
prob
,
label_smoothing
=
params
.
mixup_and_cutmix
.
label_smoothing
,
num_classes
=
params
.
mixup_and_cutmix
.
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task):
...
@@ -119,12 +133,15 @@ class ImageClassificationTask(base_task.Task):
def
build_losses
(
self
,
def
build_losses
(
self
,
labels
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
model_outputs
:
tf
.
Tensor
,
is_validation
:
bool
,
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
aux_losses
:
Optional
[
Any
]
=
None
)
->
tf
.
Tensor
:
"""Builds sparse categorical cross entropy loss.
"""Builds sparse categorical cross entropy loss.
Args:
Args:
labels: Input groundtruth labels.
labels: Input groundtruth labels.
model_outputs: Output logits of the classifier.
model_outputs: Output logits of the classifier.
is_validation: To handle that some augmentations need custom soft labels
while the validation should remain unchainged.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
aux_losses: The auxiliarly loss tensors, i.e. `losses` in tf.keras.Model.
Returns:
Returns:
...
@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task):
...
@@ -134,12 +151,19 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
if
not
is_multilabel
:
if
losses_config
.
one_hot
:
# Some augmentation need custom soft labels in training, but validation
# should remain unchainged
if
losses_config
.
one_hot
or
is_validation
:
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
total_loss
=
tf
.
keras
.
losses
.
categorical_crossentropy
(
labels
,
labels
,
model_outputs
,
model_outputs
,
from_logits
=
True
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
label_smoothing
=
losses_config
.
label_smoothing
)
elif
losses_config
.
soft_labels
:
total_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
model_outputs
)
else
:
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
labels
,
model_outputs
,
from_logits
=
True
)
...
@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -161,7 +185,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
if
not
is_multilabel
:
k
=
self
.
task_config
.
evaluation
.
top_k
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
if
(
self
.
task_config
.
losses
.
one_hot
or
self
.
task_config
.
losses
.
soft_labels
):
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
...
@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -222,8 +247,8 @@ class ImageClassificationTask(base_task.Task):
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
model_outputs
=
outputs
,
labels
=
labe
ls
,
aux_losses
=
model
.
losses
)
is_validation
=
Fa
ls
e
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# optimizer.
scaled_loss
=
loss
/
num_replicas
scaled_loss
=
loss
/
num_replicas
...
@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task):
...
@@ -266,14 +291,16 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
A dictionary of logs.
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
one_hot
=
self
.
task_config
.
losses
.
one_hot
soft_labels
=
self
.
task_config
.
losses
.
soft_labels
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
if
(
one_hot
or
soft_labels
)
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
is_validation
=
True
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
if
metrics
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment