Commit f6437667 authored by Chris Shallue's avatar Chris Shallue Committed by Christopher Shallue
Browse files

Add bytes_encoding parameter to set_feature().

PiperOrigin-RevId: 205427760
parent 8793267f
...@@ -47,6 +47,7 @@ py_library( ...@@ -47,6 +47,7 @@ py_library(
name = "example_util", name = "example_util",
srcs = ["example_util.py"], srcs = ["example_util.py"],
srcs_version = "PY2AND3", srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
) )
py_test( py_test(
......
...@@ -79,7 +79,12 @@ def _infer_kind(value): ...@@ -79,7 +79,12 @@ def _infer_kind(value):
return "bytes_list" return "bytes_list"
def set_feature(ex, name, value, kind=None, allow_overwrite=False): def set_feature(ex,
name,
value,
kind=None,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets a feature value in a tf.train.Example. """Sets a feature value in a tf.train.Example.
Args: Args:
...@@ -89,6 +94,7 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False): ...@@ -89,6 +94,7 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False):
kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
not specified. not specified.
allow_overwrite: Whether to overwrite the existing value of the feature. allow_overwrite: Whether to overwrite the existing value of the feature.
bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
Raises: Raises:
ValueError: If `allow_overwrite` is False and the feature already exists, or ValueError: If `allow_overwrite` is False and the feature already exists, or
...@@ -105,7 +111,7 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False): ...@@ -105,7 +111,7 @@ def set_feature(ex, name, value, kind=None, allow_overwrite=False):
kind = _infer_kind(value) kind = _infer_kind(value)
if kind == "bytes_list": if kind == "bytes_list":
value = [str(v).encode("latin-1") for v in value] value = [str(v).encode(bytes_encoding) for v in value]
elif kind == "float_list": elif kind == "float_list":
value = [float(v) for v in value] value = [float(v) for v in value]
elif kind == "int64_list": elif kind == "int64_list":
...@@ -121,9 +127,13 @@ def set_float_feature(ex, name, value, allow_overwrite=False): ...@@ -121,9 +127,13 @@ def set_float_feature(ex, name, value, allow_overwrite=False):
set_feature(ex, name, value, "float_list", allow_overwrite) set_feature(ex, name, value, "float_list", allow_overwrite)
def set_bytes_feature(ex, name, value, allow_overwrite=False): def set_bytes_feature(ex,
name,
value,
allow_overwrite=False,
bytes_encoding="latin-1"):
"""Sets the value of a bytes feature in a tf.train.Example.""" """Sets the value of a bytes feature in a tf.train.Example."""
set_feature(ex, name, value, "bytes_list", allow_overwrite) set_feature(ex, name, value, "bytes_list", allow_overwrite, bytes_encoding)
def set_int64_feature(ex, name, value, allow_overwrite=False): def set_int64_feature(ex, name, value, allow_overwrite=False):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment