train_test.py 3.05 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os


from absl import flags
from absl.testing import absltest
import tensorflow as tf

import train


FLAGS = flags.FLAGS


32
def provide_random_data(batch_size=2, patch_size=4, colors=1, **unused_kwargs):
33
34
35
36
37
38
39
  return tf.random_normal([batch_size, patch_size, patch_size, colors])


class TrainTest(absltest.TestCase):

  def setUp(self):
    self._config = {
40
41
        'start_height': 2,
        'start_width': 2,
42
43
        'scale_base': 2,
        'num_resolutions': 2,
44
        'batch_size_schedule': [2],
45
46
47
        'colors': 1,
        'to_rgb_use_tanh_activation': True,
        'kernel_size': 3,
48
49
50
51
52
        'stable_stage_num_images': 1,
        'transition_stage_num_images': 1,
        'total_num_images': 3,
        'save_summaries_num_images': 2,
        'latent_vector_size': 2,
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
        'fmap_base': 8,
        'fmap_decay': 1.0,
        'fmap_max': 8,
        'gradient_penalty_target': 1.0,
        'gradient_penalty_weight': 10.0,
        'real_score_penalty_weight': 0.001,
        'generator_learning_rate': 0.001,
        'discriminator_learning_rate': 0.001,
        'adam_beta1': 0.0,
        'adam_beta2': 0.99,
        'fake_grid_size': 2,
        'interp_grid_size': 2,
        'train_root_dir': os.path.join(FLAGS.test_tmpdir, 'progressive_gan'),
        'master': '',
        'task': 0
    }

  def test_train_success(self):
    train_root_dir = self._config['train_root_dir']
    if not tf.gfile.Exists(train_root_dir):
      tf.gfile.MakeDirs(train_root_dir)

    for stage_id in train.get_stage_ids(**self._config):
76
      batch_size = train.get_batch_size(stage_id, **self._config)
77
      tf.reset_default_graph()
78
79
80
      real_images = provide_random_data(batch_size=batch_size)
      model = train.build_model(stage_id, batch_size, real_images,
                                **self._config)
81
82
83
      train.add_model_summaries(model, **self._config)
      train.train(model, **self._config)

84
85
86
87
88
89
90
  def test_get_batch_size(self):
    config = {'num_resolutions': 5, 'batch_size_schedule': [8, 4, 2]}
    # batch_size_schedule is expanded to [8, 8, 8, 4, 2]
    # At stage level it is [8, 8, 8, 8, 8, 4, 4, 2, 2]
    for i, expected_batch_size in enumerate([8, 8, 8, 8, 8, 4, 4, 2, 2]):
      self.assertEqual(train.get_batch_size(i, **config), expected_batch_size)

91
92
93

if __name__ == '__main__':
  absltest.main()