"docs/source/en/api/schedulers/euler_ancestral.md" did not exist on "75d53cc83966b4046e5a329ddf7baa6aa24f52e2"
example_util.py 4.58 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Copyright 2018 The TensorFlow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Helpers for getting and setting values in tf.Example protocol buffers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np


def get_feature(ex, name, kind=None, strict=True):
  """Gets a feature value from a tf.train.Example.

  Args:
    ex: A tf.train.Example.
    name: Name of the feature to look up.
    kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
Chris Shallue's avatar
Chris Shallue committed
31
      not specified.
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
    strict: Whether to raise a KeyError if there is no such feature.

  Returns:
    A numpy array containing to the values of the specified feature.

  Raises:
    KeyError: If there is no feature with the specified name.
    TypeError: If the feature has a different type to that specified.
  """
  if name not in ex.features.feature:
    if strict:
      raise KeyError(name)
    return np.array([])

  inferred_kind = ex.features.feature[name].WhichOneof("kind")
  if not inferred_kind:
    return np.array([])  # Feature exists, but it's empty.

  if kind and kind != inferred_kind:
51
52
    raise TypeError("Requested {}, but Feature has {}".format(
        kind, inferred_kind))
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82

  return np.array(getattr(ex.features.feature[name], inferred_kind).value)


def get_bytes_feature(ex, name, strict=True):
  """Gets the value of a bytes feature from a tf.train.Example."""
  return get_feature(ex, name, "bytes_list", strict)


def get_float_feature(ex, name, strict=True):
  """Gets the value of a float feature from a tf.train.Example."""
  return get_feature(ex, name, "float_list", strict)


def get_int64_feature(ex, name, strict=True):
  """Gets the value of an int64 feature from a tf.train.Example."""
  return get_feature(ex, name, "int64_list", strict)


def _infer_kind(value):
  """Infers the tf.train.Feature kind from a value."""
  if np.issubdtype(type(value[0]), np.integer):
    return "int64_list"
  try:
    float(value[0])
    return "float_list"
  except ValueError:
    return "bytes_list"


83
84
85
86
87
88
def set_feature(ex,
                name,
                value,
                kind=None,
                allow_overwrite=False,
                bytes_encoding="latin-1"):
89
90
91
92
93
94
95
  """Sets a feature value in a tf.train.Example.

  Args:
    ex: A tf.train.Example.
    name: Name of the feature to set.
    value: Feature value to set. Must be a sequence.
    kind: Optional: one of 'bytes_list', 'float_list', 'int64_list'. Inferred if
Chris Shallue's avatar
Chris Shallue committed
96
      not specified.
97
    allow_overwrite: Whether to overwrite the existing value of the feature.
98
    bytes_encoding: Codec for encoding strings when kind = 'bytes_list'.
99
100
101
102
103
104
105
106
107
108

  Raises:
    ValueError: If `allow_overwrite` is False and the feature already exists, or
        if `kind` is unrecognized.
  """
  if name in ex.features.feature:
    if allow_overwrite:
      del ex.features.feature[name]
    else:
      raise ValueError(
109
110
          "Attempting to overwrite feature with name: {}. "
          "Set allow_overwrite=True if this is desired.".format(name))
111
112
113
114
115

  if not kind:
    kind = _infer_kind(value)

  if kind == "bytes_list":
116
    value = [str(v).encode(bytes_encoding) for v in value]
117
118
119
120
121
  elif kind == "float_list":
    value = [float(v) for v in value]
  elif kind == "int64_list":
    value = [int(v) for v in value]
  else:
122
    raise ValueError("Unrecognized kind: {}".format(kind))
123
124
125
126
127
128
129
130
131

  getattr(ex.features.feature[name], kind).value.extend(value)


def set_float_feature(ex, name, value, allow_overwrite=False):
  """Sets the value of a float feature in a tf.train.Example."""
  set_feature(ex, name, value, "float_list", allow_overwrite)


132
133
134
135
136
def set_bytes_feature(ex,
                      name,
                      value,
                      allow_overwrite=False,
                      bytes_encoding="latin-1"):
137
  """Sets the value of a bytes feature in a tf.train.Example."""
138
  set_feature(ex, name, value, "bytes_list", allow_overwrite, bytes_encoding)
139
140
141
142
143


def set_int64_feature(ex, name, value, allow_overwrite=False):
  """Sets the value of an int64 feature in a tf.train.Example."""
  set_feature(ex, name, value, "int64_list", allow_overwrite)