Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
374cff58
Commit
374cff58
authored
Jan 14, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 351798345
parent
1fb4c559
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
24 additions
and
6 deletions
+24
-6
official/vision/beta/data/tfrecord_lib.py
official/vision/beta/data/tfrecord_lib.py
+24
-6
No files found.
official/vision/beta/data/tfrecord_lib.py
View file @
374cff58
...
...
@@ -19,6 +19,7 @@ import io
import
itertools
from
absl
import
logging
import
numpy
as
np
from
PIL
import
Image
import
tensorflow
as
tf
...
...
@@ -45,10 +46,10 @@ def convert_to_feature(value, value_type=None):
if
isinstance
(
element
,
bytes
):
value_type
=
'bytes'
elif
isinstance
(
element
,
int
):
elif
isinstance
(
element
,
(
int
,
np
.
integer
)
):
value_type
=
'int64'
elif
isinstance
(
element
,
float
):
elif
isinstance
(
element
,
(
float
,
np
.
floating
)
):
value_type
=
'float'
else
:
...
...
@@ -104,8 +105,9 @@ def encode_binary_mask_as_png(binary_mask):
return
output_io
.
getvalue
()
def
write_tf_record_dataset
(
output_path
,
annotation_iterator
,
process_func
,
num_shards
,
use_multiprocessing
=
True
):
def
write_tf_record_dataset
(
output_path
,
annotation_iterator
,
process_func
,
num_shards
,
use_multiprocessing
=
True
,
unpack_arguments
=
True
):
"""Iterates over annotations, processes them and writes into TFRecords.
Args:
...
...
@@ -118,6 +120,9 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
num_shards: int, the number of shards to write for the dataset.
use_multiprocessing:
Whether or not to use multiple processes to write TF Records.
unpack_arguments:
Whether to unpack the tuples from annotation_iterator as individual
arguments to the process func or to pass the returned value as it is.
Returns:
num_skipped: The total number of skipped annotations.
...
...
@@ -133,9 +138,15 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
if
use_multiprocessing
:
pool
=
mp
.
Pool
()
tf_example_iterator
=
pool
.
starmap
(
process_func
,
annotation_iterator
)
if
unpack_arguments
:
tf_example_iterator
=
pool
.
starmap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
pool
.
imap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
itertools
.
starmap
(
process_func
,
annotation_iterator
)
if
unpack_arguments
:
tf_example_iterator
=
itertools
.
starmap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
map
(
process_func
,
annotation_iterator
)
for
idx
,
(
tf_example
,
num_annotations_skipped
)
in
enumerate
(
tf_example_iterator
):
...
...
@@ -155,3 +166,10 @@ def write_tf_record_dataset(output_path, annotation_iterator, process_func,
logging
.
info
(
'Finished writing, skipped %d annotations.'
,
total_num_annotations_skipped
)
return
total_num_annotations_skipped
def
check_and_make_dir
(
directory
):
"""Creates the directory if it doesn't exist."""
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
tf
.
io
.
gfile
.
makedirs
(
directory
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment