Commit 1093f4a5 authored by Mark Daoust's avatar Mark Daoust
Browse files

Re-add regularization section, with weight distributions.

parent b07b494e
......@@ -7,7 +7,13 @@
"version": "0.3.2",
"views": {},
"default_view": {},
"provenance": [],
"provenance": [
{
"file_id": "https://github.com/tensorflow/models/blob/master/samples/core/tutorials/estimators/linear.ipynb",
"timestamp": 1531763859998
}
],
"private_outputs": true,
"collapsed_sections": [
"MWW1TyjaecRh"
],
......@@ -17,7 +23,8 @@
"display_name": "Python 3",
"language": "python",
"name": "python3"
}
},
"accelerator": "GPU"
},
"cells": [
{
......@@ -126,6 +133,8 @@
"\n",
"import os\n",
"import sys\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from IPython.display import clear_output"
],
"execution_count": 0,
......@@ -257,7 +266,16 @@
},
"cell_type": "markdown",
"source": [
"### Command line usage\n",
"### Command line usage\n"
]
},
{
"metadata": {
"id": "ZhMGuU8v2sxh",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"\n",
"The repo includes a complete program for experimenting with this type of model.\n",
"\n",
......@@ -399,7 +417,7 @@
},
"cell_type": "markdown",
"source": [
"[pandas](https://pandas.pydata.org/) provides some convenient utilities for data analysis. Here's a list of columns available in the Census Income dataset:"
"[Pandas](https://pandas.pydata.org/) provides some convenient utilities for data analysis. Here's a list of columns available in the Census Income dataset:"
]
},
{
......@@ -1267,7 +1285,9 @@
"]\n",
"\n",
"model = tf.estimator.LinearClassifier(\n",
" model_dir=tempfile.mkdtemp(), feature_columns=base_columns + crossed_columns)"
" model_dir=tempfile.mkdtemp(), \n",
" feature_columns=base_columns + crossed_columns,\n",
" optimizer=tf.train.FtrlOptimizer(learning_rate=0.1))"
],
"execution_count": 0,
"outputs": []
......@@ -1299,6 +1319,9 @@
},
"cell_type": "code",
"source": [
"train_inpf = functools.partial(census_dataset.input_fn, train_file, \n",
" num_epochs=40, shuffle=True, batch_size=64)\n",
"\n",
"model.train(train_inpf)\n",
"\n",
"clear_output() # used for notebook display"
......@@ -1399,6 +1422,211 @@
"source": [
"For a working end-to-end example, download our [example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_main.py) and set the `model_type` flag to `wide`."
]
},
{
"metadata": {
"id": "oyKy1lM_3gkL",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Adding Regularization to Prevent Overfitting\n",
"\n",
"Regularization is a technique used to avoid **overfitting**. Overfitting happens\n",
"when your model does well on the data it is trained on, but worse on test data\n",
"that the model has not seen before, such as live traffic. Overfitting generally\n",
"occurs when a model is excessively complex, such as having too many parameters\n",
"relative to the number of observed training data. Regularization allows for you\n",
"to control your model's complexity and makes the model more generalizable to\n",
"unseen data.\n",
"\n",
"You can add L1 and L2 regularizations to the model with the following code:"
]
},
{
"metadata": {
"id": "lzMUSBQ03hHx",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"model_l1 = tf.estimator.LinearClassifier(\n",
" feature_columns=base_columns + crossed_columns,\n",
" optimizer=tf.train.FtrlOptimizer(\n",
" learning_rate=0.1,\n",
" l1_regularization_strength=10.0,\n",
" l2_regularization_strength=0.0))\n",
"\n",
"model_l1.train(train_inpf)\n",
"\n",
"results = model_l1.evaluate(test_inpf)\n",
"clear_output()\n",
"for key in sorted(results):\n",
" print('%s: %0.2f' % (key, results[key]))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "ofmPL212JIy2",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"model_l2 = tf.estimator.LinearClassifier(\n",
" feature_columns=base_columns + crossed_columns,\n",
" optimizer=tf.train.FtrlOptimizer(\n",
" learning_rate=0.1,\n",
" l1_regularization_strength=0.0,\n",
" l2_regularization_strength=10.0))\n",
"\n",
"model_l2.train(train_inpf)\n",
"\n",
"results = model_l2.evaluate(test_inpf)\n",
"clear_output()\n",
"for key in sorted(results):\n",
" print('%s: %0.2f' % (key, results[key]))"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Lp1Rfy_k4e7w",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"These regularized models don't perform very differently base model. Let look ar the models' weight distributions to better see the effect of the regularization:"
]
},
{
"metadata": {
"id": "Wb6093N04XlS",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"def get_flat_weights(model):\n",
" weight_names = [\n",
" name for name in model.get_variable_names()\n",
" if \"linear_model\" in name and \"Ftrl\" not in name]\n",
"\n",
" weight_values = [model.get_variable_value(name) for name in weight_names]\n",
"\n",
" weights_flat = np.concatenate([item.flatten() for item in weight_values], axis=0)\n",
"\n",
" return weights_flat\n",
"\n",
"weights_flat = get_flat_weights(model)\n",
"weights_flat_l1 = get_flat_weights(model_l1)\n",
"weights_flat_l2 = get_flat_weights(model_l2)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "GskJmtfmL0p-",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"The models have many zero-vlaued weights caused by unused hash bins (There are many more hash bins than categories in some columns). We will mask these weights when viewing the weight distributions:"
]
},
{
"metadata": {
"id": "rM3agZe3MT3D",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"weight_mask = weights_flat != 0\n",
"\n",
"weights_base = weights_flat[weight_mask]\n",
"weights_l1 = weights_flat_l1[weight_mask]\n",
"weights_l2 = weights_flat_l2[weight_mask]"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "NqBpxLLQNEBE",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Now plot the distributions:"
]
},
{
"metadata": {
"id": "IdFK7wWa5_0K",
"colab_type": "code",
"colab": {
"autoexec": {
"startup": false,
"wait_interval": 0
}
}
},
"cell_type": "code",
"source": [
"plt.figure()\n",
"_ = plt.hist(weights_base, bins=np.linspace(-3,3,30))\n",
"plt.title('Base Model')\n",
"plt.ylim([0,500])\n",
"\n",
"plt.figure()\n",
"_ = plt.hist(weights_l1, bins=np.linspace(-3,3,30))\n",
"plt.title('L1 - Regularization')\n",
"plt.ylim([0,500])\n",
"\n",
"plt.figure()\n",
"_ = plt.hist(weights_l2, bins=np.linspace(-3,3,30))\n",
"plt.title('L2 - Regularization')\n",
"_=plt.ylim([0,500])\n",
"\n"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "Mv6knhFa5-iJ",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"Both types of regularization squeeze the distribution of weights towards zero. L2 regularization has a greater effect in the tails of the distribution eliminating extreme weights. L1 regularization produces more exactly-zero values (In this case it sets ~200 to zero)."
]
}
]
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment