Commit b47e5fe9 authored by Alan Chiao's avatar Alan Chiao Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 330777498
parent 095f89db
...@@ -145,7 +145,9 @@ def run(flags_obj): ...@@ -145,7 +145,9 @@ def run(flags_obj):
# output format should be same as the keras backend image data format or just # output format should be same as the keras backend image data format or just
# channel-last format. # channel-last format.
use_keras_image_data_format = \ use_keras_image_data_format = \
(flags_obj.model == 'mobilenet' or 'mobilenet_pretrained') (flags_obj.model == 'mobilenet' or
flags_obj.model == 'mobilenet_pretrained')
train_input_dataset = input_fn( train_input_dataset = input_fn(
is_training=True, is_training=True,
data_dir=flags_obj.data_dir, data_dir=flags_obj.data_dir,
...@@ -183,7 +185,7 @@ def run(flags_obj): ...@@ -183,7 +185,7 @@ def run(flags_obj):
with strategy_scope: with strategy_scope:
if flags_obj.optimizer == 'resnet50_default': if flags_obj.optimizer == 'resnet50_default':
optimizer = common.get_optimizer(lr_schedule) optimizer = common.get_optimizer(lr_schedule)
elif flags_obj.optimizer == 'mobilenet_default' or 'mobilenet_fine_tune': elif flags_obj.optimizer == 'mobilenet_default' or flags_obj.optimizer == 'mobilenet_fine_tune':
initial_learning_rate = \ initial_learning_rate = \
flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size flags_obj.initial_learning_rate_per_sample * flags_obj.batch_size
if flags_obj.optimizer == 'mobilenet_fine_tune': if flags_obj.optimizer == 'mobilenet_fine_tune':
...@@ -211,7 +213,7 @@ def run(flags_obj): ...@@ -211,7 +213,7 @@ def run(flags_obj):
elif flags_obj.model == 'resnet50_v1.5': elif flags_obj.model == 'resnet50_v1.5':
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_preprocessing.NUM_CLASSES) num_classes=imagenet_preprocessing.NUM_CLASSES)
elif flags_obj.model == 'mobilenet' or 'mobilenet_pretrained': elif flags_obj.model == 'mobilenet' or flags_obj.model == 'mobilenet_pretrained':
# TODO(kimjaehong): Remove layers attribute when minimum TF version # TODO(kimjaehong): Remove layers attribute when minimum TF version
# support 2.0 layers by default. # support 2.0 layers by default.
if flags_obj.model == 'mobilenet_pretrained': if flags_obj.model == 'mobilenet_pretrained':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment