Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
242f4098
Commit
242f4098
authored
Sep 28, 2022
by
Chaochao Yan
Committed by
A. Unique TensorFlower
Sep 28, 2022
Browse files
Internal change
PiperOrigin-RevId: 477486513
parent
051f1c96
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
15 deletions
+21
-15
official/vision/configs/image_classification.py
official/vision/configs/image_classification.py
+2
-1
official/vision/dataloaders/classification_input.py
official/vision/dataloaders/classification_input.py
+19
-14
No files found.
official/vision/configs/image_classification.py
View file @
242f4098
...
...
@@ -15,7 +15,7 @@
"""Image classification configuration definition."""
import
dataclasses
import
os
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
,
Tuple
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
...
...
@@ -37,6 +37,7 @@ class DataConfig(cfg.DataConfig):
is_multilabel
:
bool
=
False
aug_rand_hflip
:
bool
=
True
aug_crop
:
Optional
[
bool
]
=
True
crop_area_range
:
Optional
[
Tuple
[
float
,
float
]]
=
(
0.08
,
1.0
)
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
...
...
official/vision/dataloaders/classification_input.py
View file @
242f4098
...
...
@@ -13,7 +13,7 @@
# limitations under the License.
"""Classification decoder and parser."""
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
# Import libraries
import
tensorflow
as
tf
...
...
@@ -54,8 +54,8 @@ class Decoder(decoder.Decoder):
self
.
_keys_to_features
=
keys_to_features
def
decode
(
self
,
serialized_example
):
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
class
Parser
(
parser
.
Parser
):
...
...
@@ -73,7 +73,8 @@ class Parser(parser.Parser):
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
is_multilabel
:
bool
=
False
,
dtype
:
str
=
'float32'
):
dtype
:
str
=
'float32'
,
crop_area_range
:
Optional
[
Tuple
[
float
,
float
]]
=
(
0.08
,
1.0
)):
"""Initializes parameters for parsing annotations in the dataset.
Args:
...
...
@@ -84,8 +85,8 @@ class Parser(parser.Parser):
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal
flip.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal
flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
...
...
@@ -98,6 +99,10 @@ class Parser(parser.Parser):
is_multilabel: A `bool`, whether or not each example has multiple labels.
dtype: `str`, cast output image in dtype. It can be 'float32', 'float16',
or 'bfloat16'.
crop_area_range: An optional `tuple` of (min_area, max_area) for image
random crop function to constraint crop operation. The cropped areas
of the image must contain a fraction of the input image within this
range. The default area range is (0.08, 1.0).
"""
self
.
_output_size
=
output_size
self
.
_aug_rand_hflip
=
aug_rand_hflip
...
...
@@ -147,6 +152,7 @@ class Parser(parser.Parser):
self
.
_random_erasing
=
None
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
self
.
_crop_area_range
=
crop_area_range
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
...
...
@@ -177,7 +183,7 @@ class Parser(parser.Parser):
# Crops image.
cropped_image
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
image_shape
)
image_bytes
,
image_shape
,
area_range
=
self
.
_crop_area_range
)
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
image_shape
)),
lambda
:
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
),
...
...
@@ -189,7 +195,8 @@ class Parser(parser.Parser):
# Crops image.
if
self
.
_aug_crop
:
cropped_image
=
preprocess_ops
.
random_crop_image
(
image
)
cropped_image
=
preprocess_ops
.
random_crop_image
(
image
,
area_range
=
self
.
_crop_area_range
)
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
tf
.
shape
(
image
))),
...
...
@@ -215,9 +222,8 @@ class Parser(parser.Parser):
image
=
self
.
_augmenter
.
distort
(
image
)
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
# Random erasing after the image has been normalized
if
self
.
_random_erasing
is
not
None
:
...
...
@@ -251,9 +257,8 @@ class Parser(parser.Parser):
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
image
=
preprocess_ops
.
normalize_image
(
image
,
offset
=
MEAN_RGB
,
scale
=
STDDEV_RGB
)
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment