Commit a9fe9ba9 authored by Scott Zhu's avatar Scott Zhu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 366900579
parent cc12499b
......@@ -22,7 +22,6 @@ import os
from absl import flags
import tensorflow as tf
from tensorflow.python.keras.optimizer_v2 import gradient_descent as gradient_descent_v2
import tensorflow_model_optimization as tfmot
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
......@@ -109,7 +108,7 @@ class PiecewiseConstantDecayWithWarmup(
def get_optimizer(learning_rate=0.1):
"""Returns optimizer to use."""
# The learning_rate is overwritten at the beginning of each step by callback.
return gradient_descent_v2.SGD(learning_rate=learning_rate, momentum=0.9)
return tf.keras.optimizers.SGD(learning_rate=learning_rate, momentum=0.9)
def get_callbacks(pruning_method=None,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment