Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0bcf460a
Commit
0bcf460a
authored
May 10, 2021
by
Rajagopal Ananthanarayanan
Committed by
A. Unique TensorFlower
May 10, 2021
Browse files
Internal change
PiperOrigin-RevId: 372996541
parent
ebf268b6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
52 additions
and
38 deletions
+52
-38
official/vision/beta/dataloaders/classification_input.py
official/vision/beta/dataloaders/classification_input.py
+52
-38
No files found.
official/vision/beta/dataloaders/classification_input.py
View file @
0bcf460a
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
# limitations under the License.
# limitations under the License.
"""Classification decoder and parser."""
"""Classification decoder and parser."""
from
typing
import
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
# Import libraries
# Import libraries
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -26,27 +26,34 @@ from official.vision.beta.ops import preprocess_ops
...
@@ -26,27 +26,34 @@ from official.vision.beta.ops import preprocess_ops
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
MEAN_RGB
=
(
0.485
*
255
,
0.456
*
255
,
0.406
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
STDDEV_RGB
=
(
0.229
*
255
,
0.224
*
255
,
0.225
*
255
)
DEFAULT_IMAGE_FIELD_KEY
=
'image/encoded'
DEFAULT_LABEL_FIELD_KEY
=
'image/class/label'
class
Decoder
(
decoder
.
Decoder
):
class
Decoder
(
decoder
.
Decoder
):
"""A tf.Example decoder for classification task."""
"""A tf.Example decoder for classification task."""
def
__init__
(
self
,
def
__init__
(
self
,
image_field_key
:
str
=
'image/encoded'
,
image_field_key
:
str
=
DEFAULT_IMAGE_FIELD_KEY
,
label_field_key
:
str
=
'image/class/label'
,
label_field_key
:
str
=
DEFAULT_LABEL_FIELD_KEY
,
is_multilabel
:
bool
=
False
):
is_multilabel
:
bool
=
False
,
self
.
_keys_to_features
=
{
keys_to_features
:
Optional
[
Dict
[
str
,
Any
]]
=
None
):
image_field_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
if
not
keys_to_features
:
}
keys_to_features
=
{
if
is_multilabel
:
image_field_key
:
self
.
_keys_to_features
.
update
(
tf
.
io
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
''
),
{
label_field_key
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
int64
)})
}
else
:
if
is_multilabel
:
self
.
_keys_to_features
.
update
({
keys_to_features
.
update
(
label_field_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
)
{
label_field_key
:
tf
.
io
.
VarLenFeature
(
dtype
=
tf
.
int64
)})
})
else
:
keys_to_features
.
update
({
label_field_key
:
tf
.
io
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=-
1
)
})
self
.
_keys_to_features
=
keys_to_features
def
decode
(
self
,
def
decode
(
self
,
serialized_example
):
serialized_example
:
tf
.
train
.
Example
)
->
Dict
[
str
,
tf
.
Tensor
]:
return
tf
.
io
.
parse_single_example
(
return
tf
.
io
.
parse_single_example
(
serialized_example
,
self
.
_keys_to_features
)
serialized_example
,
self
.
_keys_to_features
)
...
@@ -57,8 +64,8 @@ class Parser(parser.Parser):
...
@@ -57,8 +64,8 @@ class Parser(parser.Parser):
def
__init__
(
self
,
def
__init__
(
self
,
output_size
:
List
[
int
],
output_size
:
List
[
int
],
num_classes
:
float
,
num_classes
:
float
,
image_field_key
:
str
=
'image/encoded'
,
image_field_key
:
str
=
DEFAULT_IMAGE_FIELD_KEY
,
label_field_key
:
str
=
'image/class/label'
,
label_field_key
:
str
=
DEFAULT_LABEL_FIELD_KEY
,
aug_rand_hflip
:
bool
=
True
,
aug_rand_hflip
:
bool
=
True
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
aug_type
:
Optional
[
common
.
Augmentation
]
=
None
,
is_multilabel
:
bool
=
False
,
is_multilabel
:
bool
=
False
,
...
@@ -69,8 +76,8 @@ class Parser(parser.Parser):
...
@@ -69,8 +76,8 @@ class Parser(parser.Parser):
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size: `Tensor` or `list` for [height, width] of output image. The
output_size should be divided by the largest feature stride 2^max_level.
output_size should be divided by the largest feature stride 2^max_level.
num_classes: `float`, number of classes.
num_classes: `float`, number of classes.
image_field_key:
A
`str`
of
the key name to encoded image in
TF
Example.
image_field_key: `str`
,
the key name to encoded image in
tf.
Example.
label_field_key:
A
`str`
of
the key name to label in
TF
Example.
label_field_key: `str`
,
the key name to label in
tf.
Example.
aug_rand_hflip: `bool`, if True, augment training with random
aug_rand_hflip: `bool`, if True, augment training with random
horizontal flip.
horizontal flip.
aug_type: An optional Augmentation object to choose from AutoAugment and
aug_type: An optional Augmentation object to choose from AutoAugment and
...
@@ -83,9 +90,6 @@ class Parser(parser.Parser):
...
@@ -83,9 +90,6 @@ class Parser(parser.Parser):
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_aug_rand_hflip
=
aug_rand_hflip
self
.
_num_classes
=
num_classes
self
.
_num_classes
=
num_classes
self
.
_image_field_key
=
image_field_key
self
.
_image_field_key
=
image_field_key
self
.
_label_field_key
=
label_field_key
self
.
_is_multilabel
=
is_multilabel
if
dtype
==
'float32'
:
if
dtype
==
'float32'
:
self
.
_dtype
=
tf
.
float32
self
.
_dtype
=
tf
.
float32
elif
dtype
==
'float16'
:
elif
dtype
==
'float16'
:
...
@@ -111,10 +115,31 @@ class Parser(parser.Parser):
...
@@ -111,10 +115,31 @@ class Parser(parser.Parser):
aug_type
.
type
))
aug_type
.
type
))
else
:
else
:
self
.
_augmenter
=
None
self
.
_augmenter
=
None
self
.
_label_field_key
=
label_field_key
self
.
_is_multilabel
=
is_multilabel
def
_parse_train_data
(
self
,
decoded_tensors
):
def
_parse_train_data
(
self
,
decoded_tensors
):
"""Parses data for training."""
"""Parses data for training."""
image
=
self
.
_parse_train_image
(
decoded_tensors
)
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
if
self
.
_is_multilabel
:
if
isinstance
(
label
,
tf
.
sparse
.
SparseTensor
):
label
=
tf
.
sparse
.
to_dense
(
label
)
label
=
tf
.
reduce_sum
(
tf
.
one_hot
(
label
,
self
.
_num_classes
),
axis
=
0
)
return
image
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
):
"""Parses data for evaluation."""
image
=
self
.
_parse_eval_image
(
decoded_tensors
)
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
if
self
.
_is_multilabel
:
if
isinstance
(
label
,
tf
.
sparse
.
SparseTensor
):
label
=
tf
.
sparse
.
to_dense
(
label
)
label
=
tf
.
reduce_sum
(
tf
.
one_hot
(
label
,
self
.
_num_classes
),
axis
=
0
)
return
image
,
label
def
_parse_train_image
(
self
,
decoded_tensors
):
"""Parses image data for training."""
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
...
@@ -146,16 +171,10 @@ class Parser(parser.Parser):
...
@@ -146,16 +171,10 @@ class Parser(parser.Parser):
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
if
self
.
_is_multilabel
:
return
image
if
isinstance
(
label
,
tf
.
sparse
.
SparseTensor
):
label
=
tf
.
sparse
.
to_dense
(
label
)
label
=
tf
.
reduce_sum
(
tf
.
one_hot
(
label
,
self
.
_num_classes
),
axis
=
0
)
return
image
,
label
def
_parse_eval_data
(
self
,
decoded_tensors
):
def
_parse_eval_image
(
self
,
decoded_tensors
):
"""Parses data for evaluation."""
"""Parses image data for evaluation."""
label
=
tf
.
cast
(
decoded_tensors
[
self
.
_label_field_key
],
dtype
=
tf
.
int32
)
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_bytes
=
decoded_tensors
[
self
.
_image_field_key
]
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
image_shape
=
tf
.
image
.
extract_jpeg_shape
(
image_bytes
)
...
@@ -175,9 +194,4 @@ class Parser(parser.Parser):
...
@@ -175,9 +194,4 @@ class Parser(parser.Parser):
# Convert image to self._dtype.
# Convert image to self._dtype.
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
image
=
tf
.
image
.
convert_image_dtype
(
image
,
self
.
_dtype
)
if
self
.
_is_multilabel
:
return
image
if
isinstance
(
label
,
tf
.
sparse
.
SparseTensor
):
label
=
tf
.
sparse
.
to_dense
(
label
)
label
=
tf
.
reduce_sum
(
tf
.
one_hot
(
label
,
self
.
_num_classes
),
axis
=
0
)
return
image
,
label
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment