"include/ck/utility/dynamic_buffer.hpp" did not exist on "880fbee95782a30fb16654f830502d03dd92fae2"
Commit 672ac40b authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add utilities for getting and setting features in tf.train.Example protos.

PiperOrigin-RevId: 201839334
parent da7925b7
...@@ -42,3 +42,17 @@ py_library( ...@@ -42,3 +42,17 @@ py_library(
"//astronet/ops:training", "//astronet/ops:training",
], ],
) )
py_library(
name = "example_util",
srcs = ["example_util.py"],
srcs_version = "PY2AND3",
)
py_test(
name = "example_util_test",
size = "small",
srcs = ["example_util_test.py"],
srcs_version = "PY2AND3",
deps = [":example_util"],
)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helpers for getting and setting values in tf.Example protocol buffers."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
def get_feature(ex, name, kind=None, strict=True):
"""Gets a feature value from a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to look up.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
strict: Whether to raise a KeyError if there is no such feature.
Returns:
A numpy array containing to the values of the specified feature.
Raises:
KeyError: If there is no feature with the specified name.
TypeError: If the feature has a different type to that specified.
"""
if name not in ex.features.feature:
if strict:
raise KeyError(name)
return np.array([])
inferred_kind = ex.features.feature[name].WhichOneof("kind")
if not inferred_kind:
return np.array([]) # Feature exists, but it's empty.
if kind and kind != inferred_kind:
raise TypeError("Requested %s, but Feature has %s" % (kind, inferred_kind))
return np.array(getattr(ex.features.feature[name], inferred_kind).value)
def get_bytes_feature(ex, name, strict=True):
"""Gets the value of a bytes feature from a tf.train.Example."""
return get_feature(ex, name, "bytes_list", strict)
def get_float_feature(ex, name, strict=True):
"""Gets the value of a float feature from a tf.train.Example."""
return get_feature(ex, name, "float_list", strict)
def get_int64_feature(ex, name, strict=True):
"""Gets the value of an int64 feature from a tf.train.Example."""
return get_feature(ex, name, "int64_list", strict)
def _infer_kind(value):
"""Infers the tf.train.Feature kind from a value."""
if np.issubdtype(type(value[0]), np.integer):
return "int64_list"
try:
float(value[0])
return "float_list"
except ValueError:
return "bytes_list"
def set_feature(ex, name, value, kind=None, allow_overwrite=False):
"""Sets a feature value in a tf.train.Example.
Args:
ex: A tf.train.Example.
name: Name of the feature to set.
value: Feature value to set. Must be a sequence.
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified.
allow_overwrite: Whether to overwrite the existing value of the feature.
Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or
if `kind` is unrecognized.
"""
if name in ex.features.feature:
if allow_overwrite:
del ex.features.feature[name]
else:
raise ValueError(
"Attempting to set duplicate feature with name: %s" % name)
if not kind:
kind = _infer_kind(value)
if kind == "bytes_list":
value = [str(v).encode("latin-1") for v in value]
elif kind == "float_list":
value = [float(v) for v in value]
elif kind == "int64_list":
value = [int(v) for v in value]
else:
raise ValueError("Unrecognized kind: %s" % kind)
getattr(ex.features.feature[name], kind).value.extend(value)
def set_float_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of a float feature in a tf.train.Example."""
set_feature(ex, name, value, "float_list", allow_overwrite)
def set_bytes_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of a bytes feature in a tf.train.Example."""
set_feature(ex, name, value, "bytes_list", allow_overwrite)
def set_int64_feature(ex, name, value, allow_overwrite=False):
"""Sets the value of an int64 feature in a tf.train.Example."""
set_feature(ex, name, value, "int64_list", allow_overwrite)
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for example_util.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from astronet.util import example_util
class ExampleUtilTest(tf.test.TestCase):
def test_get_feature(self):
# Create Example.
bytes_list = tf.train.BytesList(
value=[v.encode("latin-1") for v in ["a", "b", "c"]])
float_list = tf.train.FloatList(value=[1.0, 2.0, 3.0])
int64_list = tf.train.Int64List(value=[11, 22, 33])
ex = tf.train.Example(
features=tf.train.Features(
feature={
"a_bytes": tf.train.Feature(bytes_list=bytes_list),
"b_float": tf.train.Feature(float_list=float_list),
"c_int64": tf.train.Feature(int64_list=int64_list),
"d_empty": tf.train.Feature(),
}))
# Get bytes feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes").astype(str), ["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_feature(ex, "a_bytes", "bytes_list").astype(str),
["a", "b", "c"])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "a_bytes").astype(str),
["a", "b", "c"])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "a_bytes", "float_list")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "a_bytes")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "a_bytes")
# Get float feature.
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_feature(ex, "b_float", "float_list"), [1.0, 2.0, 3.0])
np.testing.assert_array_almost_equal(
example_util.get_float_feature(ex, "b_float"), [1.0, 2.0, 3.0])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "b_float", "int64_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "b_float")
with self.assertRaises(TypeError):
example_util.get_int64_feature(ex, "b_float")
# Get int64 feature.
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_feature(ex, "c_int64", "int64_list"), [11, 22, 33])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "c_int64"), [11, 22, 33])
with self.assertRaises(TypeError):
example_util.get_feature(ex, "c_int64", "bytes_list")
with self.assertRaises(TypeError):
example_util.get_bytes_feature(ex, "c_int64")
with self.assertRaises(TypeError):
example_util.get_float_feature(ex, "c_int64")
# Get empty feature.
np.testing.assert_array_equal(example_util.get_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_feature(ex, "d_empty", "float_list"), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "d_empty"), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "d_empty"), [])
# Get nonexistent feature.
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_feature(ex, "nonexistent", "bytes_list")
with self.assertRaises(KeyError):
example_util.get_bytes_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_float_feature(ex, "nonexistent")
with self.assertRaises(KeyError):
example_util.get_int64_feature(ex, "nonexistent")
np.testing.assert_array_equal(
example_util.get_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_bytes_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_float_feature(ex, "nonexistent", strict=False), [])
np.testing.assert_array_equal(
example_util.get_int64_feature(ex, "nonexistent", strict=False), [])
def test_set_feature(self):
ex = tf.train.Example()
# Set bytes features.
example_util.set_feature(ex, "a1_bytes", ["a", "b"])
example_util.set_feature(ex, "a2_bytes", ["A", "B"], kind="bytes_list")
example_util.set_bytes_feature(ex, "a3_bytes", ["x", "y"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a1_bytes"].bytes_list.value).astype(str),
["a", "b"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a2_bytes"].bytes_list.value).astype(str),
["A", "B"])
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["x", "y"])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "a3_bytes", ["xxx"]) # Duplicate.
# Set float features.
example_util.set_feature(ex, "b1_float", [1.0, 2.0])
example_util.set_feature(ex, "b2_float", [10.0, 20.0], kind="float_list")
example_util.set_float_feature(ex, "b3_float", [88.0, 99.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b1_float"].float_list.value, [1.0, 2.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b2_float"].float_list.value, [10.0, 20.0])
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [88.0, 99.0])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "b3_float", [1234.0]) # Duplicate.
# Set int64 features.
example_util.set_feature(ex, "c1_int64", [1, 2, 3])
example_util.set_feature(ex, "c2_int64", [11, 22, 33], kind="int64_list")
example_util.set_int64_feature(ex, "c3_int64", [88, 99])
np.testing.assert_array_equal(
ex.features.feature["c1_int64"].int64_list.value, [1, 2, 3])
np.testing.assert_array_equal(
ex.features.feature["c2_int64"].int64_list.value, [11, 22, 33])
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [88, 99])
with self.assertRaises(ValueError):
example_util.set_feature(ex, "c3_int64", [1234]) # Duplicate.
# Overwrite features.
example_util.set_feature(ex, "a3_bytes", ["xxx"], allow_overwrite=True)
np.testing.assert_array_equal(
np.array(ex.features.feature["a3_bytes"].bytes_list.value).astype(str),
["xxx"])
example_util.set_feature(ex, "b3_float", [1234.0], allow_overwrite=True)
np.testing.assert_array_almost_equal(
ex.features.feature["b3_float"].float_list.value, [1234.0])
example_util.set_feature(ex, "c3_int64", [1234], allow_overwrite=True)
np.testing.assert_array_equal(
ex.features.feature["c3_int64"].int64_list.value, [1234])
if __name__ == "__main__":
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment