Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1e205552
Commit
1e205552
authored
Dec 03, 2020
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 345529749
parent
495fbc4a
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
1136 additions
and
5 deletions
+1136
-5
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+1
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+23
-5
official/vision/beta/ops/augment.py
official/vision/beta/ops/augment.py
+981
-0
official/vision/beta/ops/augment_test.py
official/vision/beta/ops/augment_test.py
+130
-0
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+1
-0
No files found.
official/vision/beta/configs/image_classification.py
View file @
1e205552
...
...
@@ -34,6 +34,7 @@ class DataConfig(cfg.DataConfig):
dtype
:
str
=
'float32'
shuffle_buffer_size
:
int
=
10000
cycle_length
:
int
=
10
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'
@
dataclasses
.
dataclass
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
1e205552
...
...
@@ -13,11 +13,13 @@
# limitations under the License.
# ==============================================================================
"""Classification decoder and parser."""
from
typing
import
List
,
Optional
# Import libraries
import
tensorflow
as
tf
from
official.vision.beta.dataloaders
import
decoder
from
official.vision.beta.dataloaders
import
parser
from
official.vision.beta.ops
import
augment
from
official.vision.beta.ops
import
preprocess_ops
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
...
...
@@ -43,18 +45,20 @@ class Parser(parser.Parser):
"""Parser to parse an image and its annotations into a dictionary of tensors."""
def
__init__
(
self
,
output_size
,
num_classes
,
aug_rand_hflip
=
True
,
dtype
=
'float32'
):
output_size
:
List
[
int
],
num_classes
:
float
,
aug_rand_hflip
:
bool
=
True
,
aug_policy
:
Optional
[
str
]
=
None
,
dtype
:
str
=
'float32'
):
"""Initializes parameters for parsing annotations in the dataset.
Args:
output_size: `Tens
s
or` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_policy: `str`, augmentation policies. None, 'autoaug', or 'randaug'.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
"""
...
...
@@ -69,6 +73,16 @@ class Parser(parser.Parser):
self
.
_dtype
=
tf
.
bfloat16
else
:
raise
ValueError
(
'dtype {!r} is not supported!'
.
format
(
dtype
))
if
aug_policy
:
if
aug_policy
==
'autoaug'
:
self
.
_augmenter
=
augment
.
AutoAugment
()
elif
aug_policy
==
'randaug'
:
self
.
_augmenter
=
augment
.
RandAugment
(
num_layers
=
2
,
magnitude
=
20
)
else
:
raise
ValueError
(
'Augmentation policy {} not supported.'
.
format
(
aug_policy
))
else
:
self
.
_augmenter
=
None
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
...
...
@@ -93,6 +107,10 @@ class Parser(parser.Parser):
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
image
=
self
.
_augmenter
.
distort
(
image
)
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
...
...
official/vision/beta/ops/augment.py
0 → 100644
View file @
1e205552
This diff is collapsed.
Click to expand it.
official/vision/beta/ops/augment_test.py
0 → 100644
View file @
1e205552
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for autoaugment."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.ops
import
augment
def
get_dtype_test_cases
():
return
[
(
'uint8'
,
tf
.
uint8
),
(
'int32'
,
tf
.
int32
),
(
'float16'
,
tf
.
float16
),
(
'float32'
,
tf
.
float32
),
]
@
parameterized
.
named_parameters
(
get_dtype_test_cases
())
class
TransformsTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
"""Basic tests for fundamental transformations."""
def
test_to_from_4d
(
self
,
dtype
):
for
shape
in
[(
10
,
10
),
(
10
,
10
,
10
),
(
10
,
10
,
10
,
10
)]:
original_ndims
=
len
(
shape
)
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
image_4d
=
augment
.
to_4d
(
image
)
self
.
assertEqual
(
4
,
tf
.
rank
(
image_4d
))
self
.
assertAllEqual
(
image
,
augment
.
from_4d
(
image_4d
,
original_ndims
))
def
test_transform
(
self
,
dtype
):
image
=
tf
.
constant
([[
1
,
2
],
[
3
,
4
]],
dtype
=
dtype
)
self
.
assertAllEqual
(
augment
.
transform
(
image
,
transforms
=
[
1
]
*
8
),
[[
4
,
4
],
[
4
,
4
]])
def
test_translate
(
self
,
dtype
):
image
=
tf
.
constant
(
[[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
],
[
1
,
0
,
1
,
0
],
[
0
,
1
,
0
,
1
]],
dtype
=
dtype
)
translations
=
[
-
1
,
-
1
]
translated
=
augment
.
translate
(
image
=
image
,
translations
=
translations
)
expected
=
[[
1
,
0
,
1
,
1
],
[
0
,
1
,
0
,
0
],
[
1
,
0
,
1
,
1
],
[
1
,
0
,
1
,
1
]]
self
.
assertAllEqual
(
translated
,
expected
)
def
test_translate_shapes
(
self
,
dtype
):
translation
=
[
0
,
0
]
for
shape
in
[(
3
,
3
),
(
5
,
5
),
(
224
,
224
,
3
)]:
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
self
.
assertAllEqual
(
image
,
augment
.
translate
(
image
,
translation
))
def
test_translate_invalid_translation
(
self
,
dtype
):
image
=
tf
.
zeros
((
1
,
1
),
dtype
=
dtype
)
invalid_translation
=
[[[
1
,
1
]]]
with
self
.
assertRaisesRegex
(
TypeError
,
'rank 1 or 2'
):
_
=
augment
.
translate
(
image
,
invalid_translation
)
def
test_rotate
(
self
,
dtype
):
image
=
tf
.
reshape
(
tf
.
cast
(
tf
.
range
(
9
),
dtype
),
(
3
,
3
))
rotation
=
90.
transformed
=
augment
.
rotate
(
image
=
image
,
degrees
=
rotation
)
expected
=
[[
2
,
5
,
8
],
[
1
,
4
,
7
],
[
0
,
3
,
6
]]
self
.
assertAllEqual
(
transformed
,
expected
)
def
test_rotate_shapes
(
self
,
dtype
):
degrees
=
0.
for
shape
in
[(
3
,
3
),
(
5
,
5
),
(
224
,
224
,
3
)]:
image
=
tf
.
zeros
(
shape
,
dtype
=
dtype
)
self
.
assertAllEqual
(
image
,
augment
.
rotate
(
image
,
degrees
))
class
AutoaugmentTest
(
tf
.
test
.
TestCase
):
def
test_autoaugment
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
AutoAugment
()
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_randaug
(
self
):
"""Smoke test to be sure there are no syntax errors."""
image
=
tf
.
zeros
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
augmenter
=
augment
.
RandAugment
()
aug_image
=
augmenter
.
distort
(
image
)
self
.
assertEqual
((
224
,
224
,
3
),
aug_image
.
shape
)
def
test_all_policy_ops
(
self
):
"""Smoke test to be sure all augmentation functions can execute."""
prob
=
1
magnitude
=
10
replace_value
=
[
128
]
*
3
cutout_const
=
100
translate_const
=
250
image
=
tf
.
ones
((
224
,
224
,
3
),
dtype
=
tf
.
uint8
)
for
op_name
in
augment
.
NAME_TO_FUNC
:
func
,
_
,
args
=
augment
.
_parse_policy_info
(
op_name
,
prob
,
magnitude
,
replace_value
,
cutout_const
,
translate_const
)
image
=
func
(
image
,
*
args
)
self
.
assertEqual
((
224
,
224
,
3
),
image
.
shape
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/tasks/image_classification.py
View file @
1e205552
...
...
@@ -81,6 +81,7 @@ class ImageClassificationTask(base_task.Task):
parser
=
classification_input
.
Parser
(
output_size
=
input_size
[:
2
],
num_classes
=
num_classes
,
aug_policy
=
params
.
aug_policy
,
dtype
=
params
.
dtype
)
reader
=
input_reader
.
InputReader
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment