Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9b47a723
Commit
9b47a723
authored
Aug 04, 2022
by
Fan Yang
Committed by
A. Unique TensorFlower
Aug 04, 2022
Browse files
Internal change
PiperOrigin-RevId: 465437870
parent
02d00c0c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
16 additions
and
8 deletions
+16
-8
official/vision/configs/image_classification.py
official/vision/configs/image_classification.py
+1
-0
official/vision/dataloaders/classification_input.py
official/vision/dataloaders/classification_input.py
+14
-8
official/vision/tasks/image_classification.py
official/vision/tasks/image_classification.py
+1
-0
No files found.
official/vision/configs/image_classification.py
View file @
9b47a723
...
@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
...
@@ -36,6 +36,7 @@ class DataConfig(cfg.DataConfig):
cycle_length
:
int
=
10
cycle_length
:
int
=
10
is_multilabel
:
bool
=
False
is_multilabel
:
bool
=
False
aug_rand_hflip
:
bool
=
True
aug_rand_hflip
:
bool
=
True
aug_crop
:
Optional
[
bool
]
=
True
aug_type
:
Optional
[
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
common
.
Augmentation
]
=
None
# Choose from AutoAugment and RandAugment.
color_jitter
:
float
=
0.
color_jitter
:
float
=
0.
...
...
official/vision/dataloaders/classification_input.py
View file @
9b47a723
...
@@ -68,6 +68,7 @@ class Parser(parser.Parser):
...
@@ -68,6 +68,7 @@ class Parser(parser.Parser):
label_field_key
:
str
=
DEFAULT_LABEL_FIELD_KEY
,
label_field_key
:
str
=
DEFAULT_LABEL_FIELD_KEY
,
decode_jpeg_only
:
bool
=
True
,
decode_jpeg_only
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_crop
:
Optional
[
bool
]
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
color_jitter
:
float
=
0.
,
color_jitter
:
float
=
0.
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
random_erasing
:
Optional
[
common
.
RandomErasing
]
=
None
,
...
@@ -85,6 +86,8 @@ class Parser(parser.Parser):
...
@@ -85,6 +86,8 @@ class Parser(parser.Parser):
faster than decoding other types. Default is True.
faster than decoding other types. Default is True.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
horizontal flip.
aug_crop: `bool`, if True, perform random cropping during training and
center crop during validation.
aug_type: An optional Augmentation object to choose from AutoAugment and
aug_type: An optional Augmentation object to choose from AutoAugment and
RandAugment.
RandAugment.
color_jitter: Magnitude of color jitter. If > 0, the value is used to
color_jitter: Magnitude of color jitter. If > 0, the value is used to
...
@@ -98,6 +101,7 @@ class Parser(parser.Parser):
...
@@ -98,6 +101,7 @@ class Parser(parser.Parser):
"""
"""
self
.
_output_size
=
output_size
self
.
_output_size
=
output_size
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_aug_crop
=
aug_crop
self
.
_num_classes
=
num_classes
self
.
_num_classes
=
num_classes
self
.
_image_field_key
=
image_field_key
self
.
_image_field_key
=
image_field_key
if
dtype
==
'float32'
:
if
dtype
==
'float32'
:
...
@@ -168,7 +172,7 @@ class Parser(parser.Parser):
...
@@ -168,7 +172,7 @@ class Parser(parser.Parser):
"""Parses image data for training."""
"""Parses image data for training."""
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
if
self
.
_decode_jpeg_only
:
if
self
.
_decode_jpeg_only
and
self
.
_aug_crop
:
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Crops image.
# Crops image.
...
@@ -184,12 +188,13 @@ class Parser(parser.Parser):
...
@@ -184,12 +188,13 @@ class Parser(parser.Parser):
image
.
set_shape
([
None
,
None
,
3
])
image
.
set_shape
([
None
,
None
,
3
])
# Crops image.
# Crops image.
cropped_image
=
preprocess_ops
.
random_crop_image
(
image
)
if
self
.
_aug_crop
:
cropped_image
=
preprocess_ops
.
random_crop_image
(
image
)
image
=
tf
.
cond
(
image
=
tf
.
cond
(
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
tf
.
shape
(
image
))),
tf
.
reduce_all
(
tf
.
equal
(
tf
.
shape
(
cropped_image
),
tf
.
shape
(
image
))),
lambda
:
preprocess_ops
.
center_crop_image
(
image
),
lambda
:
preprocess_ops
.
center_crop_image
(
image
),
lambda
:
cropped_image
)
lambda
:
cropped_image
)
if
self
.
_aug_rand_hflip
:
if
self
.
_aug_rand_hflip
:
image
=
tf
.
image
.
random_flip_left_right
(
image
)
image
=
tf
.
image
.
random_flip_left_right
(
image
)
...
@@ -227,7 +232,7 @@ class Parser(parser.Parser):
...
@@ -227,7 +232,7 @@ class Parser(parser.Parser):
"""Parses image data for evaluation."""
"""Parses image data for evaluation."""
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
if
self
.
_decode_jpeg_only
:
if
self
.
_decode_jpeg_only
and
self
.
_aug_crop
:
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
# Center crops.
# Center crops.
...
@@ -238,7 +243,8 @@ class Parser(parser.Parser):
...
@@ -238,7 +243,8 @@ class Parser(parser.Parser):
image
.
set_shape
([
None
,
None
,
3
])
image
.
set_shape
([
None
,
None
,
3
])
# Center crops.
# Center crops.
image
=
preprocess_ops
.
center_crop_image
(
image
)
if
self
.
_aug_crop
:
image
=
preprocess_ops
.
center_crop_image
(
image
)
image
=
tf
.
image
.
resize
(
image
=
tf
.
image
.
resize
(
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
image
,
self
.
_output_size
,
method
=
tf
.
image
.
ResizeMethod
.
BILINEAR
)
...
...
official/vision/tasks/image_classification.py
View file @
9b47a723
...
@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
...
@@ -106,6 +106,7 @@ class ImageClassificationTask(base_task.Task):
label_field_key
=
label_field_key
,
label_field_key
=
label_field_key
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
decode_jpeg_only
=
params
.
decode_jpeg_only
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_rand_hflip
=
params
.
aug_rand_hflip
,
aug_crop
=
params
.
aug_crop
,
aug_type
=
params
.
aug_type
,
aug_type
=
params
.
aug_type
,
color_jitter
=
params
.
color_jitter
,
color_jitter
=
params
.
color_jitter
,
random_erasing
=
params
.
random_erasing
,
random_erasing
=
params
.
random_erasing
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment