example_util.py 4.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for getting and setting values in tf.Example protocol buffers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


def get_feature(ex, name, kind=None, strict=True):
  """Gets a feature value from a tf.train.Example.

  Args:
    ex: A tf.train.Example.
    name: Name of the feature to look up.
    kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
        not specified.
    strict: Whether to raise a KeyError if there is no such feature.

  Returns:
    A numpy array containing to the values of the specified feature.

  Raises:
    KeyError: If there is no feature with the specified name.
    TypeError: If the feature has a different type to that specified.
  """
  if name not in ex.features.feature:
    if strict:
      raise KeyError(name)
    return np.array([])

  inferred_kind = ex.features.feature[name].WhichOneof("kind")
  if not inferred_kind:
    return np.array([])  # Feature exists, but it's empty.

  if kind and kind != inferred_kind:
51
52
    raise TypeError("Requested {}, but Feature has {}".format(
        kind, inferred_kind))
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

  return np.array(getattr(ex.features.feature[name], inferred_kind).value)


def get_bytes_feature(ex, name, strict=True):
  """Gets the value of a bytes feature from a tf.train.Example."""
  return get_feature(ex, name, "bytes_list", strict)


def get_float_feature(ex, name, strict=True):
  """Gets the value of a float feature from a tf.train.Example."""
  return get_feature(ex, name, "float_list", strict)


def get_int64_feature(ex, name, strict=True):
  """Gets the value of an int64 feature from a tf.train.Example."""
  return get_feature(ex, name, "int64_list", strict)


def _infer_kind(value):
  """Infers the tf.train.Feature kind from a value."""
  if np.issubdtype(type(value[0]), np.integer):
    return "int64_list"
  try:
    float(value[0])
    return "float_list"
  except ValueError:
    return "bytes_list"


83
84
85
86
87
88
def set_feature(ex,
                name,
                value,
                kind=None,
                allow_overwrite=False,
                bytes_encoding="latin-1"):
89
90
91
92
93
94
95
96
97
  """Sets a feature value in a tf.train.Example.

  Args:
    ex: A tf.train.Example.
    name: Name of the feature to set.
    value: Feature value to set. Must be a sequence.
    kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
        not specified.
    allow_overwrite: Whether to overwrite the existing value of the feature.
98
    bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
99
100
101
102
103
104
105
106
107
108

  Raises:
    ValueError: If `allow_overwrite` is False and the feature already exists, or
        if `kind` is unrecognized.
  """
  if name in ex.features.feature:
    if allow_overwrite:
      del ex.features.feature[name]
    else:
      raise ValueError(
109
110
          "Attempting to overwrite feature with name: {}. "
          "Set allow_overwrite=True if this is desired.".format(name))
111
112
113
114
115

  if not kind:
    kind = _infer_kind(value)

  if kind == "bytes_list":
116
    value = [str(v).encode(bytes_encoding) for v in value]
117
118
119
120
121
  elif kind == "float_list":
    value = [float(v) for v in value]
  elif kind == "int64_list":
    value = [int(v) for v in value]
  else:
122
    raise ValueError("Unrecognized kind: {}".format(kind))
123
124
125
126
127
128
129
130
131

  getattr(ex.features.feature[name], kind).value.extend(value)


def set_float_feature(ex, name, value, allow_overwrite=False):
  """Sets the value of a float feature in a tf.train.Example."""
  set_feature(ex, name, value, "float_list", allow_overwrite)


132
133
134
135
136
def set_bytes_feature(ex,
                      name,
                      value,
                      allow_overwrite=False,
                      bytes_encoding="latin-1"):
137
  """Sets the value of a bytes feature in a tf.train.Example."""
138
  set_feature(ex, name, value, "bytes_list", allow_overwrite, bytes_encoding)
139
140
141
142
143


def set_int64_feature(ex, name, value, allow_overwrite=False):
  """Sets the value of an int64 feature in a tf.train.Example."""
  set_feature(ex, name, value, "int64_list", allow_overwrite)