"For a working end-to-end example, download our [example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_main.py) and set the `model_type` flag to `wide`."
"For a working end-to-end example, download our [example code](https://github.com/tensorflow/models/tree/master/official/wide_deep/census_main.py) and set the `model_type` flag to `wide`."
]
]
},
{
"metadata": {
"id": "oyKy1lM_3gkL",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"## Adding Regularization to Prevent Overfitting\n",
"\n",
"Regularization is a technique used to avoid overfitting. Overfitting happens when a model performs well on the data it is trained on, but worse on test data that the model has not seen before. Overfitting can occur when a model is excessively complex, such as having too many parameters relative to the number of observed training data. Regularization allows you to control the model's complexity and make the model more generalizable to unseen data.\n",
"\n",
"You can add L1 and L2 regularizations to the model with the following code:"
"These regularized models don't perform much better than the base model. Let's look at the model's weight distributions to better see the effect of the regularization:"
]
},
{
"metadata": {
"id": "Wb6093N04XlS",
"colab_type": "code",
"colab": {}
},
"cell_type": "code",
"source": [
"def get_flat_weights(model):\n",
" weight_names = [\n",
" name for name in model.get_variable_names()\n",
" if \"linear_model\" in name and \"Ftrl\" not in name]\n",
"\n",
" weight_values = [model.get_variable_value(name) for name in weight_names]\n",
"\n",
" weights_flat = np.concatenate([item.flatten() for item in weight_values], axis=0)\n",
"\n",
" return weights_flat\n",
"\n",
"weights_flat = get_flat_weights(model)\n",
"weights_flat_l1 = get_flat_weights(model_l1)\n",
"weights_flat_l2 = get_flat_weights(model_l2)"
],
"execution_count": 0,
"outputs": []
},
{
"metadata": {
"id": "GskJmtfmL0p-",
"colab_type": "text"
},
"cell_type": "markdown",
"source": [
"The models have many zero-valued weights caused by unused hash bins (there are many more hash bins than categories in some columns). We can mask these weights when viewing the weight distributions:"
"Both types of regularization squeeze the distribution of weights towards zero. L2 regularization has a greater effect in the tails of the distribution eliminating extreme weights. L1 regularization produces more exactly-zero values, in this case it sets ~200 to zero."