Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d4053d80
Commit
d4053d80
authored
Aug 20, 2021
by
Simon Geisler
Browse files
fix issues with color jitter and random erase
parent
34c6530a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
66 additions
and
49 deletions
+66
-49
official/vision/beta/configs/common.py
official/vision/beta/configs/common.py
+1
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+9
-9
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+28
-25
official/vision/beta/ops/preprocess_ops.py
official/vision/beta/ops/preprocess_ops.py
+17
-15
official/vision/beta/ops/preprocess_ops_test.py
official/vision/beta/ops/preprocess_ops_test.py
+11
-0
No files found.
official/vision/beta/configs/common.py
View file @
d4053d80
...
@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config):
...
@@ -31,6 +31,7 @@ class RandAugment(hyperparams.Config):
magnitude
:
float
=
10
magnitude
:
float
=
10
cutout_const
:
float
=
40
cutout_const
:
float
=
40
translate_const
:
float
=
10
translate_const
:
float
=
10
magnitude_std
:
float
=
0.0
prob_to_apply
:
Optional
[
float
]
=
None
prob_to_apply
:
Optional
[
float
]
=
None
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
exclude_ops
:
List
[
str
]
=
dataclasses
.
field
(
default_factory
=
list
)
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
d4053d80
...
@@ -196,6 +196,11 @@ class Parser(parser.Parser):
...
@@ -196,6 +196,11 @@ class Parser(parser.Parser):
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Apply autoaug or randaug.
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
if
self
.
_augmenter
is
not
None
:
image
=
self
.
_augmenter
.
distort
(
image
)
image
=
self
.
_augmenter
.
distort
(
image
)
...
@@ -205,6 +210,10 @@ class Parser(parser.Parser):
...
@@ -205,6 +210,10 @@ class Parser(parser.Parser):
offset
=
MEAN_RGB
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
@@ -231,20 +240,11 @@ class Parser(parser.Parser):
...
@@ -231,20 +240,11 @@ class Parser(parser.Parser):
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Color jitter.
if
self
.
_color_jitter
>
0
:
image
=
preprocess_ops
.
color_jitter
(
image
,
self
.
_color_jitter
,
self
.
_color_jitter
,
self
.
_color_jitter
)
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
image
=
self
.
_random_erasing
.
distort
(
image
)
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
official/vision/beta/ops/augment.py
View file @
d4053d80
...
@@ -1359,7 +1359,7 @@ class RandomErasing(ImageAugment):
...
@@ -1359,7 +1359,7 @@ class RandomErasing(ImageAugment):
"""
"""
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
uniform_random
=
tf
.
random
.
uniform
(
shape
=
[],
minval
=
0.
,
maxval
=
1.0
)
mirror_cond
=
tf
.
less
(
uniform_random
,
.
5
)
mirror_cond
=
tf
.
less
(
uniform_random
,
.
5
)
tf
.
cond
(
mirror_cond
,
self
.
_erase
,
lambda
:
image
)
tf
.
cond
(
mirror_cond
,
lambda
:
self
.
_erase
(
image
)
,
lambda
:
image
)
return
image
return
image
@
tf
.
function
@
tf
.
function
...
@@ -1374,31 +1374,34 @@ class RandomErasing(ImageAugment):
...
@@ -1374,31 +1374,34 @@ class RandomErasing(ImageAugment):
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
area
=
tf
.
cast
(
image_width
*
image_height
,
tf
.
float32
)
for
_
in
range
(
count
):
for
_
in
range
(
count
):
# Work around since break is not supported in tf.function
is_trial_successfull
=
False
for
_
in
range
(
self
.
_trials
):
for
_
in
range
(
self
.
_trials
):
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
if
not
is_trial_successfull
:
minval
=
area
*
self
.
_min_area
,
erase_area
=
tf
.
random
.
uniform
(
shape
=
[],
maxval
=
area
*
self
.
_max_area
)
minval
=
area
*
self
.
_min_area
,
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
maxval
=
area
*
self
.
_max_area
)
shape
=
[],
minval
=
self
.
_min_log_aspect
,
aspect_ratio
=
tf
.
math
.
exp
(
tf
.
random
.
uniform
(
maxval
=
self
.
_max_log_aspect
))
shape
=
[],
minval
=
self
.
_min_log_aspect
,
maxval
=
self
.
_max_log_aspect
))
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_height
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
*
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
half_width
=
tf
.
cast
(
tf
.
math
.
round
(
tf
.
math
.
sqrt
(
erase_area
/
aspect_ratio
)
/
2
),
dtype
=
tf
.
int32
)
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
center_height
=
tf
.
random
.
uniform
(
if
2
*
half_height
<
image_height
and
2
*
half_width
<
image_width
:
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
center_height
=
tf
.
random
.
uniform
(
dtype
=
tf
.
int32
)
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_height
-
2
*
half_height
),
center_width
=
tf
.
random
.
uniform
(
dtype
=
tf
.
int32
)
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
center_width
=
tf
.
random
.
uniform
(
dtype
=
tf
.
int32
)
shape
=
[],
minval
=
0
,
maxval
=
int
(
image_width
-
2
*
half_width
),
dtype
=
tf
.
int32
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
image
=
_fill_rectangle
(
image
,
center_width
,
center_height
,
half_width
,
half_height
,
replace
=
None
)
break
is_trial_successfull
=
True
return
image
return
image
...
...
official/vision/beta/ops/preprocess_ops.py
View file @
d4053d80
...
@@ -566,7 +566,7 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
...
@@ -566,7 +566,7 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
"""Applies color jitter to an image, similarly to torchvision`s ColorJitter.
Args:
Args:
image (tf.Tensor): Of shape [height, width, 3]
representing an image
.
image (tf.Tensor): Of shape [height, width, 3]
and type uint8
.
brightness (float, optional): Magnitude for brightness jitter.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
contrast (float, optional): Magnitude for contrast jitter. Defaults to 0.
...
@@ -575,8 +575,9 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
...
@@ -575,8 +575,9 @@ def color_jitter(image: tf.Tensor, brightness: Optional[float] = 0.,
seed (int, optional): Random seed. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
Returns:
Returns:
tf.Tensor: The augmented
version of `image`
.
tf.Tensor: The augmented
`image` of type uint8
.
"""
"""
image
=
tf
.
cast
(
image
,
dtype
=
tf
.
uint8
)
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_brightness
(
image
,
brightness
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_contrast
(
image
,
contrast
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
image
=
random_saturation
(
image
,
saturation
,
seed
=
seed
)
...
@@ -588,17 +589,17 @@ def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0.,
...
@@ -588,17 +589,17 @@ def random_brightness(image: tf.Tensor, brightness: Optional[float] = 0.,
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
"""Jitters brightness of an image, similarly to torchvision`s ColorJitter.
Args:
Args:
image (tf.Tensor): Of shape [height, width, 3]
representing an image
.
image (tf.Tensor): Of shape [height, width, 3]
and type uint8
.
brightness (float, optional): Magnitude for brightness jitter.
brightness (float, optional): Magnitude for brightness jitter.
Defaults to 0.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
Returns:
Returns:
tf.Tensor: The augmented
version of `image`
.
tf.Tensor: The augmented
`image` of type uint8
.
"""
"""
assert
brightness
>=
0
and
brightness
<=
1.
,
'`brightness` must be
in [0, 1]
'
assert
brightness
>=
0
,
'`brightness` must be
positive
'
brightness
=
tf
.
random
.
uniform
(
brightness
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
)
[],
max
(
0
,
1
-
brightness
),
1
+
brightness
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
brightness
(
image
,
brightness
)
return
augment
.
brightness
(
image
,
brightness
)
...
@@ -607,17 +608,17 @@ def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0.,
...
@@ -607,17 +608,17 @@ def random_contrast(image: tf.Tensor, contrast: Optional[float] = 0.,
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
"""Jitters contrast of an image, similarly to torchvision`s ColorJitter.
Args:
Args:
image (tf.Tensor): Of shape [height, width, 3]
representing an image
.
image (tf.Tensor): Of shape [height, width, 3]
and type uint8
.
contrast (float, optional): Magnitude for contrast jitter.
contrast (float, optional): Magnitude for contrast jitter.
Defaults to 0.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
Returns:
Returns:
tf.Tensor: The augmented
version of `image`
.
tf.Tensor: The augmented
`image` of type uint8
.
"""
"""
assert
contrast
>=
0
and
contrast
<=
1.
,
'`contrast` must be
in [0, 1]
'
assert
contrast
>=
0
,
'`contrast` must be
positive
'
contrast
=
tf
.
random
.
uniform
(
contrast
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
)
[],
max
(
0
,
1
-
contrast
),
1
+
contrast
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
contrast
(
image
,
contrast
)
return
augment
.
contrast
(
image
,
contrast
)
...
@@ -626,15 +627,16 @@ def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
...
@@ -626,15 +627,16 @@ def random_saturation(image: tf.Tensor, saturation: Optional[float] = 0.,
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
"""Jitters saturation of an image, similarly to torchvision`s ColorJitter.
Args:
Args:
image (tf.Tensor): Of shape [height, width, 3]
representing an image
.
image (tf.Tensor): Of shape [height, width, 3]
and type uint8
.
saturation (float, optional): Magnitude for saturation jitter.
saturation (float, optional): Magnitude for saturation jitter.
Defaults to 0.
Defaults to 0.
seed (int, optional): Random seed. Defaults to None.
seed (int, optional): Random seed. Defaults to None.
Returns:
Returns:
tf.Tensor: The augmented
version of `image`
.
tf.Tensor: The augmented
`image` of type uint8
.
"""
"""
assert
saturation
>=
0
and
saturation
<=
1.
,
'`saturation` must be
in [0, 1]
'
assert
saturation
>=
0
,
'`saturation` must be
positive
'
saturation
=
tf
.
random
.
uniform
(
saturation
=
tf
.
random
.
uniform
(
[],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
)
[],
max
(
0
,
1
-
saturation
),
1
+
saturation
,
seed
=
seed
,
dtype
=
tf
.
float32
)
return
augment
.
blend
(
tf
.
image
.
rgb_to_grayscale
(
image
),
image
,
saturation
)
return
augment
.
blend
(
tf
.
repeat
(
tf
.
image
.
rgb_to_grayscale
(
image
),
3
,
axis
=-
1
),
image
,
saturation
)
official/vision/beta/ops/preprocess_ops_test.py
View file @
d4053d80
...
@@ -225,6 +225,17 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -225,6 +225,17 @@ class InputUtilsTest(parameterized.TestCase, tf.test.TestCase):
_
=
preprocess_ops
.
random_crop_image_v2
(
_
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
tf
.
constant
([
input_height
,
input_width
,
3
],
tf
.
int32
))
image_bytes
,
tf
.
constant
([
input_height
,
input_width
,
3
],
tf
.
int32
))
@
parameterized
.
parameters
(
(
400
,
600
,
0
),
(
400
,
600
,
0.4
),
(
600
,
400
,
1.4
)
)
def
testColorJitter
(
self
,
input_height
,
input_width
,
color_jitter
):
image
=
tf
.
convert_to_tensor
(
np
.
random
.
rand
(
input_height
,
input_width
,
3
))
jittered_image
=
preprocess_ops
.
color_jitter
(
image
,
color_jitter
,
color_jitter
,
color_jitter
)
assert
jittered_image
.
shape
==
image
.
shape
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment