example_input.py 4.75 KB
Newer Older
A. Unique TensorFlower's avatar
A. Unique TensorFlower committed
1
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Fan Yang's avatar
Fan Yang committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Example classification decoder and parser.

This file defines the Decoder and Parser to load data. The example is shown on
loading standard tf.Example data but non-standard tf.Example or other data
format can be supported by implementing proper decoder and parser.
"""
from typing import Mapping, List, Tuple
# Import libraries
import tensorflow as tf

Fan Yang's avatar
Fan Yang committed
25
26
27
from official.vision.dataloaders import decoder
from official.vision.dataloaders import parser
from official.vision.ops import preprocess_ops
Fan Yang's avatar
Fan Yang committed
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
STDDEV_RGB = (0.229 * 255, 0.224 * 255, 0.225 * 255)


class Decoder(decoder.Decoder):
  """A tf.Example decoder for classification task."""

  def __init__(self):
    """Initializes the decoder.

    The constructor defines the mapping between the field name and the value
    from an input tf.Example. For example, we define two fields for image bytes
    and labels. There is no limit on the number of fields to decode.
    """
    self._keys_to_features = {
        'image/encoded':
            tf.io.FixedLenFeature((), tf.string, default_value=''),
        'image/class/label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1)
    }

  def decode(self,
             serialized_example: tf.train.Example) -> Mapping[str, tf.Tensor]:
    """Decodes a tf.Example to a dictionary.

    This function decodes a serialized tf.Example to a dictionary. The output
    will be consumed by `_parse_train_data` and `_parse_validation_data` in
    Parser.

    Args:
      serialized_example: A serialized tf.Example.

    Returns:
      A dictionary of field key name and decoded tensor mapping.
    """
    return tf.io.parse_single_example(
        serialized_example, self._keys_to_features)


class Parser(parser.Parser):
  """Parser to parse an image and its annotations.

  To define own Parser, client should override _parse_train_data and
  _parse_eval_data functions, where decoded tensors are parsed with optional
  pre-processing steps. The output from the two functions can be any structure
  like tuple, list or dictionary.
  """

  def __init__(self, output_size: List[int], num_classes: float):
    """Initializes parameters for parsing annotations in the dataset.

    This example only takes two arguments but one can freely add as many
    arguments as needed. For example, pre-processing and augmentations usually
    happen in Parser, and related parameters can be passed in by this
    constructor.

    Args:
      output_size: `Tensor` or `list` for [height, width] of output image.
      num_classes: `float`, number of classes.
    """
    self._output_size = output_size
    self._num_classes = num_classes
    self._dtype = tf.float32

  def _parse_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    label = tf.cast(decoded_tensors['image/class/label'], dtype=tf.int32)
    image_bytes = decoded_tensors['image/encoded']
    image = tf.io.decode_jpeg(image_bytes, channels=3)
    image = tf.image.resize(
        image, self._output_size, method=tf.image.ResizeMethod.BILINEAR)
    image = tf.ensure_shape(image, self._output_size + [3])

    # Normalizes image with mean and std pixel values.
    image = preprocess_ops.normalize_image(
        image, offset=MEAN_RGB, scale=STDDEV_RGB)

    image = tf.image.convert_image_dtype(image, self._dtype)
    return image, label

  def _parse_train_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses data for training.

    Args:
      decoded_tensors: A dictionary of field key name and decoded tensor mapping
        from Decoder.

    Returns:
      A tuple of (image, label) tensors.

    """
    return self._parse_data(decoded_tensors)

  def _parse_eval_data(
      self, decoded_tensors: Mapping[str,
                                     tf.Tensor]) -> Tuple[tf.Tensor, tf.Tensor]:
    """Parses data for evaluation.

    Args:
      decoded_tensors: A dictionary of field key name and decoded tensor mapping
        from Decoder.

    Returns:
      A tuple of (image, label) tensors.
    """
    return self._parse_data(decoded_tensors)