data_provider.py 921 Bytes
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
"""StarGAN data provider."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import data_provider


def provide_data(image_file_patterns, batch_size, patch_size):
  """Data provider wrapper on for the data_provider in gan/cyclegan.

  Args:
    image_file_patterns: A list of file pattern globs.
    batch_size: Python int. Batch size.
    patch_size: Python int. The patch size to extract.

  Returns:
    List of `Tensor` of shape (N, H, W, C) representing the images.
    List of `Tensor` of shape (N, num_domains) representing the labels.
  """

  images = data_provider.provide_custom_data(
      image_file_patterns,
      batch_size=batch_size,
      patch_size=patch_size)

  num_domains = len(images)
  labels = [tf.one_hot([idx] * batch_size, num_domains) for idx in
            range(num_domains)]

  return images, labels