Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
069bdd28
Commit
069bdd28
authored
Jun 03, 2021
by
Abdullah Rashwan
Committed by
A. Unique TensorFlower
Jun 03, 2021
Browse files
Internal change
PiperOrigin-RevId: 377340715
parent
3fc55e9e
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
45 additions
and
15 deletions
+45
-15
official/vision/beta/configs/image_classification.py
official/vision/beta/configs/image_classification.py
+1
-0
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+40
-14
official/vision/beta/dataloaders/tfexample_utils.py
official/vision/beta/dataloaders/tfexample_utils.py
+3
-1
official/vision/beta/tasks/image_classification.py
official/vision/beta/tasks/image_classification.py
+1
-0
No files found.
official/vision/beta/configs/image_classification.py
View file @
069bdd28
...
...
@@ -43,6 +43,7 @@ class DataConfig(cfg.DataConfig):
file_type
:
str
=
'tfrecord'
image_field_key
:
str
=
'image/encoded'
label_field_key
:
str
=
'image/class/label'
decode_jpeg_only
:
bool
=
True
# Keep for backward compatibility.
aug_policy
:
Optional
[
str
]
=
None
# None, 'autoaug', or 'randaug'.
...
...
official/vision/beta/dataloaders/classification_input.py
View file @
069bdd28
...
...
@@ -66,6 +66,7 @@ class Parser(parser.Parser):
num_classes
:
float
,
image_field_key
:
str
=
DEFAULT_IMAGE_FIELD_KEY
,
label_field_key
:
str
=
DEFAULT_LABEL_FIELD_KEY
,
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
is_multilabel
:
bool
=
False
,
...
...
@@ -78,6 +79,8 @@ class Parser(parser.Parser):
num_classes: `float`, number of classes.
image_field_key: `str`, the key name to encoded image in tf.Example.
label_field_key: `str`, the key name to label in tf.Example.
decode_jpeg_only: `bool`, if True, only JPEG format is decoded, this is
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
...
...
@@ -118,6 +121,7 @@ class Parser(parser.Parser):
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_is_multilabel
=
is_multilabel
self
.
_decode_jpeg_only
=
decode_jpeg_only
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
...
...
@@ -142,16 +146,29 @@ class Parser(parser.Parser):
def
_parse_train_image
(
self
,
decoded_tensors
):
"""Parses image data for training."""
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Crops image.
# TODO(pengchong): support image format other than JPEG.
cropped_image
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
image_shape
)
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
image_shape
)),
lambda
:
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
),
lambda
:
cropped_image
)
if
self
.
_decode_jpeg_only
:
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Crops image.
cropped_image
=
preprocess_ops
.
random_crop_image_v2
(
image_bytes
,
image_shape
)
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
image_shape
)),
lambda
:
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
),
lambda
:
cropped_image
)
else
:
# Decodes image.
image
=
tf
.
io
.
decode_image
(
image_bytes
,
channels
=
3
)
image
.
set_shape
([
None
,
None
,
3
])
# Crops image.
cropped_image
=
preprocess_ops
.
random_crop_image
(
image
)
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
tf
.
shape
(
image
))),
lambda
:
preprocess_ops
.
center_crop_image
(
image
),
lambda
:
cropped_image
)
if
self
.
_aug_rand_hflip
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
...
...
@@ -159,6 +176,7 @@ class Parser(parser.Parser):
# Resizes image.
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Apply autoaug or randaug.
if
self
.
_augmenter
is
not
None
:
...
...
@@ -177,15 +195,23 @@ class Parser(parser.Parser):
def
_parse_eval_image
(
self
,
decoded_tensors
):
"""Parses image data for evaluation."""
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Center crops and resizes image.
image
=
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
)
if
self
.
_decode_jpeg_only
:
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Center crops.
image
=
preprocess_ops
.
center_crop_image_v2
(
image_bytes
,
image_shape
)
else
:
# Decodes image.
image
=
tf
.
io
.
decode_image
(
image_bytes
,
channels
=
3
)
image
.
set_shape
([
None
,
None
,
3
])
# Center crops.
image
=
preprocess_ops
.
center_crop_image
(
image
)
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
=
tf
.
reshape
(
image
,
[
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
image
.
set_shape
([
self
.
_output_size
[
0
],
self
.
_output_size
[
1
],
3
])
# Normalizes image with mean and std pixel values.
image
=
preprocess_ops
.
normalize_image
(
image
,
...
...
official/vision/beta/dataloaders/tfexample_utils.py
View file @
069bdd28
...
...
@@ -127,10 +127,12 @@ def _encode_image(image_array: np.ndarray, fmt: str) -> bytes:
def
create_classification_example
(
image_height
:
int
,
image_width
:
int
,
image_format
:
str
=
'JPEG'
,
is_multilabel
:
bool
=
False
)
->
tf
.
train
.
Example
:
"""Creates image and labels for image classification input pipeline."""
image
=
_encode_image
(
np
.
uint8
(
np
.
random
.
rand
(
image_height
,
image_width
,
3
)
*
255
),
fmt
=
'JPEG'
)
np
.
uint8
(
np
.
random
.
rand
(
image_height
,
image_width
,
3
)
*
255
),
fmt
=
image_format
)
labels
=
[
0
,
1
]
if
is_multilabel
else
[
0
]
serialized_example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
...
...
official/vision/beta/tasks/image_classification.py
View file @
069bdd28
...
...
@@ -104,6 +104,7 @@ class ImageClassificationTask(base_task.Task):
num_classes
=
num_classes
,
image_field_key
=
image_field_key
,
label_field_key
=
label_field_key
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_type
=
params
.
aug_type
,
is_multilabel
=
is_multilabel
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment