Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9bdfb04a
Unverified
Commit
9bdfb04a
authored
Mar 12, 2019
by
Toby Boyd
Committed by
GitHub
Mar 12, 2019
Browse files
V1 optimizer fix (#6350)
* optimizer back to compat.v1 * add doc string to fix lint
parent
0b0dc7f5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
6 deletions
+4
-6
official/resnet/resnet_run_loop.py
official/resnet/resnet_run_loop.py
+4
-6
No files found.
official/resnet/resnet_run_loop.py
View file @
9bdfb04a
...
@@ -266,8 +266,6 @@ def learning_rate_with_decay(
...
@@ -266,8 +266,6 @@ def learning_rate_with_decay(
false_fn
=
lambda
:
lr
)
false_fn
=
lambda
:
lr
)
return
lr
return
lr
def
poly_rate_fn
(
global_step
):
def
poly_rate_fn
(
global_step
):
"""Handles linear scaling rule, gradual warmup, and LR decay.
"""Handles linear scaling rule, gradual warmup, and LR decay.
...
@@ -277,10 +275,10 @@ def learning_rate_with_decay(
...
@@ -277,10 +275,10 @@ def learning_rate_with_decay(
decay schedule with power 2.0.
decay schedule with power 2.0.
Args:
Args:
global_step: the current global_step
global_step: the current global_step
Returns:
Returns:
returns the current learning rate
returns the current learning rate
"""
"""
# Learning rate schedule for LARS polynomial schedule
# Learning rate schedule for LARS polynomial schedule
...
@@ -318,7 +316,6 @@ def learning_rate_with_decay(
...
@@ -318,7 +316,6 @@ def learning_rate_with_decay(
if
flags
.
FLAGS
.
enable_lars
:
if
flags
.
FLAGS
.
enable_lars
:
return
poly_rate_fn
return
poly_rate_fn
return
learning_rate_fn
return
learning_rate_fn
...
@@ -360,6 +357,7 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -360,6 +357,7 @@ def resnet_model_fn(features, labels, mode, model_class,
from the loss.
from the loss.
dtype: the TensorFlow dtype to use for calculations.
dtype: the TensorFlow dtype to use for calculations.
fine_tune: If True only train the dense layers(final layers).
fine_tune: If True only train the dense layers(final layers).
label_smoothing: If greater than 0 then smooth the labels.
Returns:
Returns:
EstimatorSpec parameterized according to the input params and the
EstimatorSpec parameterized according to the input params and the
...
@@ -402,7 +400,7 @@ def resnet_model_fn(features, labels, mode, model_class,
...
@@ -402,7 +400,7 @@ def resnet_model_fn(features, labels, mode, model_class,
logits
=
logits
,
onehot_labels
=
one_hot_labels
,
logits
=
logits
,
onehot_labels
=
one_hot_labels
,
label_smoothing
=
label_smoothing
)
label_smoothing
=
label_smoothing
)
else
:
else
:
cross_entropy
=
tf
.
losses
.
sparse_softmax_cross_entropy
(
cross_entropy
=
tf
.
compat
.
v1
.
losses
.
sparse_softmax_cross_entropy
(
logits
=
logits
,
labels
=
labels
)
logits
=
logits
,
labels
=
labels
)
# Create a tensor named cross_entropy for logging purposes.
# Create a tensor named cross_entropy for logging purposes.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment