Unverified Commit f7a44074 authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Code cleanup. (#6989)

parent 415e8a45
......@@ -199,25 +199,12 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj,
default_for_fp16=128))
if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting
# `batch_size` in keras.Input layer.
if strategy and strategy.num_replicas_in_sync > 1:
# TODO(b/129791381): Specify `input_layer_batch_size` value in
# DistributionStrategy multi-replica case.
input_layer_batch_size = None
else:
input_layer_batch_size = flags_obj.batch_size
else:
input_layer_batch_size = None
if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES, dtype)
else:
model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype,
batch_size=input_layer_batch_size)
dtype=dtype)
model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment