data_provider.py 6.57 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains code for loading and preprocessing image data."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function


import numpy as np
import tensorflow as tf


def normalize_image(image):
  """Rescale from range [0, 255] to [-1, 1]."""
  return (tf.to_float(image) - 127.5) / 127.5


def undo_normalize_image(normalized_image):
  """Convert to a numpy array that can be read by PIL."""
  # Convert from NHWC to HWC.
  normalized_image = np.squeeze(normalized_image, axis=0)
  return np.uint8(normalized_image * 127.5 + 127.5)


def _sample_patch(image, patch_size):
  """Crop image to square shape and resize it to `patch_size`.

  Args:
    image: A 3D `Tensor` of HWC format.
    patch_size: A Python scalar.  The output image size.

  Returns:
    A 3D `Tensor` of HWC format which has the shape of
    [patch_size, patch_size, 3].
  """
  image_shape = tf.shape(image)
  height, width = image_shape[0], image_shape[1]
  target_size = tf.minimum(height, width)
  image = tf.image.resize_image_with_crop_or_pad(image, target_size,
                                                 target_size)
  # tf.image.resize_area only accepts 4D tensor, so expand dims first.
  image = tf.expand_dims(image, axis=0)
  image = tf.image.resize_images(image, [patch_size, patch_size])
  image = tf.squeeze(image, axis=0)
  # Force image num_channels = 3
  image = tf.tile(image, [1, 1, tf.maximum(1, 4 - tf.shape(image)[2])])
  image = tf.slice(image, [0, 0, 0], [patch_size, patch_size, 3])
  return image


def full_image_to_patch(image, patch_size):
  image = normalize_image(image)
  # Sample a patch of fixed size.
  image_patch = _sample_patch(image, patch_size)
  image_patch.shape.assert_is_compatible_with([patch_size, patch_size, 3])
  return image_patch


def _provide_custom_dataset(image_file_pattern,
                            batch_size,
                            shuffle=True,
                            num_threads=1,
                            patch_size=128):
  """Provides batches of custom image data.

  Args:
    image_file_pattern: A string of glob pattern of image files.
    batch_size: The number of images in each batch.
    shuffle: Whether to shuffle the read images.  Defaults to True.
83
    num_threads: Number of mapping threads.  Defaults to 1.
84
85
86
    patch_size: Size of the path to extract from the image.  Defaults to 128.

  Returns:
87
88
    A tf.data.Dataset with Tensors of shape
    [batch_size, patch_size, patch_size, 3] representing a batch of images.
89

90
91
92
93
94
95
96
97
98
99
100
101
102
  Raises:
    ValueError: If no files match `image_file_pattern`.
  """
  if not tf.gfile.Glob(image_file_pattern):
    raise ValueError('No file patterns found.')
  filenames_ds = tf.data.Dataset.list_files(image_file_pattern)
  bytes_ds = filenames_ds.map(tf.io.read_file, num_parallel_calls=num_threads)
  images_ds = bytes_ds.map(
      tf.image.decode_image, num_parallel_calls=num_threads)
  patches_ds = images_ds.map(
      lambda img: full_image_to_patch(img, patch_size),
      num_parallel_calls=num_threads)
  patches_ds = patches_ds.repeat()
103
104

  if shuffle:
105
106
107
108
109
110
    patches_ds = patches_ds.shuffle(5 * batch_size)

  patches_ds = patches_ds.prefetch(5 * batch_size)
  patches_ds = patches_ds.batch(batch_size)

  return patches_ds
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127


def provide_custom_datasets(image_file_patterns,
                            batch_size,
                            shuffle=True,
                            num_threads=1,
                            patch_size=128):
  """Provides multiple batches of custom image data.

  Args:
    image_file_patterns: A list of glob patterns of image files.
    batch_size: The number of images in each batch.
    shuffle: Whether to shuffle the read images.  Defaults to True.
    num_threads: Number of prefetching threads.  Defaults to 1.
    patch_size: Size of the patch to extract from the image.  Defaults to 128.

  Returns:
128
129
    A list of tf.data.Datasets the same number as `image_file_patterns`. Each
    of the datasets have `Tensor`'s in the list has a shape of
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
    [batch_size, patch_size, patch_size, 3] representing a batch of images.

  Raises:
    ValueError: If image_file_patterns is not a list or tuple.
  """
  if not isinstance(image_file_patterns, (list, tuple)):
    raise ValueError(
        '`image_file_patterns` should be either list or tuple, but was {}.'.
        format(type(image_file_patterns)))
  custom_datasets = []
  for pattern in image_file_patterns:
    custom_datasets.append(
        _provide_custom_dataset(
            pattern,
            batch_size=batch_size,
            shuffle=shuffle,
            num_threads=num_threads,
            patch_size=patch_size))
148

149
  return custom_datasets
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185


def provide_custom_data(image_file_patterns,
                        batch_size,
                        shuffle=True,
                        num_threads=1,
                        patch_size=128):
  """Provides multiple batches of custom image data.

  Args:
    image_file_patterns: A list of glob patterns of image files.
    batch_size: The number of images in each batch.
    shuffle: Whether to shuffle the read images.  Defaults to True.
    num_threads: Number of prefetching threads.  Defaults to 1.
    patch_size: Size of the patch to extract from the image.  Defaults to 128.

  Returns:
    A list of float `Tensor`s with the same size of `image_file_patterns`. Each
    of the `Tensor` in the list has a shape of
    [batch_size, patch_size, patch_size, 3] representing a batch of images. As a
    side effect, the tf.Dataset initializer is added to the
    tf.GraphKeys.TABLE_INITIALIZERS collection.

  Raises:
    ValueError: If image_file_patterns is not a list or tuple.
  """
  datasets = provide_custom_datasets(
      image_file_patterns, batch_size, shuffle, num_threads, patch_size)

  tensors = []
  for ds in datasets:
    iterator = ds.make_initializable_iterator()
    tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, iterator.initializer)
    tensors.append(iterator.get_next())

  return tensors