{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Overview" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This example demonstrates the use `tf.feature_column.crossed_column` on some simulated Atlanta housing price data. \n", "This spatial data is used primarily so the results can be easily visualized. \n", "\n", "These functions are designed primarily for categorical data, not to build interpolation tables. \n", "\n", "If you actually want to build smart interpolation tables in TensorFlow you may want to consider [TensorFlow Lattice](https://research.googleblog.com/2017/10/tensorflow-lattice-flexibility.html)." ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "yEHBFimYk-Mu" }, "source": [ "# Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "DiAklWTFk-My" }, "outputs": [], "source": [ "import os\n", "import subprocess\n", "import tempfile\n", "\n", "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "fHqxuCUu8Bvm" }, "outputs": [], "source": [ "assert tf.VERSION.split('.') >= ['1','4']" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "%matplotlib inline\n", "mpl.rcParams['figure.figsize'] = 12, 6\n", "mpl.rcParams['image.cmap'] = 'viridis'" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "Oj4Jv4Pik-M1" }, "outputs": [], "source": [ "logdir = tempfile.mkdtemp()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "output_extras": [ {} ] }, "colab_type": "code", "id": "cTrSkk1zmvO0", "outputId": "41532b3b-2bf8-4abb-bc46-a92c76fe70f8" }, "outputs": [], "source": [ "logdir" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "p3me9zPGk-M3" }, "source": [ "# Start TensorBoard\n", "The following command will kill all running TensorBoard processes, and start a new one monitoring to the above logdir. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 34, "output_extras": [ {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 327, "status": "ok", "timestamp": 1508962289209, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "umxoWdz9k-M3", "outputId": "e2d2af2f-e56f-4b5e-aa62-bd9119966b53" }, "outputs": [], "source": [ "subprocess.Popen(['pkill','-f','tensorboard'])\n", "subprocess.Popen(['tensorboard', '--logdir', logdir])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WFYz5eg1k-M7" }, "source": [ "# Build Synthetic Data" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "F3ouYc9N9zW3" }, "outputs": [], "source": [ "# Define the grid\n", "min_latitude = 33.641336\n", "max_latitude = 33.887157\n", "delta_latitude = max_latitude-min_latitude\n", "\n", "min_longitude = -84.558798\n", "max_longitude = -84.287259\n", "delta_longitude = max_longitude-min_longitude\n", "\n", "resolution = 100" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Use RandomState so the behavior is repeatable. \n", "R = np.random.RandomState(1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "JQ7wXpURk-M8" }, "outputs": [], "source": [ "# The price data will be a sum of Gaussians, at random locations.\n", "n_centers = 20\n", "centers = R.rand(n_centers, 2) # shape: (centers, dimensions)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "nR1wQiqSk-NA" }, "outputs": [], "source": [ "# Each Gaussian has a maximum price contribution, at the center.\n", "# Price_\n", "price_delta = 0.5+2*R.rand(n_centers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "rGX3j-DOk-NC" }, "outputs": [], "source": [ "# Each Gaussian also has a standard-deviation and variance.\n", "std = 0.2*R.rand(n_centers) # shape: (centers)\n", "var = std**2" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "kWu2ba4Tk-NE" }, "outputs": [], "source": [ "def price(latitude, longitude):\n", " # Convert latitude, longitude to x,y in [0,1]\n", " x = (longitude - min_longitude)/delta_longitude\n", " y = (latitude - min_latitude)/delta_latitude\n", " \n", " # Cache the shape, and flatten the inputs.\n", " shape = x.shape\n", " assert y.shape == x.shape\n", " x = x.flatten()\n", " y = y.flatten()\n", " \n", " # Convert x, y examples into an array with shape (examples, dimensions)\n", " xy = np.array([x,y]).T\n", "\n", " # Calculate the square distance from each example to each center. \n", " components2 = (xy[:,None,:] - centers[None,:,:])**2 # shape: (examples, centers, dimensions)\n", " r2 = components2.sum(axis=2) # shape: (examples, centers)\n", " \n", " # Calculate the z**2 for each example from each center.\n", " z2 = r2/var[None,:]\n", " price = (np.exp(-z2)*price_delta).sum(1) # shape: (examples,)\n", " \n", " # Restore the original shape.\n", " return price.reshape(shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "BPoSndW8k-NG" }, "outputs": [], "source": [ "# Build the grid. We want `resolution` cells between `min` and `max` on each dimension\n", "# so we need `resolution+1` evenly spaced edges. The centers are at the average of the\n", "# upper and lower edge. \n", "\n", "latitude_edges = np.linspace(min_latitude, max_latitude, resolution+1)\n", "latitude_centers = (latitude_edges[:-1] + latitude_edges[1:])/2\n", "\n", "longitude_edges = np.linspace(min_longitude, max_longitude, resolution+1)\n", "longitude_centers = (longitude_edges[:-1] + longitude_edges[1:])/2\n", "\n", "latitude_grid, longitude_grid = np.meshgrid(\n", " latitude_centers,\n", " longitude_centers)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "output_extras": [ {} ] }, "colab_type": "code", "id": "0Y5fSCpWk-NI", "outputId": "35737491-93bd-4911-cb6e-8163849983a3" }, "outputs": [], "source": [ "# Evaluate the price at each center-point\n", "actual_price_grid = price(latitude_grid, longitude_grid)\n", "\n", "price_min = actual_price_grid.min()\n", "price_max = actual_price_grid.max()\n", "price_mean = actual_price_grid.mean()\n", "price_mean" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "wN2wMUOck-NK" }, "outputs": [], "source": [ "def show_price(price):\n", " plt.imshow(\n", " price, \n", " # The color axis goes from `price_min` to `price_max`.\n", " vmin=price_min, vmax=price_max,\n", " # Put the image at the correct latitude and longitude.\n", " extent=(min_longitude, max_longitude, min_latitude, max_latitude), \n", " # Make the image square.\n", " aspect = 1.0*delta_longitude/delta_latitude)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 592, "output_extras": [ {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 678, "status": "ok", "timestamp": 1508962293265, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "FkHXxJGuAsC1", "outputId": "40bf2c0f-51b0-4026-a275-c3669e5366cc" }, "outputs": [], "source": [ "show_price(actual_price_grid)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "miDtqLRek-NM" }, "source": [ "# Build Datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "zqBPKjBvk-NM" }, "outputs": [], "source": [ "# For test data we will use the grid centers.\n", "test_features = {'latitude':latitude_grid.flatten(), 'longitude':longitude_grid.flatten()}\n", "test_ds = tf.data.Dataset.from_tensor_slices((test_features, \n", " actual_price_grid.flatten()))\n", "test_ds = test_ds.cache().batch(512).prefetch(1)\n", "\n", "# For training data we will use a set of random points.\n", "train_latitude = min_latitude + np.random.rand(50000)*delta_latitude\n", "train_longitude = min_longitude + np.random.rand(50000)*delta_longitude\n", "train_price = price(train_latitude, train_longitude)\n", "\n", "train_features = {'latitude':train_latitude, 'longitude':train_longitude}\n", "train_ds = tf.data.Dataset.from_tensor_slices((train_features, train_price))\n", "train_ds = train_ds.cache().repeat().shuffle(100000).batch(512).prefetch(1)\n", "\n", "# A shortcut to build an `input_fn` from a `Dataset`\n", "def in_fn(ds):\n", " return lambda : ds.make_one_shot_iterator().get_next()\n" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "i-ToQzSqk-NO" }, "source": [ "# Generate a plot from an Estimator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "N9NmQLUOk-NP" }, "outputs": [], "source": [ "def plot_est(est, ds = test_ds):\n", " # Create two plot axes\n", " actual, predicted = plt.subplot(1,2,1), plt.subplot(1,2,2)\n", "\n", " # Plot the actual price.\n", " plt.sca(actual)\n", " show_price(actual_price_grid.reshape(resolution, resolution))\n", " \n", " # Generate predictions over the grid from the estimator.\n", " pred = est.predict(in_fn(ds))\n", " # Convert them to a numpy array.\n", " pred = np.fromiter((item['predictions'] for item in pred), np.float32)\n", " # Plot the predictions on the secodn axis.\n", " plt.sca(predicted)\n", " show_price(pred.reshape(resolution, resolution))" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "gfgngu0lk-NQ" }, "source": [ "# Using `numeric_column` with DNNRegressor\n", "Important: Pure categorical data doesn't the spatial relationships that make this example possible. Embeddings are a way your model can learn spatial relationships." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 2010, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 8234, "status": "ok", "timestamp": 1508962479411, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "On-_j4Jtk-NR", "outputId": "1c18ba56-7246-4096-c32d-822c2fe3dd36" }, "outputs": [], "source": [ "# Use `normalizer_fn` so that the model only sees values in [0, 1]\n", "norm_latitude = lambda latitude:(latitude-min_latitude)/delta_latitude - 0.5\n", "norm_longitude = lambda longitude:(longitude-min_longitude)/delta_longitude - 0.5\n", "\n", "fc = [tf.feature_column.numeric_column('latitude', normalizer_fn = norm_latitude), \n", " tf.feature_column.numeric_column('longitude', normalizer_fn = norm_longitude)]\n", "\n", "# Build and train the Estimator\n", "est = tf.estimator.DNNRegressor(\n", " hidden_units=[100,100], \n", " feature_columns=fc, \n", " model_dir = os.path.join(logdir,'DNN'))\n", "\n", "est.train(in_fn(train_ds), steps = 5000)\n", "est.evaluate(in_fn(test_ds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 388, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 1002, "status": "ok", "timestamp": 1508962552800, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "WTcIX78Tk-NV", "outputId": "8d08be57-7c37-4268-e38a-9675978567ff" }, "outputs": [], "source": [ "plot_est(est)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WIlTd5VEk-NZ" }, "source": [ "# Using `bucketized_column`" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 2010, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 9805, "status": "ok", "timestamp": 1508962565572, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "NdKt4g2Kk-NZ", "outputId": "a72e8a55-0239-4b42-e66b-a57ba79fb47b" }, "outputs": [], "source": [ "# Bucketize the latitude and longitude usig the `edges`\n", "latitude_bucket_fc = tf.feature_column.bucketized_column(\n", " tf.feature_column.numeric_column('latitude'), \n", " list(latitude_edges))\n", "\n", "longitude_bucket_fc = tf.feature_column.bucketized_column(\n", " tf.feature_column.numeric_column('longitude'),\n", " list(longitude_edges))\n", "\n", "fc = [\n", " latitude_bucket_fc,\n", " longitude_bucket_fc]\n", "\n", "# Build and train the Estimator.\n", "est = tf.estimator.LinearRegressor(fc, model_dir = os.path.join(logdir,'separable'))\n", "est.train(in_fn(train_ds), steps = 5000)\n", "est.evaluate(in_fn(test_ds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 388, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 1152, "status": "ok", "timestamp": 1508962568631, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "j4lhsk5rk-Nc", "outputId": "f12b5ba5-1d28-46b2-b1d6-322a60ba40d9" }, "outputs": [], "source": [ "plot_est(est)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "G-QLbA0hKBbY" }, "source": [ "# Using `crossed_column` on its own.\n", "The single-cell \"holes\" in the figure are caused by cells which do not contain examples." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 2010, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 11707, "status": "ok", "timestamp": 1508962619513, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "JoIXtYykKJei", "outputId": "08e45310-f006-454a-8bb4-4f14466c5013" }, "outputs": [], "source": [ "# Cross the bucketized columns, using 5000 hash bins (for an average weight sharing of 2).\n", "crossed_lat_lon_fc = tf.feature_column.crossed_column(\n", " [latitude_bucket_fc, longitude_bucket_fc], int(5e3))\n", "\n", "fc = [crossed_lat_lon_fc]\n", "\n", "# Build and train the Estimator.\n", "est = tf.estimator.LinearRegressor(fc, model_dir=os.path.join(logdir, 'crossed'))\n", "\n", "est.train(in_fn(train_ds), steps = 5000)\n", "est.evaluate(in_fn(test_ds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 388, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 1112, "status": "ok", "timestamp": 1508962600110, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "R-itu9itLe0K", "outputId": "449a6d84-c6f5-4582-f8e3-0f007ae68e8a", "scrolled": false }, "outputs": [], "source": [ "plot_est(est)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "KHJ_KsRUk-Nj" }, "source": [ "# Using raw categories with `crossed_column` \n", "The model generalizes better if it also has access to the raw categories, outside of the cross. " ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 2010, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 13233, "status": "ok", "timestamp": 1508963622115, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "ukHo6NrTk-Nk", "outputId": "12fba55e-c496-4007-b7b7-11b1aad38ba8" }, "outputs": [], "source": [ "fc = [\n", " latitude_bucket_fc,\n", " longitude_bucket_fc,\n", " crossed_lat_lon_fc]\n", "\n", "# Build and train the Estimator.\n", "est = tf.estimator.LinearRegressor(fc, model_dir=os.path.join(logdir, 'both'))\n", "est.train(in_fn(train_ds), steps = 5000)\n", "est.evaluate(in_fn(test_ds))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 388, "output_extras": [ {}, {} ] }, "colab_type": "code", "executionInfo": { "elapsed": 1307, "status": "ok", "timestamp": 1508963623450, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "QjOwalvDk-Nm", "outputId": "f0a520a1-8253-494e-f3c8-2e6df168100a" }, "outputs": [], "source": [ "plot_est(est)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Mx66A5ETk-Ns" }, "source": [ "# Open TensorBoard" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 }, "base_uri": "https://localhost:8080/", "height": 820, "output_extras": [ { "item_id": 1 } ] }, "colab_type": "code", "executionInfo": { "elapsed": 478, "status": "ok", "timestamp": 1508986589529, "user": { "displayName": "Mark Daoust", "photoUrl": "//lh5.googleusercontent.com/-2bdrhkqhwhc/AAAAAAAAAAI/AAAAAAAAAYY/WEdKp4OXSFY/s50-c-k-no/photo.jpg", "userId": "106546680081284977106" }, "user_tz": 240 }, "id": "fESYrJamm_Z5", "outputId": "d982a677-a217-491b-ef85-93f932e6afc5" }, "outputs": [], "source": [ "%%html\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "autoexec": { "startup": false, "wait_interval": 0 } }, "colab_type": "code", "collapsed": true, "id": "_YHrJneHnA9K" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "collapsed_sections": [], "default_view": {}, "name": "Housing Prices (1).ipynb", "provenance": [], "version": "0.3.2", "views": {} }, "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.10" } }, "nbformat": 4, "nbformat_minor": 1 }