Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
672ac40b
Commit
672ac40b
authored
Jun 23, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Jun 23, 2018
Browse files
Add utilities for getting and setting features in tf.train.Example protos.
PiperOrigin-RevId: 201839334
parent
da7925b7
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
325 additions
and
0 deletions
+325
-0
research/astronet/astronet/util/BUILD
research/astronet/astronet/util/BUILD
+14
-0
research/astronet/astronet/util/example_util.py
research/astronet/astronet/util/example_util.py
+131
-0
research/astronet/astronet/util/example_util_test.py
research/astronet/astronet/util/example_util_test.py
+180
-0
No files found.
research/astronet/astronet/util/BUILD
View file @
672ac40b
...
...
@@ -42,3 +42,17 @@ py_library(
"//astronet/ops:training"
,
],
)
py_library
(
name
=
"example_util"
,
srcs
=
[
"example_util.py"
],
srcs_version
=
"PY2AND3"
,
)
py_test
(
name
=
"example_util_test"
,
size
=
"small"
,
srcs
=
[
"example_util_test.py"
],
srcs_version
=
"PY2AND3"
,
deps
=
[
":example_util"
],
)
research/astronet/astronet/util/example_util.py
0 → 100644
View file @
672ac40b
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for getting and setting values in tf.Example protocol buffers."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
def
get_feature
(
ex
,
name
,
kind
=
None
,
strict
=
True
):
"""Gets a feature value from a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
A numpy array containing to the values of the specified feature.
Raises:
KeyError: If there is no feature with the specified name.
TypeError: If the feature has a different type to that specified.
"""
if
name
not
in
ex
.
features
.
feature
:
if
strict
:
raise
KeyError
(
name
)
return
np
.
array
([])
inferred_kind
=
ex
.
features
.
feature
[
name
].
WhichOneof
(
"kind"
)
if
not
inferred_kind
:
return
np
.
array
([])
# Feature exists, but it's empty.
if
kind
and
kind
!=
inferred_kind
:
raise
TypeError
(
"Requested %s, but Feature has %s"
%
(
kind
,
inferred_kind
))
return
np
.
array
(
getattr
(
ex
.
features
.
feature
[
name
],
inferred_kind
).
value
)
def
get_bytes_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of a bytes feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"bytes_list"
,
strict
)
def
get_float_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of a float feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"float_list"
,
strict
)
def
get_int64_feature
(
ex
,
name
,
strict
=
True
):
"""Gets the value of an int64 feature from a tf.train.Example."""
return
get_feature
(
ex
,
name
,
"int64_list"
,
strict
)
def
_infer_kind
(
value
):
"""Infers the tf.train.Feature kind from a value."""
if
np
.
issubdtype
(
type
(
value
[
0
]),
np
.
integer
):
return
"int64_list"
try
:
float
(
value
[
0
])
return
"float_list"
except
ValueError
:
return
"bytes_list"
def
set_feature
(
ex
,
name
,
value
,
kind
=
None
,
allow_overwrite
=
False
):
"""Sets a feature value in a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
if `kind` is unrecognized.
"""
if
name
in
ex
.
features
.
feature
:
if
allow_overwrite
:
del
ex
.
features
.
feature
[
name
]
else
:
raise
ValueError
(
"Attempting to set duplicate feature with name: %s"
%
name
)
if
not
kind
:
kind
=
_infer_kind
(
value
)
if
kind
==
"bytes_list"
:
value
=
[
str
(
v
).
encode
(
"latin-1"
)
for
v
in
value
]
elif
kind
==
"float_list"
:
value
=
[
float
(
v
)
for
v
in
value
]
elif
kind
==
"int64_list"
:
value
=
[
int
(
v
)
for
v
in
value
]
else
:
raise
ValueError
(
"Unrecognized kind: %s"
%
kind
)
getattr
(
ex
.
features
.
feature
[
name
],
kind
).
value
.
extend
(
value
)
def
set_float_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
):
"""Sets the value of a float feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"float_list"
,
allow_overwrite
)
def
set_bytes_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"bytes_list"
,
allow_overwrite
)
def
set_int64_feature
(
ex
,
name
,
value
,
allow_overwrite
=
False
):
"""Sets the value of an int64 feature in a tf.train.Example."""
set_feature
(
ex
,
name
,
value
,
"int64_list"
,
allow_overwrite
)
research/astronet/astronet/util/example_util_test.py
0 → 100644
View file @
672ac40b
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for example_util.py."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
numpy
as
np
import
tensorflow
as
tf
from
astronet.util
import
example_util
class
ExampleUtilTest
(
tf
.
test
.
TestCase
):
def
test_get_feature
(
self
):
# Create Example.
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
v
.
encode
(
"latin-1"
)
for
v
in
[
"a"
,
"b"
,
"c"
]])
float_list
=
tf
.
train
.
FloatList
(
value
=
[
1.0
,
2.0
,
3.0
])
int64_list
=
tf
.
train
.
Int64List
(
value
=
[
11
,
22
,
33
])
ex
=
tf
.
train
.
Example
(
features
=
tf
.
train
.
Features
(
feature
=
{
"a_bytes"
:
tf
.
train
.
Feature
(
bytes_list
=
bytes_list
),
"b_float"
:
tf
.
train
.
Feature
(
float_list
=
float_list
),
"c_int64"
:
tf
.
train
.
Feature
(
int64_list
=
int64_list
),
"d_empty"
:
tf
.
train
.
Feature
(),
}))
# Get bytes feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"a_bytes"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"a_bytes"
,
"bytes_list"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"a_bytes"
).
astype
(
str
),
[
"a"
,
"b"
,
"c"
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"a_bytes"
,
"float_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_float_feature
(
ex
,
"a_bytes"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_int64_feature
(
ex
,
"a_bytes"
)
# Get float feature.
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_feature
(
ex
,
"b_float"
),
[
1.0
,
2.0
,
3.0
])
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_feature
(
ex
,
"b_float"
,
"float_list"
),
[
1.0
,
2.0
,
3.0
])
np
.
testing
.
assert_array_almost_equal
(
example_util
.
get_float_feature
(
ex
,
"b_float"
),
[
1.0
,
2.0
,
3.0
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"b_float"
,
"int64_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_bytes_feature
(
ex
,
"b_float"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_int64_feature
(
ex
,
"b_float"
)
# Get int64 feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"c_int64"
),
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"c_int64"
,
"int64_list"
),
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"c_int64"
),
[
11
,
22
,
33
])
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_feature
(
ex
,
"c_int64"
,
"bytes_list"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_bytes_feature
(
ex
,
"c_int64"
)
with
self
.
assertRaises
(
TypeError
):
example_util
.
get_float_feature
(
ex
,
"c_int64"
)
# Get empty feature.
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"d_empty"
,
"float_list"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_float_feature
(
ex
,
"d_empty"
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"d_empty"
),
[])
# Get nonexistent feature.
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_feature
(
ex
,
"nonexistent"
,
"bytes_list"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_bytes_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_float_feature
(
ex
,
"nonexistent"
)
with
self
.
assertRaises
(
KeyError
):
example_util
.
get_int64_feature
(
ex
,
"nonexistent"
)
np
.
testing
.
assert_array_equal
(
example_util
.
get_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_bytes_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_float_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
np
.
testing
.
assert_array_equal
(
example_util
.
get_int64_feature
(
ex
,
"nonexistent"
,
strict
=
False
),
[])
def
test_set_feature
(
self
):
ex
=
tf
.
train
.
Example
()
# Set bytes features.
example_util
.
set_feature
(
ex
,
"a1_bytes"
,
[
"a"
,
"b"
])
example_util
.
set_feature
(
ex
,
"a2_bytes"
,
[
"A"
,
"B"
],
kind
=
"bytes_list"
)
example_util
.
set_bytes_feature
(
ex
,
"a3_bytes"
,
[
"x"
,
"y"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a1_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"a"
,
"b"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a2_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"A"
,
"B"
])
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a3_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"x"
,
"y"
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"a3_bytes"
,
[
"xxx"
])
# Duplicate.
# Set float features.
example_util
.
set_feature
(
ex
,
"b1_float"
,
[
1.0
,
2.0
])
example_util
.
set_feature
(
ex
,
"b2_float"
,
[
10.0
,
20.0
],
kind
=
"float_list"
)
example_util
.
set_float_feature
(
ex
,
"b3_float"
,
[
88.0
,
99.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b1_float"
].
float_list
.
value
,
[
1.0
,
2.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b2_float"
].
float_list
.
value
,
[
10.0
,
20.0
])
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b3_float"
].
float_list
.
value
,
[
88.0
,
99.0
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"b3_float"
,
[
1234.0
])
# Duplicate.
# Set int64 features.
example_util
.
set_feature
(
ex
,
"c1_int64"
,
[
1
,
2
,
3
])
example_util
.
set_feature
(
ex
,
"c2_int64"
,
[
11
,
22
,
33
],
kind
=
"int64_list"
)
example_util
.
set_int64_feature
(
ex
,
"c3_int64"
,
[
88
,
99
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c1_int64"
].
int64_list
.
value
,
[
1
,
2
,
3
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c2_int64"
].
int64_list
.
value
,
[
11
,
22
,
33
])
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c3_int64"
].
int64_list
.
value
,
[
88
,
99
])
with
self
.
assertRaises
(
ValueError
):
example_util
.
set_feature
(
ex
,
"c3_int64"
,
[
1234
])
# Duplicate.
# Overwrite features.
example_util
.
set_feature
(
ex
,
"a3_bytes"
,
[
"xxx"
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_equal
(
np
.
array
(
ex
.
features
.
feature
[
"a3_bytes"
].
bytes_list
.
value
).
astype
(
str
),
[
"xxx"
])
example_util
.
set_feature
(
ex
,
"b3_float"
,
[
1234.0
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_almost_equal
(
ex
.
features
.
feature
[
"b3_float"
].
float_list
.
value
,
[
1234.0
])
example_util
.
set_feature
(
ex
,
"c3_int64"
,
[
1234
],
allow_overwrite
=
True
)
np
.
testing
.
assert_array_equal
(
ex
.
features
.
feature
[
"c3_int64"
].
int64_list
.
value
,
[
1234
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment