Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
8c248a5a
Commit
8c248a5a
authored
Nov 30, 2020
by
A. Unique TensorFlower
Browse files
Branch create_coco_tf_record into tf-vision code base.
PiperOrigin-RevId: 344826362
parent
5634bf23
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
620 additions
and
0 deletions
+620
-0
official/vision/beta/data/create_coco_tf_record.py
official/vision/beta/data/create_coco_tf_record.py
+370
-0
official/vision/beta/data/tfrecord_lib.py
official/vision/beta/data/tfrecord_lib.py
+157
-0
official/vision/beta/data/tfrecord_lib_test.py
official/vision/beta/data/tfrecord_lib_test.py
+93
-0
No files found.
official/vision/beta/data/create_coco_tf_record.py
0 → 100644
View file @
8c248a5a
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Convert raw COCO dataset to TFRecord format.
This scripts follows the label map decoder format and supports detection
boxes, instance masks and captions.
Example usage:
python create_coco_tf_record.py --logtostderr \
--image_dir="${TRAIN_IMAGE_DIR}" \
--image_info_file="${TRAIN_IMAGE_INFO_FILE}" \
--object_annotations_file="${TRAIN_ANNOTATIONS_FILE}" \
--caption_annotations_file="${CAPTION_ANNOTATIONS_FILE}" \
--output_file_prefix="${OUTPUT_DIR/FILE_PREFIX}" \
--num_shards=100
"""
import
collections
import
json
import
logging
import
os
from
absl
import
app
# pylint:disable=unused-import
from
absl
import
flags
import
numpy
as
np
from
pycocotools
import
mask
import
tensorflow
as
tf
import
multiprocessing
as
mp
from
official.vision.beta.data
import
tfrecord_lib
flags
.
DEFINE_boolean
(
'include_masks'
,
False
,
'Whether to include instance segmentations masks '
'(PNG encoded) in the result. default: False.'
)
flags
.
DEFINE_string
(
'image_dir'
,
''
,
'Directory containing images.'
)
flags
.
DEFINE_string
(
'image_info_file'
,
''
,
'File containing image information. '
'Tf Examples in the output files correspond to the image '
'info entries in this file. If this file is not provided '
'object_annotations_file is used if present. Otherwise, '
'caption_annotations_file is used to get image info.'
)
flags
.
DEFINE_string
(
'object_annotations_file'
,
''
,
'File containing object '
'annotations - boxes and instance masks.'
)
flags
.
DEFINE_string
(
'caption_annotations_file'
,
''
,
'File containing image '
'captions.'
)
flags
.
DEFINE_string
(
'output_file_prefix'
,
'/tmp/train'
,
'Path to output file'
)
flags
.
DEFINE_integer
(
'num_shards'
,
32
,
'Number of shards for output file.'
)
FLAGS
=
flags
.
FLAGS
logger
=
tf
.
get_logger
()
logger
.
setLevel
(
logging
.
INFO
)
def
coco_segmentation_to_mask_png
(
segmentation
,
height
,
width
,
is_crowd
):
"""Encode a COCO mask segmentation as PNG string."""
run_len_encoding
=
mask
.
frPyObjects
(
segmentation
,
height
,
width
)
binary_mask
=
mask
.
decode
(
run_len_encoding
)
if
not
is_crowd
:
binary_mask
=
np
.
amax
(
binary_mask
,
axis
=
2
)
return
tfrecord_lib
.
encode_binary_mask_as_png
(
binary_mask
)
def
coco_annotations_to_lists
(
bbox_annotations
,
id_to_name_map
,
image_height
,
image_width
,
include_masks
):
"""Convert COCO annotations to feature lists."""
data
=
dict
((
k
,
list
())
for
k
in
[
'xmin'
,
'xmax'
,
'ymin'
,
'ymax'
,
'is_crowd'
,
'category_id'
,
'category_names'
,
'area'
])
if
include_masks
:
data
[
'encoded_mask_png'
]
=
[]
num_annotations_skipped
=
0
for
object_annotations
in
bbox_annotations
:
(
x
,
y
,
width
,
height
)
=
tuple
(
object_annotations
[
'bbox'
])
if
width
<=
0
or
height
<=
0
:
num_annotations_skipped
+=
1
continue
if
x
+
width
>
image_width
or
y
+
height
>
image_height
:
num_annotations_skipped
+=
1
continue
data
[
'xmin'
].
append
(
float
(
x
)
/
image_width
)
data
[
'xmax'
].
append
(
float
(
x
+
width
)
/
image_width
)
data
[
'ymin'
].
append
(
float
(
y
)
/
image_height
)
data
[
'ymax'
].
append
(
float
(
y
+
height
)
/
image_height
)
data
[
'is_crowd'
].
append
(
object_annotations
[
'iscrowd'
])
category_id
=
int
(
object_annotations
[
'category_id'
])
data
[
'category_id'
].
append
(
category_id
)
data
[
'category_names'
].
append
(
id_to_name_map
[
category_id
].
encode
(
'utf8'
))
data
[
'area'
].
append
(
object_annotations
[
'area'
])
if
include_masks
:
data
[
'encoded_mask_png'
].
append
(
coco_segmentation_to_mask_png
(
object_annotations
[
'segmentation'
],
image_height
,
image_width
,
object_annotations
[
'iscrowd'
])
)
return
data
,
num_annotations_skipped
def
bbox_annotations_to_feature_dict
(
bbox_annotations
,
image_height
,
image_width
,
id_to_name_map
,
include_masks
):
"""Convert COCO annotations to an encoded feature dict."""
data
,
num_skipped
=
coco_annotations_to_lists
(
bbox_annotations
,
id_to_name_map
,
image_height
,
image_width
,
include_masks
)
feature_dict
=
{
'image/object/bbox/xmin'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'xmin'
]),
'image/object/bbox/xmax'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'xmax'
]),
'image/object/bbox/ymin'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'ymin'
]),
'image/object/bbox/ymax'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'ymax'
]),
'image/object/class/text'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'category_names'
]),
'image/object/class/label'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'category_id'
]),
'image/object/is_crowd'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'is_crowd'
]),
'image/object/area'
:
tfrecord_lib
.
convert_to_feature
(
data
[
'area'
]),
}
if
include_masks
:
feature_dict
[
'image/object/mask'
]
=
(
tfrecord_lib
.
convert_to_feature
(
data
[
'encoded_mask_png'
]))
return
feature_dict
,
num_skipped
def
encode_caption_annotations
(
caption_annotations
):
captions
=
[]
for
caption_annotation
in
caption_annotations
:
captions
.
append
(
caption_annotation
[
'caption'
].
encode
(
'utf8'
))
return
captions
def
create_tf_example
(
image
,
image_dir
,
bbox_annotations
=
None
,
id_to_name_map
=
None
,
caption_annotations
=
None
,
include_masks
=
False
):
"""Converts image and annotations to a tf.Example proto.
Args:
image: dict with keys: [u'license', u'file_name', u'coco_url', u'height',
u'width', u'date_captured', u'flickr_url', u'id']
image_dir: directory containing the image files.
bbox_annotations:
list of dicts with keys: [u'segmentation', u'area', u'iscrowd',
u'image_id', u'bbox', u'category_id', u'id'] Notice that bounding box
coordinates in the official COCO dataset are given as [x, y, width,
height] tuples using absolute coordinates where x, y represent the
top-left (0-indexed) corner. This function converts to the format
expected by the Tensorflow Object Detection API (which is which is
[ymin, xmin, ymax, xmax] with coordinates normalized relative to image
size).
id_to_name_map: a dict mapping category IDs to string names.
caption_annotations:
list of dict with keys: [u'id', u'image_id', u'str'].
include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False.
Returns:
example: The converted tf.Example
num_annotations_skipped: Number of (invalid) annotations that were ignored.
Raises:
ValueError: if the image pointed to by data['filename'] is not a valid JPEG
"""
image_height
=
image
[
'height'
]
image_width
=
image
[
'width'
]
filename
=
image
[
'file_name'
]
image_id
=
image
[
'id'
]
full_path
=
os
.
path
.
join
(
image_dir
,
filename
)
with
tf
.
io
.
gfile
.
GFile
(
full_path
,
'rb'
)
as
fid
:
encoded_jpg
=
fid
.
read
()
feature_dict
=
tfrecord_lib
.
image_info_to_feature_dict
(
image_height
,
image_width
,
filename
,
image_id
,
encoded_jpg
,
'jpg'
)
num_annotations_skipped
=
0
if
bbox_annotations
:
box_feature_dict
,
num_skipped
=
bbox_annotations_to_feature_dict
(
bbox_annotations
,
image_height
,
image_width
,
id_to_name_map
,
include_masks
)
num_annotations_skipped
+=
num_skipped
feature_dict
.
update
(
box_feature_dict
)
if
caption_annotations
:
encoded_captions
=
encode_caption_annotations
(
caption_annotations
)
feature_dict
.
update
(
{
'image/caption'
:
tfrecord_lib
.
convert_to_feature
(
encoded_captions
)})
example
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
feature_dict
))
return
example
,
num_annotations_skipped
def
_load_object_annotations
(
object_annotations_file
):
"""Loads object annotation JSON file."""
with
tf
.
io
.
gfile
.
GFile
(
object_annotations_file
,
'r'
)
as
fid
:
obj_annotations
=
json
.
load
(
fid
)
images
=
obj_annotations
[
'images'
]
id_to_name_map
=
dict
((
element
[
'id'
],
element
[
'name'
])
for
element
in
obj_annotations
[
'categories'
])
img_to_obj_annotation
=
collections
.
defaultdict
(
list
)
logging
.
info
(
'Building bounding box index.'
)
for
annotation
in
obj_annotations
[
'annotations'
]:
image_id
=
annotation
[
'image_id'
]
img_to_obj_annotation
[
image_id
].
append
(
annotation
)
missing_annotation_count
=
0
for
image
in
images
:
image_id
=
image
[
'id'
]
if
image_id
not
in
img_to_obj_annotation
:
missing_annotation_count
+=
1
logging
.
info
(
'%d images are missing bboxes.'
,
missing_annotation_count
)
return
img_to_obj_annotation
,
id_to_name_map
def
_load_caption_annotations
(
caption_annotations_file
):
"""Loads caption annotation JSON file."""
with
tf
.
io
.
gfile
.
GFile
(
caption_annotations_file
,
'r'
)
as
fid
:
caption_annotations
=
json
.
load
(
fid
)
img_to_caption_annotation
=
collections
.
defaultdict
(
list
)
logging
.
info
(
'Building caption index.'
)
for
annotation
in
caption_annotations
[
'annotations'
]:
image_id
=
annotation
[
'image_id'
]
img_to_caption_annotation
[
image_id
].
append
(
annotation
)
missing_annotation_count
=
0
images
=
caption_annotations
[
'images'
]
for
image
in
images
:
image_id
=
image
[
'id'
]
if
image_id
not
in
img_to_caption_annotation
:
missing_annotation_count
+=
1
logging
.
info
(
'%d images are missing captions.'
,
missing_annotation_count
)
return
img_to_caption_annotation
def
_load_images_info
(
images_info_file
):
with
tf
.
io
.
gfile
.
GFile
(
images_info_file
,
'r'
)
as
fid
:
info_dict
=
json
.
load
(
fid
)
return
info_dict
[
'images'
]
def
generate_annotations
(
images
,
image_dir
,
img_to_obj_annotation
=
None
,
img_to_caption_annotation
=
None
,
id_to_name_map
=
None
,
include_masks
=
False
):
"""Generator for COCO annotations."""
for
image
in
images
:
if
img_to_obj_annotation
:
object_annotation
=
img_to_obj_annotation
.
get
(
image
[
'id'
],
None
)
if
img_to_caption_annotation
:
caption_annotaion
=
img_to_caption_annotation
.
get
(
image
[
'id'
],
None
)
yield
(
image
,
image_dir
,
object_annotation
,
id_to_name_map
,
caption_annotaion
,
include_masks
)
def
_create_tf_record_from_coco_annotations
(
images_info_file
,
image_dir
,
output_path
,
num_shards
,
object_annotations_file
=
None
,
caption_annotations_file
=
None
,
include_masks
=
False
):
"""Loads COCO annotation json files and converts to tf.Record format.
Args:
images_info_file: JSON file containing image info. The number of tf.Examples
in the output tf Record files is exactly equal to the number of image info
entries in this file. This can be any of train/val/test annotation json
files Eg. 'image_info_test-dev2017.json',
'instance_annotations_train2017.json',
'caption_annotations_train2017.json', etc.
image_dir: Directory containing the image files.
output_path: Path to output tf.Record file.
num_shards: Number of output files to create.
object_annotations_file: JSON file containing bounding box annotations.
caption_annotations_file: JSON file containing caption annotations.
include_masks: Whether to include instance segmentations masks
(PNG encoded) in the result. default: False.
"""
logging
.
info
(
'writing to output path: %s'
,
output_path
)
images
=
_load_images_info
(
images_info_file
)
img_to_obj_annotation
=
None
img_to_caption_annotation
=
None
id_to_name_map
=
None
if
object_annotations_file
:
img_to_obj_annotation
,
id_to_name_map
=
(
_load_object_annotations
(
object_annotations_file
))
if
caption_annotations_file
:
img_to_caption_annotation
=
(
_load_caption_annotations
(
caption_annotations_file
))
coco_annotations_iter
=
generate_annotations
(
images
,
image_dir
,
img_to_obj_annotation
,
img_to_caption_annotation
,
id_to_name_map
=
id_to_name_map
,
include_masks
=
include_masks
)
num_skipped
=
tfrecord_lib
.
write_tf_record_dataset
(
output_path
,
coco_annotations_iter
,
create_tf_example
,
num_shards
)
logging
.
info
(
'Finished writing, skipped %d annotations.'
,
num_skipped
)
def
main
(
_
):
assert
FLAGS
.
image_dir
,
'`image_dir` missing.'
assert
(
FLAGS
.
image_info_file
or
FLAGS
.
object_annotations_file
or
FLAGS
.
caption_annotations_file
),
(
'All annotation files are '
'missing.'
)
if
FLAGS
.
image_info_file
:
images_info_file
=
FLAGS
.
image_info_file
elif
FLAGS
.
object_annotations_file
:
images_info_file
=
FLAGS
.
object_annotations_file
else
:
images_info_file
=
FLAGS
.
caption_annotations_file
directory
=
os
.
path
.
dirname
(
FLAGS
.
output_file_prefix
)
if
not
tf
.
io
.
gfile
.
isdir
(
directory
):
tf
.
io
.
gfile
.
makedirs
(
directory
)
_create_tf_record_from_coco_annotations
(
images_info_file
,
FLAGS
.
image_dir
,
FLAGS
.
output_file_prefix
,
FLAGS
.
num_shards
,
FLAGS
.
object_annotations_file
,
FLAGS
.
caption_annotations_file
,
FLAGS
.
include_masks
)
if
__name__
==
'__main__'
:
app
.
run
(
main
)
official/vision/beta/data/tfrecord_lib.py
0 → 100644
View file @
8c248a5a
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Helper functions for creating TFRecord datasets."""
import
hashlib
import
io
import
itertools
from
absl
import
logging
from
PIL
import
Image
import
tensorflow
as
tf
import
multiprocessing
as
mp
def
convert_to_feature
(
value
,
value_type
=
None
):
"""Converts the given python object to a tf.train.Feature.
Args:
value: int, float, bytes or a list of them.
value_type: optional, if specified, forces the feature to be of the given
type. Otherwise, type is inferred automatically. Can be one of
['bytes', 'int64', 'float', 'bytes_list', 'int64_list', 'float_list']
Returns:
feature: A tf.train.Feature object.
"""
if
value_type
is
None
:
element
=
value
[
0
]
if
isinstance
(
value
,
list
)
else
value
if
isinstance
(
element
,
bytes
):
value_type
=
'bytes'
elif
isinstance
(
element
,
int
):
value_type
=
'int64'
elif
isinstance
(
element
,
float
):
value_type
=
'float'
else
:
raise
ValueError
(
'Cannot convert type {} to feature'
.
format
(
type
(
element
)))
if
isinstance
(
value
,
list
):
value_type
=
value_type
+
'_list'
if
value_type
==
'int64'
:
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
value
]))
elif
value_type
==
'int64_list'
:
return
tf
.
train
.
Feature
(
int64_list
=
tf
.
train
.
Int64List
(
value
=
value
))
elif
value_type
==
'float'
:
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
[
value
]))
elif
value_type
==
'float_list'
:
return
tf
.
train
.
Feature
(
float_list
=
tf
.
train
.
FloatList
(
value
=
value
))
elif
value_type
==
'bytes'
:
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
elif
value_type
==
'bytes_list'
:
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
value
))
else
:
raise
ValueError
(
'Unknown value_type parameter - {}'
.
format
(
value_type
))
def
image_info_to_feature_dict
(
height
,
width
,
filename
,
image_id
,
encoded_str
,
encoded_format
):
"""Convert image information to a dict of features."""
key
=
hashlib
.
sha256
(
encoded_str
).
hexdigest
()
return
{
'image/height'
:
convert_to_feature
(
height
),
'image/width'
:
convert_to_feature
(
width
),
'image/filename'
:
convert_to_feature
(
filename
.
encode
(
'utf8'
)),
'image/source_id'
:
convert_to_feature
(
str
(
image_id
).
encode
(
'utf8'
)),
'image/key/sha256'
:
convert_to_feature
(
key
.
encode
(
'utf8'
)),
'image/encoded'
:
convert_to_feature
(
encoded_str
),
'image/format'
:
convert_to_feature
(
encoded_format
.
encode
(
'utf8'
)),
}
def
encode_binary_mask_as_png
(
binary_mask
):
pil_image
=
Image
.
fromarray
(
binary_mask
)
output_io
=
io
.
BytesIO
()
pil_image
.
save
(
output_io
,
format
=
'PNG'
)
return
output_io
.
getvalue
()
def
write_tf_record_dataset
(
output_path
,
annotation_iterator
,
process_func
,
num_shards
,
use_multiprocessing
=
True
):
"""Iterates over annotations, processes them and writes into TFRecords.
Args:
output_path: The prefix path to create TF record files.
annotation_iterator: An iterator of tuples containing details about the
dataset.
process_func: A function which takes the elements from the tuples of
annotation_iterator as arguments and returns a tuple of (tf.train.Example,
int). The integer indicates the number of annotations that were skipped.
num_shards: int, the number of shards to write for the dataset.
use_multiprocessing:
Whether or not to use multiple processes to write TF Records.
Returns:
num_skipped: The total number of skipped annotations.
"""
writers
=
[
tf
.
io
.
TFRecordWriter
(
output_path
+
'-%05d-of-%05d.tfrecord'
%
(
i
,
num_shards
))
for
i
in
range
(
num_shards
)
]
total_num_annotations_skipped
=
0
if
use_multiprocessing
:
pool
=
mp
.
Pool
()
tf_example_iterator
=
pool
.
starmap
(
process_func
,
annotation_iterator
)
else
:
tf_example_iterator
=
itertools
.
starmap
(
process_func
,
annotation_iterator
)
for
idx
,
(
tf_example
,
num_annotations_skipped
)
in
enumerate
(
tf_example_iterator
):
if
idx
%
100
==
0
:
logging
.
info
(
'On image %d'
,
idx
)
total_num_annotations_skipped
+=
num_annotations_skipped
writers
[
idx
%
num_shards
].
write
(
tf_example
.
SerializeToString
())
if
use_multiprocessing
:
pool
.
close
()
pool
.
join
()
for
writer
in
writers
:
writer
.
close
()
logging
.
info
(
'Finished writing, skipped %d annotations.'
,
total_num_annotations_skipped
)
return
total_num_annotations_skipped
official/vision/beta/data/tfrecord_lib_test.py
0 → 100644
View file @
8c248a5a
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for tfrecord_lib."""
import
os
from
absl
import
flags
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.data
import
tfrecord_lib
FLAGS
=
flags
.
FLAGS
def
process_sample
(
x
):
d
=
{
'x'
:
x
}
return
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
d
)),
0
def
parse_function
(
example_proto
):
feature_description
=
{
'x'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
,
default_value
=-
1
)
}
return
tf
.
io
.
parse_single_example
(
example_proto
,
feature_description
)
class
TfrecordLibTest
(
parameterized
.
TestCase
):
def
test_write_tf_record_dataset
(
self
):
data
=
[(
tfrecord_lib
.
convert_to_feature
(
i
),)
for
i
in
range
(
17
)]
path
=
os
.
path
.
join
(
FLAGS
.
test_tmpdir
,
'train'
)
tfrecord_lib
.
write_tf_record_dataset
(
path
,
data
,
process_sample
,
3
,
use_multiprocessing
=
False
)
tfrecord_files
=
tf
.
io
.
gfile
.
glob
(
path
+
'*'
)
self
.
assertLen
(
tfrecord_files
,
3
)
dataset
=
tf
.
data
.
TFRecordDataset
(
tfrecord_files
)
dataset
=
dataset
.
map
(
parse_function
)
read_values
=
set
(
d
[
'x'
]
for
d
in
dataset
.
as_numpy_iterator
())
self
.
assertSetEqual
(
read_values
,
set
(
range
(
17
)))
def
test_convert_to_feature_float
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
(
0.0
)
self
.
assertEqual
(
proto
.
float_list
.
value
[
0
],
0.0
)
def
test_convert_to_feature_int
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
(
0
)
self
.
assertEqual
(
proto
.
int64_list
.
value
[
0
],
0
)
def
test_convert_to_feature_bytes
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
(
b
'123'
)
self
.
assertEqual
(
proto
.
bytes_list
.
value
[
0
],
b
'123'
)
def
test_convert_to_feature_float_list
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
([
0.0
,
1.0
])
self
.
assertSequenceAlmostEqual
(
proto
.
float_list
.
value
,
[
0.0
,
1.0
])
def
test_convert_to_feature_int_list
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
([
0
,
1
])
self
.
assertSequenceAlmostEqual
(
proto
.
int64_list
.
value
,
[
0
,
1
])
def
test_convert_to_feature_bytes_list
(
self
):
proto
=
tfrecord_lib
.
convert_to_feature
([
b
'123'
,
b
'456'
])
self
.
assertSequenceAlmostEqual
(
proto
.
bytes_list
.
value
,
[
b
'123'
,
b
'456'
])
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment