launch_jobs.sh 5.5 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
#!/bin/bash
#
# This script performs the following operations:
# 1. Downloads the MNIST dataset.
# 2. Trains an unconditional, conditional, or InfoGAN model on the MNIST
#    training set.
# 3. Evaluates the models and writes sample images to disk.
#
# These examples are intended to be fast. For better final results, tune
# hyperparameters or train longer.
#
# NOTE: Each training step takes about 0.5 second with a batch size of 32 on
# CPU. On GPU, it takes ~5 milliseconds.
#
# With the default batch size and number of steps, train times are:
#
#   unconditional: CPU: 800  steps, ~10 min   GPU: 800  steps, ~1 min
#   conditional:   CPU: 2000 steps, ~20 min   GPU: 2000 steps, ~2 min
#   infogan:       CPU: 3000 steps, ~20 min   GPU: 3000 steps, ~6 min
#
# Usage:
# cd models/research/gan/mnist
# ./launch_jobs.sh ${gan_type} ${git_repo}
set -e

# Type of GAN to run. Right now, options are `unconditional`, `conditional`, or
# `infogan`.
gan_type=$1
if ! [[ "$gan_type" =~ ^(unconditional|conditional|infogan) ]]; then
  echo "'gan_type' must be one of: 'unconditional', 'conditional', 'infogan'."
  exit
fi

# Location of the git repository.
git_repo=$2
if [[ "$git_repo" == "" ]]; then
  echo "'git_repo' must not be empty."
  exit
fi

# Base name for where the checkpoint and logs will be saved to.
TRAIN_DIR=/tmp/mnist-model

# Base name for where the evaluation images will be saved to.
EVAL_DIR=/tmp/mnist-model/eval

# Where the dataset is saved to.
DATASET_DIR=/tmp/mnist-data

# Location of the classifier frozen graph used for evaluation.
FROZEN_GRAPH="${git_repo}/research/gan/mnist/data/classify_mnist_graph_def.pb"

export PYTHONPATH=$PYTHONPATH:$git_repo:$git_repo/research:$git_repo/research/slim

# A helper function for printing pretty output.
Banner () {
  local text=$1
  local green='\033[0;32m'
  local nc='\033[0m'  # No color.
  echo -e "${green}${text}${nc}"
}

# Download the dataset.
python "${git_repo}/research/slim/download_and_convert_data.py" \
  --dataset_name=mnist \
  --dataset_dir=${DATASET_DIR}

# Run unconditional GAN.
if [[ "$gan_type" == "unconditional" ]]; then
  UNCONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/unconditional"
  UNCONDITIONAL_EVAL_DIR="${EVAL_DIR}/unconditional"
  NUM_STEPS=3000
  # Run training.
  Banner "Starting training unconditional GAN for ${NUM_STEPS} steps..."
  python "${git_repo}/research/gan/mnist/train.py" \
    --train_log_dir=${UNCONDITIONAL_TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --max_number_of_steps=${NUM_STEPS} \
    --gan_type="unconditional" \
    --alsologtostderr
  Banner "Finished training unconditional GAN ${NUM_STEPS} steps."

  # Run evaluation.
  Banner "Starting evaluation of unconditional GAN..."
  python "${git_repo}/research/gan/mnist/eval.py" \
    --checkpoint_dir=${UNCONDITIONAL_TRAIN_DIR} \
    --eval_dir=${UNCONDITIONAL_EVAL_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --eval_real_images=false \
    --classifier_filename=${FROZEN_GRAPH} \
    --max_number_of_evaluation=1
  Banner "Finished unconditional evaluation. See ${UNCONDITIONAL_EVAL_DIR} for output images."
fi

# Run conditional GAN.
if [[ "$gan_type" == "conditional" ]]; then
  CONDITIONAL_TRAIN_DIR="${TRAIN_DIR}/conditional"
  CONDITIONAL_EVAL_DIR="${EVAL_DIR}/conditional"
  NUM_STEPS=3000
  # Run training.
  Banner "Starting training conditional GAN for ${NUM_STEPS} steps..."
  python "${git_repo}/research/gan/mnist/train.py" \
    --train_log_dir=${CONDITIONAL_TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --max_number_of_steps=${NUM_STEPS} \
    --gan_type="conditional" \
    --alsologtostderr
  Banner "Finished training conditional GAN ${NUM_STEPS} steps."

  # Run evaluation.
  Banner "Starting evaluation of conditional GAN..."
  python "${git_repo}/research/gan/mnist/conditional_eval.py" \
    --checkpoint_dir=${CONDITIONAL_TRAIN_DIR} \
    --eval_dir=${CONDITIONAL_EVAL_DIR} \
    --classifier_filename=${FROZEN_GRAPH} \
    --max_number_of_evaluation=1
  Banner "Finished conditional evaluation. See ${CONDITIONAL_EVAL_DIR} for output images."
fi

# Run InfoGAN.
if [[ "$gan_type" == "infogan" ]]; then
  INFOGAN_TRAIN_DIR="${TRAIN_DIR}/infogan"
  INFOGAN_EVAL_DIR="${EVAL_DIR}/infogan"
  NUM_STEPS=3000
  # Run training.
  Banner "Starting training infogan GAN for ${NUM_STEPS} steps..."
  python "${git_repo}/research/gan/mnist/train.py" \
    --train_log_dir=${INFOGAN_TRAIN_DIR} \
    --dataset_dir=${DATASET_DIR} \
    --max_number_of_steps=${NUM_STEPS} \
    --gan_type="infogan" \
    --alsologtostderr
  Banner "Finished training InfoGAN ${NUM_STEPS} steps."

  # Run evaluation.
  Banner "Starting evaluation of infogan..."
  python "${git_repo}/research/gan/mnist/infogan_eval.py" \
    --checkpoint_dir=${INFOGAN_TRAIN_DIR} \
    --eval_dir=${INFOGAN_EVAL_DIR} \
    --classifier_filename=${FROZEN_GRAPH} \
    --max_number_of_evaluation=1
  Banner "Finished InfoGAN evaluation. See ${INFOGAN_EVAL_DIR} for output images."
fi