Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4d543417
Commit
4d543417
authored
Aug 02, 2017
by
Toby Boyd
Committed by
GitHub
Aug 02, 2017
Browse files
Merge pull request #2090 from mari-linhares/patch-8
Adding parameter to change learning_rate
parents
c3b69841
99cb3f70
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
13 deletions
+24
-13
tutorials/image/cifar10_estimator/cifar10.py
tutorials/image/cifar10_estimator/cifar10.py
+1
-1
tutorials/image/cifar10_estimator/cifar10_main.py
tutorials/image/cifar10_estimator/cifar10_main.py
+23
-12
No files found.
tutorials/image/cifar10_estimator/cifar10.py
View file @
4d543417
...
@@ -64,7 +64,7 @@ class Cifar10DataSet(object):
...
@@ -64,7 +64,7 @@ class Cifar10DataSet(object):
tf
.
float32
)
tf
.
float32
)
label
=
tf
.
cast
(
features
[
'label'
],
tf
.
int32
)
label
=
tf
.
cast
(
features
[
'label'
],
tf
.
int32
)
# Custom preprocessing
.
# Custom preprocessing.
image
=
self
.
preprocess
(
image
)
image
=
self
.
preprocess
(
image
)
return
image
,
label
return
image
,
label
...
...
tutorials/image/cifar10_estimator/cifar10_main.py
View file @
4d543417
...
@@ -73,23 +73,32 @@ tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
...
@@ -73,23 +73,32 @@ tf.flags.DEFINE_float('momentum', 0.9, 'Momentum for MomentumOptimizer.')
tf
.
flags
.
DEFINE_float
(
'weight_decay'
,
2e-4
,
'Weight decay for convolutions.'
)
tf
.
flags
.
DEFINE_float
(
'weight_decay'
,
2e-4
,
'Weight decay for convolutions.'
)
tf
.
flags
.
DEFINE_float
(
'learning_rate'
,
0.1
,
"""This is the inital learning rate value.
The learning rate will decrease during training.
For more details check the model_fn implementation
in this file.
"""
.)
tf
.
flags
.
DEFINE_boolean
(
'use_distortion_for_training'
,
True
,
tf
.
flags
.
DEFINE_boolean
(
'use_distortion_for_training'
,
True
,
'If doing image distortion for training.'
)
'If doing image distortion for training.'
)
tf
.
flags
.
DEFINE_boolean
(
'run_experiment'
,
False
,
tf
.
flags
.
DEFINE_boolean
(
'run_experiment'
,
False
,
'If True will run an experiment,'
"""If True will run an experiment,
'otherwise will run training and evaluation'
otherwise will run training and evaluation
'using the estimator interface.'
using the estimator interface.
'Experiments perform training on several workers in'
Experiments perform training on several workers in
'parallel, in other words experiments know how to'
parallel, in other words experiments know how to
' invoke train and eval in a sensible fashion for'
invoke train and eval in a sensible fashion for
' distributed training.'
)
distributed training.
"""
)
tf
.
flags
.
DEFINE_boolean
(
'sync'
,
False
,
tf
.
flags
.
DEFINE_boolean
(
'sync'
,
False
,
'If true when running in a distributed environment'
"""If true when running in a distributed environment
'will run on sync mode'
)
will run on sync mode.
"""
)
tf
.
flags
.
DEFINE_integer
(
'num_workers'
,
1
,
'Number of workers'
)
tf
.
flags
.
DEFINE_integer
(
'num_workers'
,
1
,
'Number of workers
.
'
)
# Perf flags
# Perf flags
tf
.
flags
.
DEFINE_integer
(
'num_intra_threads'
,
1
,
tf
.
flags
.
DEFINE_integer
(
'num_intra_threads'
,
1
,
...
@@ -233,7 +242,7 @@ def _resnet_model_fn(features, labels, mode):
...
@@ -233,7 +242,7 @@ def _resnet_model_fn(features, labels, mode):
Support single host, one or more GPU training. Parameter distribution can be
Support single host, one or more GPU training. Parameter distribution can be
either one of the following scheme.
either one of the following scheme.
1. CPU is the parameter server and manages gradient updates.
1. CPU is the parameter server and manages gradient updates.
2. Paramters are distributed evenly across all GPUs, and the first GPU
2. Param
e
ters are distributed evenly across all GPUs, and the first GPU
manages gradient updates.
manages gradient updates.
Args:
Args:
...
@@ -308,7 +317,9 @@ def _resnet_model_fn(features, labels, mode):
...
@@ -308,7 +317,9 @@ def _resnet_model_fn(features, labels, mode):
num_batches_per_epoch
*
x
num_batches_per_epoch
*
x
for
x
in
np
.
array
([
82
,
123
,
300
],
dtype
=
np
.
int64
)
for
x
in
np
.
array
([
82
,
123
,
300
],
dtype
=
np
.
int64
)
]
]
staged_lr
=
[
0.1
,
0.01
,
0.001
,
0.0002
]
staged_lr
=
[
FLAGS
.
learning_rate
*
x
for
x
in
[
1
,
0.1
,
0.01
,
0.002
]]
learning_rate
=
tf
.
train
.
piecewise_constant
(
tf
.
train
.
get_global_step
(),
learning_rate
=
tf
.
train
.
piecewise_constant
(
tf
.
train
.
get_global_step
(),
boundaries
,
staged_lr
)
boundaries
,
staged_lr
)
# Create a nicely-named tensor for logging
# Create a nicely-named tensor for logging
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment