Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5e8370ca
Commit
5e8370ca
authored
Sep 21, 2022
by
Jiageng Zhang
Committed by
A. Unique TensorFlower
Sep 21, 2022
Browse files
Internal change
PiperOrigin-RevId: 475928219
parent
d5cd9b0a
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
91 additions
and
3 deletions
+91
-3
official/vision/data/image_utils.py
official/vision/data/image_utils.py
+15
-3
official/vision/data/image_utils_test.py
official/vision/data/image_utils_test.py
+15
-0
official/vision/data/tf_example_builder.py
official/vision/data/tf_example_builder.py
+4
-0
official/vision/data/tf_example_builder_test.py
official/vision/data/tf_example_builder_test.py
+57
-0
No files found.
official/vision/data/image_utils.py
View file @
5e8370ca
...
@@ -17,7 +17,7 @@
...
@@ -17,7 +17,7 @@
import
dataclasses
import
dataclasses
import
imghdr
import
imghdr
import
io
import
io
from
typing
import
Tuple
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
...
@@ -35,6 +35,7 @@ class ImageFormat:
...
@@ -35,6 +35,7 @@ class ImageFormat:
bmp
:
str
=
'BMP'
bmp
:
str
=
'BMP'
png
:
str
=
'PNG'
png
:
str
=
'PNG'
jpeg
:
str
=
'JPEG'
jpeg
:
str
=
'JPEG'
raw
:
str
=
'RAW'
def
validate_image_format
(
format_str
:
str
)
->
str
:
def
validate_image_format
(
format_str
:
str
)
->
str
:
...
@@ -66,8 +67,11 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
...
@@ -66,8 +67,11 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
image_format: An enum specifying the format of the generated image.
image_format: An enum specifying the format of the generated image.
Returns:
Returns:
Encoded image string
Encoded image string
.
"""
"""
if
image_format
==
'RAW'
:
return
image_np
.
tobytes
()
if
len
(
image_np
.
shape
)
>
2
and
image_np
.
shape
[
2
]
==
1
:
if
len
(
image_np
.
shape
)
>
2
and
image_np
.
shape
[
2
]
==
1
:
image_pil
=
Image
.
fromarray
(
np
.
squeeze
(
image_np
),
'L'
)
image_pil
=
Image
.
fromarray
(
np
.
squeeze
(
image_np
),
'L'
)
else
:
else
:
...
@@ -77,7 +81,12 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
...
@@ -77,7 +81,12 @@ def encode_image(image_np: np.ndarray, image_format: str) -> bytes:
return
output
.
getvalue
()
return
output
.
getvalue
()
def
decode_image
(
image_bytes
:
bytes
)
->
np
.
ndarray
:
def
decode_image
(
image_bytes
:
bytes
,
image_format
:
Optional
[
str
]
=
None
,
image_dtype
:
str
=
'uint8'
)
->
np
.
ndarray
:
"""Decodes image_bytes into numpy array."""
if
image_format
==
'RAW'
:
return
np
.
frombuffer
(
image_bytes
,
dtype
=
image_dtype
)
image_pil
=
Image
.
open
(
io
.
BytesIO
(
image_bytes
))
image_pil
=
Image
.
open
(
io
.
BytesIO
(
image_bytes
))
image_np
=
np
.
array
(
image_pil
)
image_np
=
np
.
array
(
image_pil
)
if
len
(
image_np
.
shape
)
<
3
:
if
len
(
image_np
.
shape
)
<
3
:
...
@@ -88,6 +97,9 @@ def decode_image(image_bytes: bytes) -> np.ndarray:
...
@@ -88,6 +97,9 @@ def decode_image(image_bytes: bytes) -> np.ndarray:
def
decode_image_metadata
(
image_bytes
:
bytes
)
->
Tuple
[
int
,
int
,
int
,
str
]:
def
decode_image_metadata
(
image_bytes
:
bytes
)
->
Tuple
[
int
,
int
,
int
,
str
]:
"""Decodes image metadata from encoded image string.
"""Decodes image metadata from encoded image string.
Note that if the image is encoded in RAW format, the metadata cannot be
inferred from the image bytes.
Args:
Args:
image_bytes: Encoded image string.
image_bytes: Encoded image string.
...
...
official/vision/data/image_utils_test.py
View file @
5e8370ca
...
@@ -39,6 +39,21 @@ class ImageUtilsTest(parameterized.TestCase, tf.test.TestCase):
...
@@ -39,6 +39,21 @@ class ImageUtilsTest(parameterized.TestCase, tf.test.TestCase):
self
.
assertAllClose
(
actual_image_np
,
image_np
)
self
.
assertAllClose
(
actual_image_np
,
image_np
)
self
.
assertEqual
(
actual_image_np
.
shape
,
image_np
.
shape
)
self
.
assertEqual
(
actual_image_np
.
shape
,
image_np
.
shape
)
@
parameterized
.
named_parameters
(
(
'RGB_RAW'
,
128
,
64
,
3
,
tf
.
bfloat16
.
as_numpy_dtype
),
(
'GREY_RAW'
,
32
,
32
,
1
,
tf
.
uint8
.
as_numpy_dtype
))
def
test_encode_raw_image_then_decode_raw_image
(
self
,
height
,
width
,
num_channels
,
image_dtype
):
image_np
=
fake_feature_generator
.
generate_image_np
(
height
,
width
,
num_channels
)
image_np
=
image_np
.
astype
(
image_dtype
)
image_str
=
image_utils
.
encode_image
(
image_np
,
'RAW'
)
actual_image_np
=
image_utils
.
decode_image
(
image_str
,
'RAW'
,
image_dtype
)
actual_image_np
=
actual_image_np
.
reshape
([
height
,
width
,
num_channels
])
self
.
assertAllClose
(
actual_image_np
,
image_np
)
self
.
assertEqual
(
actual_image_np
.
shape
,
image_np
.
shape
)
@
parameterized
.
named_parameters
(
@
parameterized
.
named_parameters
(
(
'RGB_PNG'
,
128
,
64
,
3
,
'PNG'
),
(
'RGB_JPEG'
,
64
,
128
,
3
,
'JPEG'
),
(
'RGB_PNG'
,
128
,
64
,
3
,
'PNG'
),
(
'RGB_JPEG'
,
64
,
128
,
3
,
'JPEG'
),
(
'GREY_BMP'
,
32
,
32
,
1
,
'BMP'
),
(
'GREY_PNG'
,
128
,
128
,
1
,
'png'
))
(
'GREY_BMP'
,
32
,
32
,
1
,
'BMP'
),
(
'GREY_PNG'
,
128
,
128
,
1
,
'png'
))
...
...
official/vision/data/tf_example_builder.py
View file @
5e8370ca
...
@@ -119,6 +119,10 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
...
@@ -119,6 +119,10 @@ class TfExampleBuilder(tf_example_builder.TfExampleBuilder):
Returns:
Returns:
The builder object for subsequent method calls.
The builder object for subsequent method calls.
"""
"""
if
image_format
==
'RAW'
:
if
not
(
height
and
width
and
num_channels
):
raise
ValueError
(
'For raw image feature, height, width and '
'num_channels fields are required.'
)
if
not
all
((
height
,
width
,
num_channels
,
image_format
)):
if
not
all
((
height
,
width
,
num_channels
,
image_format
)):
(
height
,
width
,
num_channels
,
image_format
)
=
(
(
height
,
width
,
num_channels
,
image_format
)
=
(
image_utils
.
decode_image_metadata
(
encoded_image
))
image_utils
.
decode_image_metadata
(
encoded_image
))
...
...
official/vision/data/tf_example_builder_test.py
View file @
5e8370ca
...
@@ -24,6 +24,7 @@ from official.vision.data import tf_example_builder
...
@@ -24,6 +24,7 @@ from official.vision.data import tf_example_builder
class
TfExampleBuilderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
class
TfExampleBuilderTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
@
parameterized
.
named_parameters
((
'RGB_PNG'
,
128
,
64
,
3
,
'PNG'
,
'15125990'
),
@
parameterized
.
named_parameters
((
'RGB_PNG'
,
128
,
64
,
3
,
'PNG'
,
'15125990'
),
(
'RGB_RAW'
,
128
,
128
,
3
,
'RAW'
,
'5607919'
),
(
'RGB_JPEG'
,
64
,
128
,
3
,
'JPEG'
,
'3107796'
))
(
'RGB_JPEG'
,
64
,
128
,
3
,
'JPEG'
,
'3107796'
))
def
test_add_image_matrix_feature_success
(
self
,
height
,
width
,
num_channels
,
def
test_add_image_matrix_feature_success
(
self
,
height
,
width
,
num_channels
,
image_format
,
hashed_image
):
image_format
,
hashed_image
):
...
@@ -111,6 +112,62 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -111,6 +112,62 @@ class TfExampleBuilderTest(tf.test.TestCase, parameterized.TestCase):
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
hashed_image
]))
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
hashed_image
]))
})),
example
)
})),
example
)
def
test_add_encoded_raw_image_feature_success
(
self
):
height
=
128
width
=
128
num_channels
=
3
image_format
=
'RAW'
image_bytes
=
tf
.
bfloat16
.
as_numpy_dtype
image_np
=
fake_feature_generator
.
generate_image_np
(
height
,
width
,
num_channels
)
image_np
=
image_np
.
astype
(
image_bytes
)
expected_image_bytes
=
image_utils
.
encode_image
(
image_np
,
image_format
)
hashed_image
=
bytes
(
'3572575'
,
'ascii'
)
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
example_builder
.
add_encoded_image_feature
(
expected_image_bytes
,
'RAW'
,
height
,
width
,
num_channels
)
example
=
example_builder
.
example
self
.
assertProtoEquals
(
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
'image/encoded'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
expected_image_bytes
])),
'image/format'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
bytes
(
image_format
,
'ascii'
)])),
'image/height'
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
height
])),
'image/width'
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
width
])),
'image/channels'
:
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
num_channels
])),
'image/source_id'
:
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
hashed_image
]))
})),
example
)
def
test_add_encoded_raw_image_feature_valueerror
(
self
):
image_format
=
'RAW'
image_bytes
=
tf
.
bfloat16
.
as_numpy_dtype
image_np
=
fake_feature_generator
.
generate_image_np
(
1
,
1
,
1
)
image_np
=
image_np
.
astype
(
image_bytes
)
expected_image_bytes
=
image_utils
.
encode_image
(
image_np
,
image_format
)
example_builder
=
tf_example_builder
.
TfExampleBuilder
()
with
self
.
assertRaises
(
ValueError
):
example_builder
.
add_encoded_image_feature
(
expected_image_bytes
,
image_format
)
@
parameterized
.
parameters
(
@
parameterized
.
parameters
(
(
True
,
True
,
True
,
True
,
True
),
(
True
,
True
,
True
,
True
,
True
),
(
False
,
False
,
False
,
False
,
False
),
(
False
,
False
,
False
,
False
,
False
),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment