Commit b2136abf authored by Mark Daoust's avatar Mark Daoust
Browse files

Pull AutoGraph workshop into docs.

From tensorflow/tensorflow/blob/8a6ef2cb4f98bacc1f821f60c21914b4bd5faaef
parent 3a05570f
{
"cells": [
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "qWUV0FYjDSKj"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"from tensorflow.contrib import autograph\n",
"\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "kGXS3UWBBNoc"
},
"source": [
"# 1. AutoGraph writes graph code for you\n",
"\n",
"[AutoGraph](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/README.md) helps you write complicated graph code using just plain Python -- behind the scenes, AutoGraph automatically transforms your code into the equivalent TF graph code. We support a large chunk of the Python language, which is growing. [Please see this document for what we currently support, and what we're working on](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/LIMITATIONS.md).\n",
"\n",
"Here's a quick example of how it works:\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "aA3gOodCBkOw"
},
"outputs": [],
"source": [
"# Autograph can convert functions like this...\n",
"def g(x):\n",
" if x \u003e 0:\n",
" x = x * x\n",
" else:\n",
" x = 0.0\n",
" return x\n",
"\n",
"# ...into graph-building functions like this:\n",
"def tf_g(x):\n",
" with tf.name_scope('g'):\n",
" \n",
" def if_true():\n",
" with tf.name_scope('if_true'):\n",
" x_1, = x,\n",
" x_1 = x_1 * x_1\n",
" return x_1,\n",
"\n",
" def if_false():\n",
" with tf.name_scope('if_false'):\n",
" x_1, = x,\n",
" x_1 = 0.0\n",
" return x_1,\n",
"\n",
" x = autograph_utils.run_cond(tf.greater(x, 0), if_true, if_false)\n",
" return x\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "I1RtBvoKBxq5"
},
"outputs": [],
"source": [
"# You can run your plain-Python code in graph mode,\n",
"# and get the same results out, but with all the benfits of graphs:\n",
"print('Original value: %2.2f' % g(9.0))\n",
"\n",
"# Generate a graph-version of g and call it:\n",
"tf_g = autograph.to_graph(g)\n",
"\n",
"with tf.Graph().as_default(): \n",
" # The result works like a regular op: takes tensors in, returns tensors.\n",
" # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
" g_ops = tf_g(tf.constant(9.0))\n",
" with tf.Session() as sess:\n",
" print('Autograph value: %2.2f\\n' % sess.run(g_ops))\n",
" \n",
" \n",
"# You can view, debug and tweak the generated code:\n",
"print(autograph.to_code(g))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "m-jWmsCmByyw"
},
"source": [
"#### Automatically converting complex control flow\n",
"\n",
"AutoGraph can convert a large chunk of the Python language into equivalent graph-construction code, and we're adding new supported language features all the time. In this section, we'll give you a taste of some of the functionality in AutoGraph.\n",
"AutoGraph will automatically convert most Python control flow statements into their correct graph equivalent. \n",
" \n",
"We support common statements like `while`, `for`, `if`, `break`, `return` and more. You can even nest them as much as you like. Imagine trying to write the graph version of this code by hand:\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "toxKBOXbB1ro"
},
"outputs": [],
"source": [
"# Continue in a loop\n",
"def f(l):\n",
" s = 0\n",
" for c in l:\n",
" if c % 2 \u003e 0:\n",
" continue\n",
" s += c\n",
" return s\n",
"\n",
"print('Original value: %d' % f([10,12,15,20]))\n",
"\n",
"tf_f = autograph.to_graph(f)\n",
"with tf.Graph().as_default(): \n",
" with tf.Session():\n",
" print('Graph value: %d\\n\\n' % tf_f(tf.constant([10,12,15,20])).eval())\n",
" \n",
"print(autograph.to_code(f))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FUJJ-WTdCGeq"
},
"source": [
"Try replacing the `continue` in the above code with `break` -- AutoGraph supports that as well! \n",
" \n",
"Let's try some other useful Python constructs, like `print` and `assert`. We automatically convert Python `assert` statements into the equivalent `tf.Assert` code. "
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "IAOgh62zCPZ4"
},
"outputs": [],
"source": [
"def f(x):\n",
" assert x != 0, 'Do not pass zero!'\n",
" return x * x\n",
"\n",
"tf_f = autograph.to_graph(f)\n",
"with tf.Graph().as_default(): \n",
" with tf.Session():\n",
" try:\n",
" print(tf_f(tf.constant(0)).eval())\n",
" except tf.errors.InvalidArgumentError as e:\n",
" print('Got error message:\\n%s' % e.message)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "KRu8iIPBCQr5"
},
"source": [
"You can also use plain Python `print` functions in in-graph"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ySTsuxnqCTQi"
},
"outputs": [],
"source": [
"def f(n):\n",
" if n \u003e= 0:\n",
" while n \u003c 5:\n",
" n += 1\n",
" print(n)\n",
" return n\n",
" \n",
"tf_f = autograph.to_graph(f)\n",
"with tf.Graph().as_default():\n",
" with tf.Session():\n",
" tf_f(tf.constant(0)).eval()"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "NqF0GT-VCVFh"
},
"source": [
"Appending to lists in loops also works (we create a `TensorArray` for you behind the scenes)"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ABX070KwCczR"
},
"outputs": [],
"source": [
"def f(n):\n",
" z = []\n",
" # We ask you to tell us the element dtype of the list\n",
" z = autograph.utils.set_element_type(z, tf.int32)\n",
" for i in range(n):\n",
" z.append(i)\n",
" # when you're done with the list, stack it\n",
" # (this is just like np.stack)\n",
" return autograph.stack(z) \n",
"\n",
"tf_f = autograph.to_graph(f)\n",
"with tf.Graph().as_default(): \n",
" with tf.Session():\n",
" print(tf_f(tf.constant(3)).eval())\n",
"\n",
"print('\\n\\n'+autograph.to_code(f))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "iu5IF7n2Df7C"
},
"outputs": [],
"source": [
"def fizzbuzz(num):\n",
" if num % 3 == 0 and num % 5 == 0:\n",
" print('FizzBuzz')\n",
" elif num % 3 == 0:\n",
" print('Fizz')\n",
" elif num % 5 == 0:\n",
" print('Buzz')\n",
" else:\n",
" print(num)\n",
" return num"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "EExAjWuwDPpR"
},
"outputs": [],
"source": [
"tf_g = autograph.to_graph(fizzbuzz)\n",
"\n",
"with tf.Graph().as_default(): \n",
" # The result works like a regular op: takes tensors in, returns tensors.\n",
" # You can inspect the graph using tf.get_default_graph().as_graph_def()\n",
" g_ops = tf_g(tf.constant(15))\n",
" with tf.Session() as sess:\n",
" sess.run(g_ops) \n",
" \n",
"# You can view, debug and tweak the generated code:\n",
"print('\\n')\n",
"print(autograph.to_code(fizzbuzz))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "SzpKGzVpBkph"
},
"source": [
"# De-graphify Exercises\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "8k23dxcSmmXq"
},
"source": [
"#### Easy print statements"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "dE1Vsmp-mlpK"
},
"outputs": [],
"source": [
"# See what happens when you turn AutoGraph off.\n",
"# Do you see the type or the value of x when you print it?\n",
"\n",
"# @autograph.convert()\n",
"def square_log(x):\n",
" x = x * x\n",
" print('Squared value of x =', x)\n",
" return x\n",
"\n",
"\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(square_log(tf.constant(4))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "_R-Q7BbxmkBF"
},
"source": [
"#### Now some exercises. Convert the TensorFlow code into AutoGraph'd Python code."
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "SwA11tO-yCvg"
},
"outputs": [],
"source": [
"def square_if_positive(x):\n",
" x = tf.cond(tf.greater(x, 0), lambda: x * x, lambda: x)\n",
" return x\n",
"\n",
"with tf.Session() as sess:\n",
" print(sess.run(square_if_positive(tf.constant(4))))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "GPmx4CNhyPI_"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def square_if_positive(x):\n",
" ... # \u003c\u003c\u003c fill it in!\n",
" \n",
"with tf.Session() as sess:\n",
" print(sess.run(square_if_positive(tf.constant(4))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qqsjik-QyA9R"
},
"source": [
"#### Uncollapse to see answer"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "DaSmaWUEvMRv"
},
"outputs": [],
"source": [
"# Simple cond\n",
"@autograph.convert()\n",
"def square_if_positive(x):\n",
" if x \u003e 0:\n",
" x = x * x\n",
" return x\n",
"\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(square_if_positive(tf.constant(4))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "qj7am2I_xvTJ"
},
"source": [
"#### Nested If statement"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "4yyNOf-Twr6s"
},
"outputs": [],
"source": [
"def nearest_odd_square(x):\n",
"\n",
" def if_positive():\n",
" x1 = x * x\n",
" x1 = tf.cond(tf.equal(x1 % 2, 0), lambda: x1 + 1, lambda: x1)\n",
" return x1,\n",
"\n",
" x = tf.cond(tf.greater(x, 0), if_positive, lambda: x)\n",
" return x\n",
"\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(nearest_odd_square(tf.constant(4))))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "hqmh5b2VyU9w"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def nearest_odd_square(x):\n",
" ... # \u003c\u003c\u003c fill it in!\n",
" \n",
"with tf.Session() as sess:\n",
" print(sess.run(nearest_odd_square(tf.constant(4))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "b9AXIkNLxp6J"
},
"source": [
"#### Uncollapse to reveal answer"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "8RlCVEpNxD91"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def nearest_odd_square(x):\n",
" if x \u003e 0:\n",
" x = x * x\n",
" if x % 2 == 0:\n",
" x = x + 1\n",
" return x\n",
"\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(nearest_odd_square(tf.constant(4))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "jXAxjeBr1qWK"
},
"source": [
"#### Convert a while loop"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "kWkv7anlxoee"
},
"outputs": [],
"source": [
"# Convert a while loop\n",
"def square_until_stop(x, y):\n",
" x = tf.while_loop(lambda x: tf.less(x, y), lambda x: x * x, [x])\n",
" return x\n",
" \n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "zVUsc1eA1u2K"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def square_until_stop(x, y):\n",
" ... # fill it in!\n",
" \n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "L2psuzPI02S9"
},
"source": [
"#### Uncollapse for the answer\n"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "ucmZyQVL03bF"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def square_until_stop(x, y):\n",
" while x \u003c y:\n",
" x = x * x\n",
" return x\n",
" \n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" print(sess.run(square_until_stop(tf.constant(4), tf.constant(100))))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "FXB0Zbwl13PY"
},
"source": [
"#### Nested loop and conditional"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "clGymxdf15Ig"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def argwhere_cumsum(x, threshold):\n",
" current_sum = 0.0\n",
" idx = 0\n",
" \n",
" for i in range(len(x)):\n",
" idx = i\n",
" if current_sum \u003e= threshold:\n",
" break\n",
" current_sum += x[i]\n",
" return idx\n",
"\n",
"N = 10\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
" print(sess.run(idx))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "i7PF-uId9lp5"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def argwhere_cumsum(x, threshold):\n",
" ...\n",
"\n",
"N = 10\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
" print(sess.run(idx))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "weKFXAb615Vp"
},
"source": [
"#### Uncollapse to see answer"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "1sjaFcL717Ig"
},
"outputs": [],
"source": [
"@autograph.convert()\n",
"def argwhere_cumsum(x, threshold):\n",
" current_sum = 0.0\n",
" idx = 0\n",
" for i in range(len(x)):\n",
" idx = i\n",
" if current_sum \u003e= threshold:\n",
" break\n",
" current_sum += x[i]\n",
" return idx\n",
"\n",
"N = 10\n",
"with tf.Graph().as_default(): \n",
" with tf.Session() as sess:\n",
" idx = argwhere_cumsum(tf.ones(N), tf.constant(float(N/2)))\n",
" print(sess.run(idx))"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "4LfnJjm0Bm0B"
},
"source": [
"# 3. Training MNIST in-graph\n",
"\n",
"Writing control flow in AutoGraph is easy, so running a training loop in a TensorFlow graph should be easy as well! \n",
"\n",
"Here, we show an example of training a simple Keras model on MNIST, where the entire training process -- loading batches, calculating gradients, updating parameters, calculating validation accuracy, and repeating until convergence -- is done in-graph."
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Em5dzSUOtLRP"
},
"source": [
"#### Download data"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "xqoxumv0ssQW"
},
"outputs": [],
"source": [
"import gzip\n",
"import shutil\n",
"\n",
"from six.moves import urllib\n",
"\n",
"\n",
"def download(directory, filename):\n",
" filepath = os.path.join(directory, filename)\n",
" if tf.gfile.Exists(filepath):\n",
" return filepath\n",
" if not tf.gfile.Exists(directory):\n",
" tf.gfile.MakeDirs(directory)\n",
" url = 'https://storage.googleapis.com/cvdf-datasets/mnist/' + filename + '.gz'\n",
" zipped_filepath = filepath + '.gz'\n",
" print('Downloading %s to %s' % (url, zipped_filepath))\n",
" urllib.request.urlretrieve(url, zipped_filepath)\n",
" with gzip.open(zipped_filepath, 'rb') as f_in, open(filepath, 'wb') as f_out:\n",
" shutil.copyfileobj(f_in, f_out)\n",
" os.remove(zipped_filepath)\n",
" return filepath\n",
"\n",
"\n",
"def dataset(directory, images_file, labels_file):\n",
" images_file = download(directory, images_file)\n",
" labels_file = download(directory, labels_file)\n",
"\n",
" def decode_image(image):\n",
" # Normalize from [0, 255] to [0.0, 1.0]\n",
" image = tf.decode_raw(image, tf.uint8)\n",
" image = tf.cast(image, tf.float32)\n",
" image = tf.reshape(image, [784])\n",
" return image / 255.0\n",
"\n",
" def decode_label(label):\n",
" label = tf.decode_raw(label, tf.uint8)\n",
" label = tf.reshape(label, [])\n",
" return tf.to_int32(label)\n",
"\n",
" images = tf.data.FixedLengthRecordDataset(\n",
" images_file, 28 * 28, header_bytes=16).map(decode_image)\n",
" labels = tf.data.FixedLengthRecordDataset(\n",
" labels_file, 1, header_bytes=8).map(decode_label)\n",
" return tf.data.Dataset.zip((images, labels))\n",
"\n",
"\n",
"def mnist_train(directory):\n",
" return dataset(directory, 'train-images-idx3-ubyte',\n",
" 'train-labels-idx1-ubyte')\n",
"\n",
"def mnist_test(directory):\n",
" return dataset(directory, 't10k-images-idx3-ubyte', 't10k-labels-idx1-ubyte')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "znmy4l8ntMvW"
},
"source": [
"#### Define the model"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "Pe-erWQdBoC5"
},
"outputs": [],
"source": [
"def mlp_model(input_shape):\n",
" model = tf.keras.Sequential((\n",
" tf.keras.layers.Dense(100, activation='relu', input_shape=input_shape),\n",
" tf.keras.layers.Dense(100, activation='relu'),\n",
" tf.keras.layers.Dense(10, activation='softmax')))\n",
" model.build()\n",
" return model\n",
"\n",
"\n",
"def predict(m, x, y):\n",
" y_p = m(x)\n",
" losses = tf.keras.losses.categorical_crossentropy(y, y_p)\n",
" l = tf.reduce_mean(losses)\n",
" accuracies = tf.keras.metrics.categorical_accuracy(y, y_p)\n",
" accuracy = tf.reduce_mean(accuracies)\n",
" return l, accuracy\n",
"\n",
"\n",
"def fit(m, x, y, opt):\n",
" l, accuracy = predict(m, x, y)\n",
" opt.minimize(l)\n",
" return l, accuracy\n",
"\n",
"\n",
"def setup_mnist_data(is_training, hp, batch_size):\n",
" if is_training:\n",
" ds = mnist_train('/tmp/autograph_mnist_data')\n",
" ds = ds.shuffle(batch_size * 10)\n",
" else:\n",
" ds = mnist_test('/tmp/autograph_mnist_data')\n",
" ds = ds.repeat()\n",
" ds = ds.batch(batch_size)\n",
" return ds\n",
"\n",
"\n",
"def get_next_batch(ds):\n",
" itr = ds.make_one_shot_iterator()\n",
" image, label = itr.get_next()\n",
" x = tf.to_float(tf.reshape(image, (-1, 28 * 28)))\n",
" y = tf.one_hot(tf.squeeze(label), 10)\n",
" return x, y"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oeYV6mKnJGMr"
},
"source": [
"#### Define the training loop"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "3xtg_MMhJETd"
},
"outputs": [],
"source": [
"def train(train_ds, test_ds, hp):\n",
" m = mlp_model((28 * 28,))\n",
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n",
" \n",
" # We'd like to save our losses to a list. In order for AutoGraph\n",
" # to convert these lists into their graph equivalent,\n",
" # we need to specify the element type of the lists.\n",
" train_losses = []\n",
" train_losses = autograph.utils.set_element_type(train_losses, tf.float32)\n",
" test_losses = []\n",
" test_losses = autograph.utils.set_element_type(test_losses, tf.float32)\n",
" train_accuracies = []\n",
" train_accuracies = autograph.utils.set_element_type(train_accuracies, tf.float32)\n",
" test_accuracies = []\n",
" test_accuracies = autograph.utils.set_element_type(test_accuracies, tf.float32)\n",
" \n",
" # This entire training loop will be run in-graph.\n",
" i = tf.constant(0)\n",
" while i \u003c hp.max_steps:\n",
" train_x, train_y = get_next_batch(train_ds)\n",
" test_x, test_y = get_next_batch(test_ds)\n",
" # add get next\n",
" step_train_loss, step_train_accuracy = fit(m, train_x, train_y, opt)\n",
" step_test_loss, step_test_accuracy = predict(m, test_x, test_y)\n",
" if i % (hp.max_steps // 10) == 0:\n",
" print('Step', i, 'train loss:', step_train_loss, 'test loss:',\n",
" step_test_loss, 'train accuracy:', step_train_accuracy,\n",
" 'test accuracy:', step_test_accuracy)\n",
" train_losses.append(step_train_loss)\n",
" test_losses.append(step_test_loss)\n",
" train_accuracies.append(step_train_accuracy)\n",
" test_accuracies.append(step_test_accuracy)\n",
" i += 1\n",
" \n",
" # We've recorded our loss values and accuracies \n",
" # to a list in a graph with AutoGraph's help.\n",
" # In order to return the values as a Tensor, \n",
" # we need to stack them before returning them.\n",
" return (autograph.stack(train_losses), autograph.stack(test_losses), autograph.stack(train_accuracies),\n",
" autograph.stack(test_accuracies))"
]
},
{
"cell_type": "code",
"execution_count": 0,
"metadata": {
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
},
"colab_type": "code",
"id": "HYh6MSZyJOag"
},
"outputs": [],
"source": [
"with tf.Graph().as_default():\n",
" hp = tf.contrib.training.HParams(\n",
" learning_rate=0.05,\n",
" max_steps=500,\n",
" )\n",
" train_ds = setup_mnist_data(True, hp, 50)\n",
" test_ds = setup_mnist_data(False, hp, 1000)\n",
" tf_train = autograph.to_graph(train)\n",
" (train_losses, test_losses, train_accuracies,\n",
" test_accuracies) = tf_train(train_ds, test_ds, hp)\n",
"\n",
" with tf.Session() as sess:\n",
" sess.run(tf.global_variables_initializer())\n",
" (train_losses, test_losses, train_accuracies,\n",
" test_accuracies) = sess.run([train_losses, test_losses, train_accuracies,\n",
" test_accuracies])\n",
" plt.title('MNIST train/test losses')\n",
" plt.plot(train_losses, label='train loss')\n",
" plt.plot(test_losses, label='test loss')\n",
" plt.legend()\n",
" plt.xlabel('Training step')\n",
" plt.ylabel('Loss')\n",
" plt.show()\n",
" plt.title('MNIST train/test accuracies')\n",
" plt.plot(train_accuracies, label='train accuracy')\n",
" plt.plot(test_accuracies, label='test accuracy')\n",
" plt.legend(loc='lower right')\n",
" plt.xlabel('Training step')\n",
" plt.ylabel('Accuracy')\n",
" plt.show()"
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [
"qqsjik-QyA9R",
"b9AXIkNLxp6J",
"L2psuzPI02S9",
"weKFXAb615Vp",
"Em5dzSUOtLRP"
],
"default_view": {},
"name": "AutoGraph Workshop.ipynb",
"provenance": [
{
"file_id": "1kE2gz_zuwdYySL4K2HQSz13uLCYi-fYP",
"timestamp": 1530563781803
}
],
"version": "0.3.2",
"views": {}
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment