"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "8815a8606c45e7c6c01549179903329af2b2d442"
Commit 392b8130 authored by Dan Moldovan's avatar Dan Moldovan
Browse files

Update the training schedule for better convergence. The current...

Update the training schedule for better convergence. The current hyperparameters tend to diverge in Colab.
parent a141d020
...@@ -10,7 +10,8 @@ ...@@ -10,7 +10,8 @@
"collapsed_sections": [ "collapsed_sections": [
"Jxv6goXm7oGF" "Jxv6goXm7oGF"
], ],
"toc_visible": true "toc_visible": true,
"include_colab_link": true
}, },
"kernelspec": { "kernelspec": {
"name": "python3", "name": "python3",
...@@ -18,6 +19,16 @@ ...@@ -18,6 +19,16 @@
} }
}, },
"cells": [ "cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"[View in Colaboratory](https://colab.research.google.com/github/mdanatg/models/blob/master/samples/core/guide/autograph.ipynb)"
]
},
{ {
"metadata": { "metadata": {
"id": "Jxv6goXm7oGF", "id": "Jxv6goXm7oGF",
...@@ -740,7 +751,7 @@ ...@@ -740,7 +751,7 @@
"@autograph.convert(recursive=True)\n", "@autograph.convert(recursive=True)\n",
"def train(train_ds, test_ds, hp):\n", "def train(train_ds, test_ds, hp):\n",
" m = mlp_model((28 * 28,))\n", " m = mlp_model((28 * 28,))\n",
" opt = tf.train.MomentumOptimizer(hp.learning_rate, 0.9)\n", " opt = tf.train.AdamOptimizer(hp.learning_rate)\n",
" \n", " \n",
" # We'd like to save our losses to a list. In order for AutoGraph\n", " # We'd like to save our losses to a list. In order for AutoGraph\n",
" # to convert these lists into their graph equivalent,\n", " # to convert these lists into their graph equivalent,\n",
...@@ -802,7 +813,7 @@ ...@@ -802,7 +813,7 @@
"source": [ "source": [
"with tf.Graph().as_default() as g:\n", "with tf.Graph().as_default() as g:\n",
" hp = tf.contrib.training.HParams(\n", " hp = tf.contrib.training.HParams(\n",
" learning_rate=0.05,\n", " learning_rate=0.005,\n",
" max_steps=500,\n", " max_steps=500,\n",
" )\n", " )\n",
" train_ds = setup_mnist_data(True, 50)\n", " train_ds = setup_mnist_data(True, 50)\n",
...@@ -837,4 +848,4 @@ ...@@ -837,4 +848,4 @@
"outputs": [] "outputs": []
} }
] ]
} }
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment