data_provider_test.py 1.17 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""Tests for google3.third_party.tensorflow_models.gan.stargan.data_provider."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from google3.testing.pybase import googletest
import data_provider

mock = tf.test.mock


class DataProviderTest(googletest.TestCase):

  @mock.patch.object(
      data_provider.data_provider, 'provide_custom_data', autospec=True)
  def test_data_provider(self, mock_provide_custom_data):

    batch_size = 2
    patch_size = 8
    num_domains = 3

    images_shape = [batch_size, patch_size, patch_size, 3]
    mock_provide_custom_data.return_value = [
        tf.zeros(images_shape) for _ in range(num_domains)
    ]

    images, labels = data_provider.provide_data(
        image_file_patterns=None, batch_size=batch_size, patch_size=patch_size)

    self.assertEqual(num_domains, len(images))
    self.assertEqual(num_domains, len(labels))
    for label in labels:
      self.assertListEqual([batch_size, num_domains], label.shape.as_list())
    for image in images:
      self.assertListEqual(images_shape, image.shape.as_list())


if __name__ == '__main__':
  googletest.main()