Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
42ad9d5e
Commit
42ad9d5e
authored
Sep 16, 2021
by
A. Unique TensorFlower
Browse files
Merge pull request #10227 from sigeisler:master
PiperOrigin-RevId: 397161611
parents
b5416378
01b21983
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
854 additions
and
40 deletions
+854
-40
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+26
-1
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+5
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+32
-1
official/vision/beta/modeling/factory.py
official/vision/beta/modeling/factory.py
+1
-0
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+331
-13
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+77
-0
official/vision/beta/ops/preprocess_ops.py
official/vision/beta/ops/preprocess_ops.py
+103
-1
official/vision/beta/ops/preprocess_ops_test.py
official/vision/beta/ops/preprocess_ops_test.py
+13
-0
official/vision/beta/projects/vit/README.md
official/vision/beta/projects/vit/README.md
+7
-5
official/vision/beta/projects/vit/configs/backbones.py
official/vision/beta/projects/vit/configs/backbones.py
+2
-0
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+86
-1
official/vision/beta/projects/vit/modeling/nn_blocks.py
official/vision/beta/projects/vit/modeling/nn_blocks.py
+107
-0
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+34
-11
official/vision/beta/projects/yolo/configs/darknet_classification.py
...sion/beta/projects/yolo/configs/darknet_classification.py
+1
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+29
-6
No files found.
official/vision/beta/configs/common.py
View file @
42ad9d5e
...
@@ -16,7 +16,7 @@
...
@@ -16,7 +16,7 @@
"""Common configurations."""
"""Common configurations."""
import
dataclasses
import
dataclasses
from
typing
import
Optional
from
typing
import
List
,
Optional
# Import libraries
# Import libraries
...
@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config):
...
@@ -60,7 +60,9 @@ class RandAugment(hyperparams.Config):
magnitude
:
float
=
10
magnitude
:
float
=
10
cutout_const
:
float
=
40
cutout_const
:
float
=
40
translate_const
:
float
=
10
translate_const
:
float
=
10
magnitude_std
:
float
=
0.0
prob_to_apply
:
Optional
[
float
]
=
None
prob_to_apply
:
Optional
[
float
]
=
None
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config):
...
@@ -71,6 +73,29 @@ class AutoAugment(hyperparams.Config):
translate_const
:
float
=
250
translate_const
:
float
=
250
@
dataclasses
.
dataclass
class
RandomErasing
(
hyperparams
.
Config
):
"""Configuration for RandomErasing."""
probability
:
float
=
0.25
min_area
:
float
=
0.02
max_area
:
float
=
1
/
3
min_aspect
:
float
=
0.3
max_aspect
=
None
min_count
=
1
max_count
=
1
trials
=
10
@
dataclasses
.
dataclass
class
MixupAndCutmix
(
hyperparams
.
Config
):
"""Configuration for MixupAndCutmix."""
mixup_alpha
:
float
=
.
8
cutmix_alpha
:
float
=
1.
prob
:
float
=
1.0
switch_prob
:
float
=
0.5
label_smoothing
:
float
=
0.1
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
class
Augmentation
(
hyperparams
.
OneOfConfig
):
class
Augmentation
(
hyperparams
.
OneOfConfig
):
"""Configuration for input data augmentation.
"""Configuration for input data augmentation.
...
...
official/vision/beta/configs/image_classification.py
View file @
42ad9d5e
...
@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig):
...
@@ -39,10 +39,13 @@ class DataConfig(cfg.DataConfig):
aug_rand_hflip
:
bool
=
True
aug_rand_hflip
:
bool
=
True
aug_type
:
Optional
[
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
file_type
:
str
=
'tfrecord'
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
label_field_key
:
str
=
'image/class/label'
decode_jpeg_only
:
bool
=
True
decode_jpeg_only
:
bool
=
True
mixup_and_cutmix
:
Optional
[
common
.
MixupAndCutmix
]
=
None
decoder
:
Optional
[
common
.
DataDecoder
]
=
common
.
DataDecoder
()
decoder
:
Optional
[
common
.
DataDecoder
]
=
common
.
DataDecoder
()
# Keep for backward compatibility.
# Keep for backward compatibility.
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
...
@@ -62,6 +65,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
...
@@ -69,6 +73,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
42ad9d5e
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
...
@@ -69,6 +69,8 @@ class Parser(parser.Parser):
decode_jpeg_only
:
bool
=
True
,
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
is_multilabel
:
bool
=
False
,
is_multilabel
:
bool
=
False
,
dtype
:
str
=
'float32'
):
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
"""Initializes parameters for parsing annotations in the dataset.
...
@@ -85,6 +87,11 @@ class Parser(parser.Parser):
...
@@ -85,6 +87,11 @@ class Parser(parser.Parser):
horizontal flip.
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
generate random scale factor for brightness, contrast and saturation.
See `preprocess_ops.color_jitter` for more details.
random_erasing: if not None, augment input image by random erasing. See
`augment.RandomErasing` for more details.
is_multilabel: A `bool`, whether or not each example has multiple labels.
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
or 'bfloat16'.
...
@@ -113,13 +120,27 @@ class Parser(parser.Parser):
...
@@ -113,13 +120,27 @@ class Parser(parser.Parser):
magnitude
=
aug_type
.
randaug
.
magnitude
,
magnitude
=
aug_type
.
randaug
.
magnitude
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
cutout_const
=
aug_type
.
randaug
.
cutout_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
translate_const
=
aug_type
.
randaug
.
translate_const
,
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
)
prob_to_apply
=
aug_type
.
randaug
.
prob_to_apply
,
exclude_ops
=
aug_type
.
randaug
.
exclude_ops
)
else
:
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_type
.
type
))
aug_type
.
type
))
else
:
else
:
self
.
_augmenter
=
None
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_label_field_key
=
label_field_key
self
.
_color_jitter
=
color_jitter
if
random_erasing
:
self
.
_random_erasing
=
augment
.
RandomErasing
(
probability
=
random_erasing
.
probability
,
min_area
=
random_erasing
.
min_area
,
max_area
=
random_erasing
.
max_area
,
min_aspect
=
random_erasing
.
min_aspect
,
max_aspect
=
random_erasing
.
max_aspect
,
min_count
=
random_erasing
.
min_count
,
max_count
=
random_erasing
.
max_count
,
trials
=
random_erasing
.
trials
)
else
:
self
.
_random_erasing
=
None
self
.
_is_multilabel
=
is_multilabel
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
self
.
_decode_jpeg_only
=
decode_jpeg_only
...
@@ -173,6 +194,12 @@ class Parser(parser.Parser):
...
@@ -173,6 +194,12 @@ class Parser(parser.Parser):
if
self
.
_aug_rand_hflip
:
if
self
.
_aug_rand_hflip
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
image
=
tf
.
image
.
random_flip_left_right
(
image
)
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Resizes image.
# Resizes image.
image
=
tf
.
image
.
resize
(
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
...
@@ -187,6 +214,10 @@ class Parser(parser.Parser):
...
@@ -187,6 +214,10 @@ class Parser(parser.Parser):
offset
=
MEAN_RGB
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
official/vision/beta/modeling/factory.py
View file @
42ad9d5e
...
@@ -56,6 +56,7 @@ def build_classification_model(
...
@@ -56,6 +56,7 @@ def build_classification_model(
num_classes
=
model_config
.
num_classes
,
num_classes
=
model_config
.
num_classes
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
dropout_rate
=
model_config
.
dropout_rate
,
dropout_rate
=
model_config
.
dropout_rate
,
kernel_initializer
=
model_config
.
kernel_initializer
,
kernel_regularizer
=
l2_regularizer
,
kernel_regularizer
=
l2_regularizer
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
add_head_batch_norm
=
model_config
.
add_head_batch_norm
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
use_sync_bn
=
norm_activation_config
.
use_sync_bn
,
...
...
official/vision/beta/ops/augment.py
View file @
42ad9d5e
...
@@ -12,10 +12,18 @@
...
@@ -12,10 +12,18 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
"""
Auto
Augment
and RandAugment
policies for enhanced image/video preprocessing.
"""Augment
ation
policies for enhanced image/video preprocessing.
AutoAugment Reference: https://arxiv.org/abs/1805.09501
AutoAugment Reference: https://arxiv.org/abs/1805.09501
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandAugment Reference: https://arxiv.org/abs/1909.13719
RandomErasing Reference: https://arxiv.org/abs/1708.04896
MixupAndCutmix:
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
RandomErasing, Mixup and Cutmix are inspired by
https://github.com/rwightman/pytorch-image-models
"""
"""
import
math
import
math
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
from
typing
import
Any
,
List
,
Iterable
,
Optional
,
Text
,
Tuple
...
@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -295,10 +303,26 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
cutout_center_width
=
tf
.
random
.
uniform
(
cutout_center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
shape
=
[],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
lower_pad
=
tf
.
maximum
(
0
,
cutout_center_height
-
pad_size
)
image
=
_fill_rectangle
(
image
,
cutout_center_width
,
cutout_center_height
,
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
cutout_center_height
-
pad_size
)
pad_size
,
pad_size
,
replace
)
left_pad
=
tf
.
maximum
(
0
,
cutout_center_width
-
pad_size
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
cutout_center_width
-
pad_size
)
return
image
def
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
):
"""Fill blank area."""
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
lower_pad
=
tf
.
maximum
(
0
,
center_height
-
half_height
)
upper_pad
=
tf
.
maximum
(
0
,
image_height
-
center_height
-
half_height
)
left_pad
=
tf
.
maximum
(
0
,
center_width
-
half_width
)
right_pad
=
tf
.
maximum
(
0
,
image_width
-
center_width
-
half_width
)
cutout_shape
=
[
cutout_shape
=
[
image_height
-
(
lower_pad
+
upper_pad
),
image_height
-
(
lower_pad
+
upper_pad
),
...
@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
...
@@ -311,9 +335,15 @@ def cutout(image: tf.Tensor, pad_size: int, replace: int = 0) -> tf.Tensor:
constant_values
=
1
)
constant_values
=
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
expand_dims
(
mask
,
-
1
)
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
mask
=
tf
.
tile
(
mask
,
[
1
,
1
,
3
])
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
if
replace
is
None
:
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
,
image
)
fill
=
tf
.
random
.
normal
(
tf
.
shape
(
image
),
dtype
=
image
.
dtype
)
elif
isinstance
(
replace
,
tf
.
Tensor
):
fill
=
replace
else
:
fill
=
tf
.
ones_like
(
image
,
dtype
=
image
.
dtype
)
*
replace
image
=
tf
.
where
(
tf
.
equal
(
mask
,
0
),
fill
,
image
)
return
image
return
image
...
@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float):
...
@@ -803,11 +833,20 @@ def level_to_arg(cutout_const: float, translate_const: float):
return
args
return
args
def
_parse_policy_info
(
name
:
Text
,
prob
:
float
,
level
:
float
,
def
_parse_policy_info
(
name
:
Text
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
prob
:
float
,
translate_const
:
float
)
->
Tuple
[
Any
,
float
,
Any
]:
level
:
float
,
replace_value
:
List
[
int
],
cutout_const
:
float
,
translate_const
:
float
,
level_std
:
float
=
0.
)
->
Tuple
[
Any
,
float
,
Any
]:
"""Return the function that corresponds to `name` and update `level` param."""
"""Return the function that corresponds to `name` and update `level` param."""
func
=
NAME_TO_FUNC
[
name
]
func
=
NAME_TO_FUNC
[
name
]
if
level_std
>
0
:
level
+=
tf
.
random
.
normal
([],
dtype
=
tf
.
float32
)
level
=
tf
.
clip_by_value
(
level
,
0.
,
_MAX_LEVEL
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
args
=
level_to_arg
(
cutout_const
,
translate_const
)[
name
](
level
)
if
name
in
REPLACE_FUNCS
:
if
name
in
REPLACE_FUNCS
:
...
@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment):
...
@@ -1184,7 +1223,9 @@ class RandAugment(ImageAugment):
magnitude
:
float
=
10.
,
magnitude
:
float
=
10.
,
cutout_const
:
float
=
40.
,
cutout_const
:
float
=
40.
,
translate_const
:
float
=
100.
,
translate_const
:
float
=
100.
,
prob_to_apply
:
Optional
[
float
]
=
None
):
magnitude_std
:
float
=
0.0
,
prob_to_apply
:
Optional
[
float
]
=
None
,
exclude_ops
:
Optional
[
List
[
str
]]
=
None
):
"""Applies the RandAugment policy to images.
"""Applies the RandAugment policy to images.
Args:
Args:
...
@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment):
...
@@ -1196,8 +1237,11 @@ class RandAugment(ImageAugment):
[5, 10].
[5, 10].
cutout_const: multiplier for applying cutout.
cutout_const: multiplier for applying cutout.
translate_const: multiplier for applying translation.
translate_const: multiplier for applying translation.
magnitude_std: randomness of the severity as proposed by the authors of
the timm library.
prob_to_apply: The probability to apply the selected augmentation at each
prob_to_apply: The probability to apply the selected augmentation at each
layer.
layer.
exclude_ops: exclude selected operations.
"""
"""
super
(
RandAugment
,
self
).
__init__
()
super
(
RandAugment
,
self
).
__init__
()
...
@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment):
...
@@ -1212,6 +1256,11 @@ class RandAugment(ImageAugment):
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'Color'
,
'Contrast'
,
'Brightness'
,
'Sharpness'
,
'ShearX'
,
'ShearY'
,
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
'TranslateX'
,
'TranslateY'
,
'Cutout'
,
'SolarizeAdd'
]
]
self
.
magnitude_std
=
magnitude_std
if
exclude_ops
:
self
.
available_ops
=
[
op
for
op
in
self
.
available_ops
if
op
not
in
exclude_ops
]
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies the RandAugment policy to `image`.
"""Applies the RandAugment policy to `image`.
...
@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment):
...
@@ -1246,7 +1295,8 @@ class RandAugment(ImageAugment):
dtype
=
tf
.
float32
)
dtype
=
tf
.
float32
)
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
func
,
_
,
args
=
_parse_policy_info
(
op_name
,
prob
,
self
.
magnitude
,
replace_value
,
self
.
cutout_const
,
replace_value
,
self
.
cutout_const
,
self
.
translate_const
)
self
.
translate_const
,
self
.
magnitude_std
)
branch_fns
.
append
((
branch_fns
.
append
((
i
,
i
,
# pylint:disable=g-long-lambda
# pylint:disable=g-long-lambda
...
@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment):
...
@@ -1267,3 +1317,271 @@ class RandAugment(ImageAugment):
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
image
=
tf
.
cast
(
image
,
dtype
=
input_image_type
)
return
image
return
image
class
RandomErasing
(
ImageAugment
):
"""Applies RandomErasing to a single image.
Reference: https://arxiv.org/abs/1708.04896
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
probability
:
float
=
0.25
,
min_area
:
float
=
0.02
,
max_area
:
float
=
1
/
3
,
min_aspect
:
float
=
0.3
,
max_aspect
=
None
,
min_count
=
1
,
max_count
=
1
,
trials
=
10
):
"""Applies RandomErasing to a single image.
Args:
probability (float, optional): Probability of augmenting the image.
Defaults to 0.25.
min_area (float, optional): Minimum area of the random erasing rectangle.
Defaults to 0.02.
max_area (float, optional): Maximum area of the random erasing rectangle.
Defaults to 1/3.
min_aspect (float, optional): Minimum aspect rate of the random erasing
rectangle. Defaults to 0.3.
max_aspect ([type], optional): Maximum aspect rate of the random erasing
rectangle. Defaults to None.
min_count (int, optional): Minimum number of erased rectangles. Defaults
to 1.
max_count (int, optional): Maximum number of erased rectangles. Defaults
to 1.
trials (int, optional): Maximum number of trials to randomly sample a
rectangle that fulfills constraint. Defaults to 10.
"""
self
.
_probability
=
probability
self
.
_min_area
=
float
(
min_area
)
self
.
_max_area
=
float
(
max_area
)
self
.
_min_log_aspect
=
math
.
log
(
min_aspect
)
self
.
_max_log_aspect
=
math
.
log
(
max_aspect
or
1
/
min_aspect
)
self
.
_min_count
=
min_count
self
.
_max_count
=
max_count
self
.
_trials
=
trials
def
distort
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Applies RandomErasing to single `image`.
Args:
image (tf.Tensor): Of shape [height, width, 3] representing an image.
Returns:
tf.Tensor: The augmented version of `image`.
"""
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
mirror_cond
=
tf
.
less
(
uniform_random
,
self
.
_probability
)
image
=
tf
.
cond
(
mirror_cond
,
lambda
:
self
.
_erase
(
image
),
lambda
:
image
)
return
image
@
tf
.
function
def
_erase
(
self
,
image
:
tf
.
Tensor
)
->
tf
.
Tensor
:
"""Erase an area."""
if
self
.
_min_count
==
self
.
_max_count
:
count
=
self
.
_min_count
else
:
count
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
int
(
self
.
_min_count
),
maxval
=
int
(
self
.
_max_count
-
self
.
_min_count
+
1
),
dtype
=
tf
.
int32
)
image_height
=
tf
.
shape
(
image
)[
0
]
image_width
=
tf
.
shape
(
image
)[
1
]
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
for
_
in
range
(
count
):
# Work around since break is not supported in tf.function
is_trial_successfull
=
False
for
_
in
range
(
self
.
_trials
):
if
not
is_trial_successfull
:
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
area
*
self
.
_min_area
,
maxval
=
area
*
self
.
_max_area
)
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
self
.
_min_log_aspect
,
maxval
=
self
.
_max_log_aspect
))
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
center_height
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
dtype
=
tf
.
int32
)
center_width
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
dtype
=
tf
.
int32
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
is_trial_successfull
=
True
return
image
class
MixupAndCutmix
:
"""Applies Mixup and/or Cutmix to a batch of images.
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
"""
def
__init__
(
self
,
mixup_alpha
:
float
=
.
8
,
cutmix_alpha
:
float
=
1.
,
prob
:
float
=
1.0
,
switch_prob
:
float
=
0.5
,
label_smoothing
:
float
=
0.1
,
num_classes
:
int
=
1001
):
"""Applies Mixup and/or Cutmix to a batch of images.
Args:
mixup_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Mixup is deactivated.
Defaults to .8.
cutmix_alpha (float, optional): For drawing a random lambda (`lam`) from a
beta distribution (for each image). If zero Cutmix is deactivated.
Defaults to 1..
prob (float, optional): Of augmenting the batch. Defaults to 1.0.
switch_prob (float, optional): Probability of applying Cutmix for the
batch. Defaults to 0.5.
label_smoothing (float, optional): Constant for label smoothing. Defaults
to 0.1.
num_classes (int, optional): Number of classes. Defaults to 1001.
"""
self
.
mixup_alpha
=
mixup_alpha
self
.
cutmix_alpha
=
cutmix_alpha
self
.
mix_prob
=
prob
self
.
switch_prob
=
switch_prob
self
.
label_smoothing
=
label_smoothing
self
.
num_classes
=
num_classes
self
.
mode
=
'batch'
self
.
mixup_enabled
=
True
if
self
.
mixup_alpha
and
not
self
.
cutmix_alpha
:
self
.
switch_prob
=
-
1
elif
not
self
.
mixup_alpha
and
self
.
cutmix_alpha
:
self
.
switch_prob
=
1
def
__call__
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
return
self
.
distort
(
images
,
labels
)
def
distort
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
"""Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args:
images (tf.Tensor): Of shape [batch_size,height, width, 3] representing a
batch of image.
labels (tf.Tensor): Of shape [batch_size, ] representing the class id for
each image of the batch.
Returns:
Tuple[tf.Tensor, tf.Tensor]: The augmented version of `image` and
`labels`.
"""
augment_cond
=
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
mix_prob
)
# pylint: disable=g-long-lambda
augment_a
=
lambda
:
self
.
_update_labels
(
*
tf
.
cond
(
tf
.
less
(
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
),
self
.
switch_prob
),
lambda
:
self
.
_cutmix
(
images
,
labels
),
lambda
:
self
.
_mixup
(
images
,
labels
)))
augment_b
=
lambda
:
(
images
,
self
.
_smooth_labels
(
labels
))
# pylint: enable=g-long-lambda
return
tf
.
cond
(
augment_cond
,
augment_a
,
augment_b
)
@
staticmethod
def
_sample_from_beta
(
alpha
,
beta
,
shape
):
sample_alpha
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
alpha
)
sample_beta
=
tf
.
random
.
gamma
(
shape
,
1.
,
beta
=
beta
)
return
sample_alpha
/
(
sample_alpha
+
sample_beta
)
def
_cutmix
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
"""Apply cutmix."""
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
cutmix_alpha
,
self
.
cutmix_alpha
,
labels
.
shape
)
ratio
=
tf
.
math
.
sqrt
(
1
-
lam
)
batch_size
=
tf
.
shape
(
images
)[
0
]
image_height
,
image_width
=
tf
.
shape
(
images
)[
1
],
tf
.
shape
(
images
)[
2
]
cut_height
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
cut_width
=
tf
.
cast
(
ratio
*
tf
.
cast
(
image_height
,
dtype
=
tf
.
float32
),
dtype
=
tf
.
int32
)
random_center_height
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_height
,
dtype
=
tf
.
int32
)
random_center_width
=
tf
.
random
.
uniform
(
shape
=
[
batch_size
],
minval
=
0
,
maxval
=
image_width
,
dtype
=
tf
.
int32
)
bbox_area
=
cut_height
*
cut_width
lam
=
1.
-
bbox_area
/
(
image_height
*
image_width
)
lam
=
tf
.
cast
(
lam
,
dtype
=
tf
.
float32
)
images
=
tf
.
map_fn
(
lambda
x
:
_fill_rectangle
(
*
x
),
(
images
,
random_center_width
,
random_center_height
,
cut_width
//
2
,
cut_height
//
2
,
tf
.
reverse
(
images
,
[
0
])),
dtype
=
(
tf
.
float32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
int32
,
tf
.
float32
),
fn_output_signature
=
tf
.
TensorSpec
(
images
.
shape
[
1
:],
dtype
=
tf
.
float32
))
return
images
,
labels
,
lam
def
_mixup
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
,
tf
.
Tensor
]:
lam
=
MixupAndCutmix
.
_sample_from_beta
(
self
.
mixup_alpha
,
self
.
mixup_alpha
,
labels
.
shape
)
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
,
1
,
1
])
images
=
lam
*
images
+
(
1.
-
lam
)
*
tf
.
reverse
(
images
,
[
0
])
return
images
,
labels
,
tf
.
squeeze
(
lam
)
def
_smooth_labels
(
self
,
labels
:
tf
.
Tensor
)
->
tf
.
Tensor
:
off_value
=
self
.
label_smoothing
/
self
.
num_classes
on_value
=
1.
-
self
.
label_smoothing
+
off_value
smooth_labels
=
tf
.
one_hot
(
labels
,
self
.
num_classes
,
on_value
=
on_value
,
off_value
=
off_value
)
return
smooth_labels
def
_update_labels
(
self
,
images
:
tf
.
Tensor
,
labels
:
tf
.
Tensor
,
lam
:
tf
.
Tensor
)
->
Tuple
[
tf
.
Tensor
,
tf
.
Tensor
]:
labels_1
=
self
.
_smooth_labels
(
labels
)
labels_2
=
tf
.
reverse
(
labels_1
,
[
0
])
lam
=
tf
.
reshape
(
lam
,
[
-
1
,
1
])
labels
=
lam
*
labels_1
+
(
1.
-
lam
)
*
labels_2
return
images
,
labels
official/vision/beta/ops/augment_test.py
View file @
42ad9d5e
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -254,5 +254,82 @@ class AutoaugmentTest(tf.test.TestCase, parameterized.TestCase):
augmenter
.
distort
(
image
)
augmenter
.
distort
(
image
)
class
RandomErasingTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_random_erase_replaces_some_pixels
(
self
):
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
float32
)
augmenter
=
augment
.
RandomErasing
(
probability
=
1.
,
max_count
=
10
)
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
self
.
assertNotEqual
(
0
,
tf
.
reduce_max
(
aug_image
))
class
MixupAndCutmixTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
test_mixup_and_cutmix_smoothes_labels
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
num_classes
=
num_classes
,
label_smoothing
=
label_smoothing
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
def
test_mixup_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
1.
,
cutmix_alpha
=
0.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertFalse
(
tf
.
math
.
reduce_all
(
images
==
aug_images
))
def
test_cutmix_changes_image
(
self
):
batch_size
=
12
num_classes
=
1000
label_smoothing
=
0.1
images
=
tf
.
random
.
normal
((
batch_size
,
224
,
224
,
3
),
dtype
=
tf
.
float32
)
labels
=
tf
.
range
(
batch_size
)
augmenter
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
0.
,
cutmix_alpha
=
1.
,
num_classes
=
num_classes
)
aug_images
,
aug_labels
=
augmenter
.
distort
(
images
,
labels
)
self
.
assertEqual
(
images
.
shape
,
aug_images
.
shape
)
self
.
assertEqual
(
images
.
dtype
,
aug_images
.
dtype
)
self
.
assertEqual
([
batch_size
,
num_classes
],
aug_labels
.
shape
)
self
.
assertAllLessEqual
(
aug_labels
,
1.
-
label_smoothing
+
2.
/
num_classes
)
# With tolerance
self
.
assertAllGreaterEqual
(
aug_labels
,
label_smoothing
/
num_classes
-
1e4
)
# With tolerance
self
.
assertFalse
(
tf
.
math
.
reduce_all
(
images
==
aug_images
))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
official/vision/beta/ops/preprocess_ops.py
View file @
42ad9d5e
...
@@ -15,12 +15,13 @@
...
@@ -15,12 +15,13 @@
"""Preprocessing ops."""
"""Preprocessing ops."""
import
math
import
math
from
typing
import
Optional
from
six.moves
import
range
from
six.moves
import
range
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
box_ops
from
official.vision.beta.ops
import
box_ops
CENTER_CROP_FRACTION
=
0.875
CENTER_CROP_FRACTION
=
0.875
...
@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
...
@@ -557,6 +558,107 @@ def random_horizontal_flip(image, normalized_boxes=None, masks=None, seed=1):
return
image
,
normalized_boxes
,
masks
return
image
,
normalized_boxes
,
masks
def
color_jitter
(
image
:
tf
.
Tensor
,
brightness
:
Optional
[
float
]
=
0.
,
contrast
:
Optional
[
float
]
=
0.
,
saturation
:
Optional
[
float
]
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
return
image
def
random_brightness
(
image
:
tf
.
Tensor
,
brightness
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters brightness of an image.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
brightness (float, optional): Magnitude for brightness jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
brightness
>=
0
,
'`brightness` must be positive'
brightness
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
brightness
(
image
,
brightness
)
def
random_contrast
(
image
:
tf
.
Tensor
,
contrast
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
contrast
>=
0
,
'`contrast` must be positive'
contrast
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
contrast
(
image
,
contrast
)
def
random_saturation
(
image
:
tf
.
Tensor
,
saturation
:
float
=
0.
,
seed
:
Optional
[
int
]
=
None
)
->
tf
.
Tensor
:
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
image (tf.Tensor): Of shape [height, width, 3] and type uint8.
saturation (float, optional): Magnitude for saturation jitter. Defaults to
0.
seed (int, optional): Random seed. Defaults to None.
Returns:
tf.Tensor: The augmented `image` of type uint8.
"""
assert
saturation
>=
0
,
'`saturation` must be positive'
saturation
=
tf
.
random
.
uniform
([],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
_saturation
(
image
,
saturation
)
def
_saturation
(
image
:
tf
.
Tensor
,
saturation
:
Optional
[
float
]
=
0.
)
->
tf
.
Tensor
:
return
augment
.
blend
(
tf
.
repeat
(
tf
.
image
.
rgb_to_grayscale
(
image
),
3
,
axis
=-
1
),
image
,
saturation
)
def
random_crop_image_with_boxes_and_labels
(
img
,
boxes
,
labels
,
min_scale
,
def
random_crop_image_with_boxes_and_labels
(
img
,
boxes
,
labels
,
min_scale
,
aspect_ratio_range
,
aspect_ratio_range
,
min_overlap_params
,
max_retry
):
min_overlap_params
,
max_retry
):
...
...
official/vision/beta/ops/preprocess_ops_test.py
View file @
42ad9d5e
...
@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -197,6 +197,19 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_
=
preprocess_ops
.
random_crop_image_v2
(
_
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
tf
.
constant
([
input_height
,
input_width
,
3
],
tf
.
int32
))
image_bytes
,
tf
.
constant
([
input_height
,
input_width
,
3
],
tf
.
int32
))
@
parameterized
.
parameters
((
400
,
600
,
0
),
(
400
,
600
,
0.4
),
(
600
,
400
,
1.4
))
def
testColorJitter
(
self
,
input_height
,
input_width
,
color_jitter
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
jittered_image
=
preprocess_ops
.
color_jitter
(
image
,
color_jitter
,
color_jitter
,
color_jitter
)
assert
jittered_image
.
shape
==
image
.
shape
@
parameterized
.
parameters
((
400
,
600
,
0
),
(
400
,
600
,
0.4
),
(
600
,
400
,
1
))
def
testSaturation
(
self
,
input_height
,
input_width
,
saturation
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
jittered_image
=
preprocess_ops
.
_saturation
(
image
,
saturation
)
assert
jittered_image
.
shape
==
image
.
shape
@
parameterized
.
parameters
((
640
,
640
,
20
),
(
1280
,
1280
,
30
))
@
parameterized
.
parameters
((
640
,
640
,
20
),
(
1280
,
1280
,
30
))
def
test_random_crop
(
self
,
input_height
,
input_width
,
num_boxes
):
def
test_random_crop
(
self
,
input_height
,
input_width
,
num_boxes
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
...
...
official/vision/beta/projects/vit/README.md
View file @
42ad9d5e
# Vision Transformer (ViT)
# Vision Transformer (ViT)
and Data-Efficient Image Transformer (DEIT)
**DISCLAIMER**
: This implementation is still under development. No support will
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
be provided during the development phase.
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2010.11929)
-
[

](https://arxiv.org/abs/2012.12877)
This repository is the implementations of Vision Transformer (ViT)
in
This repository is the implementations of Vision Transformer (ViT)
and
TensorFlow 2.
Data-Efficient Image Transformer (DEIT) in
TensorFlow 2.
*
Paper title:
*
Paper title:
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
-
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
\ No newline at end of file
-
[
Training data-efficient image transformers & distillation through attention
](
https://arxiv.org/pdf/2012.12877.pdf
)
.
official/vision/beta/projects/vit/configs/backbones.py
View file @
42ad9d5e
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
...
@@ -42,6 +42,8 @@ class VisionTransformer(hyperparams.Config):
hidden_size
:
int
=
1
hidden_size
:
int
=
1
patch_size
:
int
=
16
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
transformer
:
Transformer
=
Transformer
()
init_stochastic_depth_rate
:
float
=
0.0
original_init
:
bool
=
True
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/projects/vit/configs/image_classification.py
View file @
42ad9d5e
...
@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
...
@@ -44,6 +44,7 @@ class ImageClassificationModel(hyperparams.Config):
use_sync_bn
=
False
)
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
add_head_batch_norm
:
bool
=
False
kernel_initializer
:
str
=
'random_uniform'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
...
@@ -51,6 +52,7 @@ class Losses(hyperparams.Config):
one_hot
:
bool
=
True
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
soft_labels
:
bool
=
False
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)(
...
@@ -79,6 +81,87 @@ task_factory.register_task_cls(ImageClassificationTask)(
image_classification
.
ImageClassificationTask
)
image_classification
.
ImageClassificationTask
)
@
exp_factory
.
register_config_factory
(
'deit_imagenet_pretrain'
)
def
image_classification_imagenet_deit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
eval_batch_size
=
4096
# originally was 1024 but 4096 better for tpu v3-32
num_classes
=
1001
label_smoothing
=
0.1
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
num_classes
,
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
,
init_stochastic_depth_rate
=
0.1
,
original_init
=
False
,
transformer
=
backbones
.
Transformer
(
dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
)))),
losses
=
Losses
(
l2_weight_decay
=
0.0
,
label_smoothing
=
label_smoothing
,
one_hot
=
False
,
soft_labels
=
True
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
,
aug_type
=
common
.
Augmentation
(
type
=
'randaug'
,
randaug
=
common
.
RandAugment
(
magnitude
=
9
,
exclude_ops
=
[
'Cutout'
])),
mixup_and_cutmix
=
common
.
MixupAndCutmix
(
label_smoothing
=
label_smoothing
)),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.05
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.0005
*
train_batch_size
/
512
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
5
*
steps_per_epoch
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
"""Image classification on imagenet with vision transformer."""
...
@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
...
@@ -90,6 +173,7 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
model
=
ImageClassificationModel
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
input_size
=
[
224
,
224
,
3
],
kernel_initializer
=
'zeros'
,
backbone
=
backbones
.
Backbone
(
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
vit
=
backbones
.
VisionTransformer
(
...
@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
...
@@ -116,12 +200,13 @@ def image_classification_imagenet_vit_pretrain() -> cfg.ExperimentConfig:
'adamw'
:
{
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
'gradient_clip_norm'
:
0.0
}
}
},
},
'learning_rate'
:
{
'learning_rate'
:
{
'type'
:
'cosine'
,
'type'
:
'cosine'
,
'cosine'
:
{
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'initial_learning_rate'
:
0.003
*
train_batch_size
/
4096
,
'decay_steps'
:
300
*
steps_per_epoch
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
}
},
},
...
...
official/vision/beta/projects/vit/modeling/nn_blocks.py
0 → 100644
View file @
42ad9d5e
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Keras-based TransformerEncoder block layer."""
import
tensorflow
as
tf
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.layers.nn_layers
import
StochasticDepth
class
TransformerEncoderBlock
(
keras_nlp
.
layers
.
TransformerEncoderBlock
):
"""TransformerEncoderBlock layer with stochastic depth."""
def
__init__
(
self
,
*
args
,
stochastic_depth_drop_rate
=
0.0
,
**
kwargs
):
"""Initializes TransformerEncoderBlock."""
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_stochastic_depth_drop_rate
=
stochastic_depth_drop_rate
def
build
(
self
,
input_shape
):
if
self
.
_stochastic_depth_drop_rate
:
self
.
_stochastic_depth
=
StochasticDepth
(
self
.
_stochastic_depth_drop_rate
)
else
:
self
.
_stochastic_depth
=
lambda
x
,
*
args
,
**
kwargs
:
tf
.
identity
(
x
)
super
().
build
(
input_shape
)
def
get_config
(
self
):
config
=
{
"stochastic_depth_drop_rate"
:
self
.
_stochastic_depth_drop_rate
}
base_config
=
super
().
get_config
()
return
dict
(
list
(
base_config
.
items
())
+
list
(
config
.
items
()))
def
call
(
self
,
inputs
,
training
=
None
):
"""Transformer self-attention encoder block call."""
if
isinstance
(
inputs
,
(
list
,
tuple
)):
if
len
(
inputs
)
==
2
:
input_tensor
,
attention_mask
=
inputs
key_value
=
None
elif
len
(
inputs
)
==
3
:
input_tensor
,
key_value
,
attention_mask
=
inputs
else
:
raise
ValueError
(
"Unexpected inputs to %s with length at %d"
%
(
self
.
__class__
,
len
(
inputs
)))
else
:
input_tensor
,
key_value
,
attention_mask
=
(
inputs
,
None
,
None
)
if
self
.
_output_range
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
[:,
0
:
self
.
_output_range
,
:]
if
attention_mask
is
not
None
:
attention_mask
=
attention_mask
[:,
0
:
self
.
_output_range
,
:]
else
:
if
self
.
_norm_first
:
source_tensor
=
input_tensor
input_tensor
=
self
.
_attention_layer_norm
(
input_tensor
)
if
key_value
is
not
None
:
key_value
=
self
.
_attention_layer_norm
(
key_value
)
target_tensor
=
input_tensor
if
key_value
is
None
:
key_value
=
input_tensor
attention_output
=
self
.
_attention_layer
(
query
=
target_tensor
,
value
=
key_value
,
attention_mask
=
attention_mask
)
attention_output
=
self
.
_attention_dropout
(
attention_output
)
if
self
.
_norm_first
:
attention_output
=
source_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
)
else
:
attention_output
=
self
.
_attention_layer_norm
(
target_tensor
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
if
self
.
_norm_first
:
source_attention_output
=
attention_output
attention_output
=
self
.
_output_layer_norm
(
attention_output
)
inner_output
=
self
.
_intermediate_dense
(
attention_output
)
inner_output
=
self
.
_intermediate_activation_layer
(
inner_output
)
inner_output
=
self
.
_inner_dropout_layer
(
inner_output
)
layer_output
=
self
.
_output_dense
(
inner_output
)
layer_output
=
self
.
_output_dropout
(
layer_output
)
if
self
.
_norm_first
:
return
source_attention_output
+
self
.
_stochastic_depth
(
layer_output
,
training
=
training
)
# During mixed precision training, layer norm output is always fp32 for now.
# Casts fp32 for the subsequent add.
layer_output
=
tf
.
cast
(
layer_output
,
tf
.
float32
)
return
self
.
_output_layer_norm
(
layer_output
+
self
.
_stochastic_depth
(
attention_output
,
training
=
training
))
official/vision/beta/projects/vit/modeling/vit.py
View file @
42ad9d5e
...
@@ -17,17 +17,24 @@
...
@@ -17,17 +17,24 @@
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.backbones
import
factory
from
official.vision.beta.modeling.layers
import
nn_layers
from
official.vision.beta.projects.vit.modeling
import
nn_blocks
layers
=
tf
.
keras
.
layers
layers
=
tf
.
keras
.
layers
VIT_SPECS
=
{
VIT_SPECS
=
{
'vit-t
esting
'
:
'vit-t
i16
'
:
dict
(
dict
(
hidden_size
=
1
,
hidden_size
=
1
92
,
patch_size
=
16
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
transformer
=
dict
(
mlp_dim
=
768
,
num_heads
=
3
,
num_layers
=
12
),
),
'vit-s16'
:
dict
(
hidden_size
=
384
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1536
,
num_heads
=
6
,
num_layers
=
12
),
),
),
'vit-b16'
:
'vit-b16'
:
dict
(
dict
(
...
@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -112,6 +119,8 @@ class Encoder(tf.keras.layers.Layer):
attention_dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
inputs_positions
=
None
,
init_stochastic_depth_rate
=
0.0
,
kernel_initializer
=
'glorot_uniform'
,
**
kwargs
):
**
kwargs
):
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_num_layers
=
num_layers
...
@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -121,6 +130,8 @@ class Encoder(tf.keras.layers.Layer):
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
self
.
_inputs_positions
=
inputs_positions
self
.
_init_stochastic_depth_rate
=
init_stochastic_depth_rate
self
.
_kernel_initializer
=
kernel_initializer
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
self
.
_pos_embed
=
AddPositionEmbs
(
...
@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer):
...
@@ -131,15 +142,18 @@ class Encoder(tf.keras.layers.Layer):
self
.
_encoder_layers
=
[]
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for
_
in
range
(
self
.
_num_layers
):
for
i
in
range
(
self
.
_num_layers
):
encoder_layer
=
keras_nlp
.
layer
s
.
TransformerEncoderBlock
(
encoder_layer
=
nn_block
s
.
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
kernel_initializer
=
self
.
_kernel_initializer
,
norm_first
=
True
,
norm_first
=
True
,
stochastic_depth_drop_rate
=
nn_layers
.
get_stochastic_depth_rate
(
self
.
_init_stochastic_depth_rate
,
i
+
1
,
self
.
_num_layers
),
norm_epsilon
=
1e-6
)
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
...
@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model):
...
@@ -164,12 +178,14 @@ class VisionTransformer(tf.keras.Model):
num_layers
=
12
,
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
dropout_rate
=
0.1
,
init_stochastic_depth_rate
=
0.0
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
patch_size
=
16
,
hidden_size
=
768
,
hidden_size
=
768
,
representation_size
=
0
,
representation_size
=
0
,
classifier
=
'token'
,
classifier
=
'token'
,
kernel_regularizer
=
None
):
kernel_regularizer
=
None
,
original_init
=
True
):
"""VisionTransformer initialization function."""
"""VisionTransformer initialization function."""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
...
@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model):
...
@@ -178,7 +194,8 @@ class VisionTransformer(tf.keras.Model):
kernel_size
=
patch_size
,
kernel_size
=
patch_size
,
strides
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
inputs
)
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
rows_axis
,
cols_axis
=
(
1
,
2
)
...
@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model):
...
@@ -203,7 +220,10 @@ class VisionTransformer(tf.keras.Model):
num_heads
=
num_heads
,
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
)(
kernel_regularizer
=
kernel_regularizer
,
kernel_initializer
=
'glorot_uniform'
if
original_init
else
dict
(
class_name
=
'TruncatedNormal'
,
config
=
dict
(
stddev
=
.
02
)),
init_stochastic_depth_rate
=
init_stochastic_depth_rate
)(
x
)
x
)
if
classifier
==
'token'
:
if
classifier
==
'token'
:
...
@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model):
...
@@ -215,7 +235,8 @@ class VisionTransformer(tf.keras.Model):
x
=
tf
.
keras
.
layers
.
Dense
(
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
)(
name
=
'pre_logits'
,
kernel_initializer
=
'lecun_normal'
if
original_init
else
'he_uniform'
)(
x
)
x
)
x
=
tf
.
nn
.
tanh
(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
else
:
...
@@ -247,9 +268,11 @@ def build_vit(input_specs,
...
@@ -247,9 +268,11 @@ def build_vit(input_specs,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
init_stochastic_depth_rate
=
backbone_cfg
.
init_stochastic_depth_rate
,
input_specs
=
input_specs
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifier
=
backbone_cfg
.
classifier
,
classifier
=
backbone_cfg
.
classifier
,
kernel_regularizer
=
l2_regularizer
)
kernel_regularizer
=
l2_regularizer
,
original_init
=
backbone_cfg
.
original_init
)
official/vision/beta/projects/yolo/configs/darknet_classification.py
View file @
42ad9d5e
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
...
@@ -58,7 +58,7 @@ class ImageClassificationTask(cfg.TaskConfig):
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
@
exp_factory
.
register_config_factory
(
'darknet_classification'
)
def
image
_classification
()
->
cfg
.
ExperimentConfig
:
def
darknet
_classification
()
->
cfg
.
ExperimentConfig
:
"""Image classification general."""
"""Image classification general."""
return
cfg
.
ExperimentConfig
(
return
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(),
task
=
ImageClassificationTask
(),
...
...
official/vision/beta/tasks/image_classification.py
View file @
42ad9d5e
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
...
@@ -26,6 +26,7 @@ from official.vision.beta.dataloaders import classification_input
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
input_reader_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.dataloaders
import
tfds_factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.modeling
import
factory
from
official.vision.beta.ops
import
augment
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
@
task_factory
.
register_task_cls
(
exp_cfg
.
ImageClassificationTask
)
...
@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task):
...
@@ -103,14 +104,26 @@ class ImageClassificationTask(base_task.Task):
decode_jpeg_only
=
params
.
decode_jpeg_only
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
is_multilabel
=
is_multilabel
,
is_multilabel
=
is_multilabel
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
postprocess_fn
=
None
if
params
.
mixup_and_cutmix
:
postprocess_fn
=
augment
.
MixupAndCutmix
(
mixup_alpha
=
params
.
mixup_and_cutmix
.
mixup_alpha
,
cutmix_alpha
=
params
.
mixup_and_cutmix
.
cutmix_alpha
,
prob
=
params
.
mixup_and_cutmix
.
prob
,
label_smoothing
=
params
.
mixup_and_cutmix
.
label_smoothing
,
num_classes
=
num_classes
)
reader
=
input_reader_factory
.
input_reader_generator
(
reader
=
input_reader_factory
.
input_reader_generator
(
params
,
params
,
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
dataset_fn
=
dataset_fn
.
pick_dataset_fn
(
params
.
file_type
),
decoder_fn
=
decoder
.
decode
,
decoder_fn
=
decoder
.
decode
,
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
))
parser_fn
=
parser
.
parse_fn
(
params
.
is_training
),
postprocess_fn
=
postprocess_fn
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
dataset
=
reader
.
read
(
input_context
=
input_context
)
...
@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task):
...
@@ -140,6 +153,9 @@ class ImageClassificationTask(base_task.Task):
model_outputs
,
model_outputs
,
from_logits
=
True
,
from_logits
=
True
,
label_smoothing
=
losses_config
.
label_smoothing
)
label_smoothing
=
losses_config
.
label_smoothing
)
elif
losses_config
.
soft_labels
:
total_loss
=
tf
.
nn
.
softmax_cross_entropy_with_logits
(
labels
,
model_outputs
)
else
:
else
:
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
total_loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
model_outputs
,
from_logits
=
True
)
labels
,
model_outputs
,
from_logits
=
True
)
...
@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task):
...
@@ -161,7 +177,8 @@ class ImageClassificationTask(base_task.Task):
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
not
is_multilabel
:
if
not
is_multilabel
:
k
=
self
.
task_config
.
evaluation
.
top_k
k
=
self
.
task_config
.
evaluation
.
top_k
if
self
.
task_config
.
losses
.
one_hot
:
if
(
self
.
task_config
.
losses
.
one_hot
or
self
.
task_config
.
losses
.
soft_labels
):
metrics
=
[
metrics
=
[
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
CategoricalAccuracy
(
name
=
'accuracy'
),
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
tf
.
keras
.
metrics
.
TopKCategoricalAccuracy
(
...
@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task):
...
@@ -223,7 +240,9 @@ class ImageClassificationTask(base_task.Task):
# Computes per-replica loss.
# Computes per-replica loss.
loss
=
self
.
build_losses
(
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
# Scales loss as the default gradients allreduce performs sum inside the
# Scales loss as the default gradients allreduce performs sum inside the
# optimizer.
# optimizer.
scaled_loss
=
loss
/
num_replicas
scaled_loss
=
loss
/
num_replicas
...
@@ -266,14 +285,18 @@ class ImageClassificationTask(base_task.Task):
...
@@ -266,14 +285,18 @@ class ImageClassificationTask(base_task.Task):
A dictionary of logs.
A dictionary of logs.
"""
"""
features
,
labels
=
inputs
features
,
labels
=
inputs
one_hot
=
self
.
task_config
.
losses
.
one_hot
soft_labels
=
self
.
task_config
.
losses
.
soft_labels
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
is_multilabel
=
self
.
task_config
.
train_data
.
is_multilabel
if
self
.
task_config
.
losses
.
one_hot
and
not
is_multilabel
:
if
(
one_hot
or
soft_labels
)
and
not
is_multilabel
:
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
labels
=
tf
.
one_hot
(
labels
,
self
.
task_config
.
model
.
num_classes
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
self
.
inference_step
(
features
,
model
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
outputs
=
tf
.
nest
.
map_structure
(
lambda
x
:
tf
.
cast
(
x
,
tf
.
float32
),
outputs
)
loss
=
self
.
build_losses
(
model_outputs
=
outputs
,
labels
=
labels
,
loss
=
self
.
build_losses
(
aux_losses
=
model
.
losses
)
model_outputs
=
outputs
,
labels
=
labels
,
aux_losses
=
model
.
losses
)
logs
=
{
self
.
loss
:
loss
}
logs
=
{
self
.
loss
:
loss
}
if
metrics
:
if
metrics
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment