Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1e205552
Commit
1e205552
authored
Dec 03, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 345529749
parent
495fbc4a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1136 additions
and
5 deletions
+1136
-5
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+1
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+23
-5
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+981
-0
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+130
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+1
-0
No files found.
official/vision/beta/configs/image_classification.py
View file @
1e205552
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
dtype
:
str
=
'float32'
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
10000
shuffle_buffer_size
:
int
=
10000
cycle_length
:
int
=
10
cycle_length
:
int
=
10
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
1e205552
...
@@ -13,11 +13,13 @@
...
@@ -13,11 +13,13 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Classification decoder and parser."""
"""Classification decoder and parser."""
from
typing
import
List
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
preprocess_ops
from
official.vision.beta.ops
import
preprocess_ops
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
...
@@ -43,18 +45,20 @@ class Parser(parser.Parser):
...
@@ -43,18 +45,20 @@ class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
def
__init__
(
self
,
output_size
,
output_size
:
List
[
int
],
num_classes
,
num_classes
:
float
,
aug_rand_hflip
=
True
,
aug_rand_hflip
:
bool
=
True
,
dtype
=
'float32'
):
aug_policy
:
Optional
[
str
]
=
None
,
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
"""Initializes parameters for parsing annotations in the dataset.
Args:
Args:
output_size: `Tens
s
or` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
num_classes: `float`, number of classes.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
or 'bfloat16'.
"""
"""
...
@@ -69,6 +73,16 @@ class Parser(parser.Parser):
...
@@ -69,6 +73,16 @@ class Parser(parser.Parser):
self
.
_dtype
=
tf
.
bfloat16
self
.
_dtype
=
tf
.
bfloat16
else
:
else
:
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
if
aug_policy
:
if
aug_policy
==
'autoaug'
:
self
.
_augmenter
=
augment
.
AutoAugment
()
elif
aug_policy
==
'randaug'
:
self
.
_augmenter
=
augment
.
RandAugment
(
num_layers
=
2
,
magnitude
=
20
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_policy
))
else
:
self
.
_augmenter
=
None
def
_parse_train_data
(
self
,
decoded_tensors
):
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
"""Parses data for training."""
...
@@ -93,6 +107,10 @@ class Parser(parser.Parser):
...
@@ -93,6 +107,10 @@ class Parser(parser.Parser):
image
=
tf
.
image
.
resize
(
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
=
self
.
_augmenter
.
distort
(
image
)
# Normalizes image with mean and std pixel values.
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
offset
=
MEAN_RGB
,
...
...
official/vision/beta/ops/augment.py
0 → 100644
View file @
1e205552
This diff is collapsed.
Click to expand it.
official/vision/beta/ops/augment_test.py
0 → 100644
View file @
1e205552
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for autoaugment."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.ops
import
augment
def
get_dtype_test_cases
():
return
[
(
'uint8'
,
tf
.
uint8
),
(
'int32'
,
tf
.
int32
),
(
'float16'
,
tf
.
float16
),
(
'float32'
,
tf
.
float32
),
]
@
parameterized
.
named_parameters
(
get_dtype_test_cases
())
class
TransformsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
"""Basic tests for fundamental transformations."""
def
test_to_from_4d
(
self
,
dtype
):
for
shape
in
[(
10
,
10
),
(
10
,
10
,
10
),
(
10
,
10
,
10
,
10
)]:
original_ndims
=
len
(
shape
)
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
image_4d
=
augment
.
to_4d
(
image
)
self
.
assertEqual
(
4
,
tf
.
rank
(
image_4d
))
self
.
assertAllEqual
(
image
,
augment
.
from_4d
(
image_4d
,
original_ndims
))
def
test_transform
(
self
,
dtype
):
image
=
tf
.
constant
([[
1
,
2
],
[
3
,
4
]],
dtype
=
dtype
)
self
.
assertAllEqual
(
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
[[
4
,
4
],
[
4
,
4
]])
def
test_translate
(
self
,
dtype
):
image
=
tf
.
constant
(
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
dtype
=
dtype
)
translations
=
[
-
1
,
-
1
]
translated
=
augment
.
translate
(
image
=
image
,
translations
=
translations
)
expected
=
[[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
1
,
0
,
1
,
1
]]
self
.
assertAllEqual
(
translated
,
expected
)
def
test_translate_shapes
(
self
,
dtype
):
translation
=
[
0
,
0
]
for
shape
in
[(
3
,
3
),
(
5
,
5
),
(
224
,
224
,
3
)]:
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
self
.
assertAllEqual
(
image
,
augment
.
translate
(
image
,
translation
))
def
test_translate_invalid_translation
(
self
,
dtype
):
image
=
tf
.
zeros
((
1
,
1
),
dtype
=
dtype
)
invalid_translation
=
[[[
1
,
1
]]]
with
self
.
assertRaisesRegex
(
TypeError
,
'rank 1 or 2'
):
_
=
augment
.
translate
(
image
,
invalid_translation
)
def
test_rotate
(
self
,
dtype
):
image
=
tf
.
reshape
(
tf
.
cast
(
tf
.
range
(
9
),
dtype
),
(
3
,
3
))
rotation
=
90.
transformed
=
augment
.
rotate
(
image
=
image
,
degrees
=
rotation
)
expected
=
[[
2
,
5
,
8
],
[
1
,
4
,
7
],
[
0
,
3
,
6
]]
self
.
assertAllEqual
(
transformed
,
expected
)
def
test_rotate_shapes
(
self
,
dtype
):
degrees
=
0.
for
shape
in
[(
3
,
3
),
(
5
,
5
),
(
224
,
224
,
3
)]:
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
self
.
assertAllEqual
(
image
,
augment
.
rotate
(
image
,
degrees
))
class
AutoaugmentTest
(
tf
.
test
.
TestCase
):
def
test_autoaugment
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
AutoAugment
()
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
RandAugment
()
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
1e205552
...
@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task):
parser
=
classification_input
.
Parser
(
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
,
num_classes
=
num_classes
,
aug_policy
=
params
.
aug_policy
,
dtype
=
params
.
dtype
)
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
reader
=
input_reader
.
InputReader
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment