Commit 825f7084 authored by joel-shor's avatar joel-shor
Browse files

Make TFGAN tutorial Python 3 compatible.

parent 5f99e589
...@@ -66,7 +66,7 @@ ...@@ -66,7 +66,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
...@@ -101,6 +101,7 @@ ...@@ -101,6 +101,7 @@
"import numpy as np\n", "import numpy as np\n",
"import time\n", "import time\n",
"import functools\n", "import functools\n",
"from six.moves import xrange # pylint: disable=redefined-builtin\n",
"\n", "\n",
"import tensorflow as tf\n", "import tensorflow as tf\n",
"\n", "\n",
...@@ -1066,14 +1067,20 @@ ...@@ -1066,14 +1067,20 @@
} }
], ],
"source": [ "source": [
"def _get_next(iterable):\n",
" try:\n",
" return iterable.next() # Python 2.x.x\n",
" except AttributeError:\n",
" return iterable.__next__() # Python 3.x.x\n",
"\n",
"# Run inference.\n", "# Run inference.\n",
"predict_input_fn = _get_predict_input_fn(36, NOISE_DIMS)\n", "predict_input_fn = _get_predict_input_fn(36, NOISE_DIMS)\n",
"prediction_iterable = gan_estimator.predict(\n", "prediction_iterable = gan_estimator.predict(\n",
" predict_input_fn, hooks=[tf.train.StopAtStepHook(last_step=1)])\n", " predict_input_fn, hooks=[tf.train.StopAtStepHook(last_step=1)])\n",
"predictions = [prediction_iterable.next() for _ in xrange(36)]\n", "predictions = [_get_next(prediction_iterable) for _ in xrange(36)]\n",
"\n", "\n",
"try: # Close the predict session.\n", "try: # Close the predict session.\n",
" prediction_iterable.next()\n", " _get_next(prediction_iterable)\n",
"except StopIteration:\n", "except StopIteration:\n",
" pass\n", " pass\n",
"\n", "\n",
...@@ -1889,7 +1896,7 @@ ...@@ -1889,7 +1896,7 @@
"assert images_to_eval % cat_dim == 0\n", "assert images_to_eval % cat_dim == 0\n",
"\n", "\n",
"unstructured_inputs = tf.random_normal([images_to_eval, noise_dims-cont_dim])\n", "unstructured_inputs = tf.random_normal([images_to_eval, noise_dims-cont_dim])\n",
"cat_noise = tf.constant(range(cat_dim) * (images_to_eval // cat_dim))\n", "cat_noise = tf.constant(list(range(cat_dim)) * (images_to_eval // cat_dim))\n",
"cont_noise = tf.random_uniform([images_to_eval, cont_dim], -1.0, 1.0)\n", "cont_noise = tf.random_uniform([images_to_eval, cont_dim], -1.0, 1.0)\n",
"\n", "\n",
"with tf.variable_scope(infogan_model.generator_scope, reuse=True):\n", "with tf.variable_scope(infogan_model.generator_scope, reuse=True):\n",
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment