"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "52fe5baa8649309a5f0cd78685e34c9871201650"
Unverified Commit f1a59682 authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Set input layer `batch_size` in multi-replica mode (#6578)

parent b4b8c723
...@@ -176,14 +176,9 @@ def run(flags_obj): ...@@ -176,14 +176,9 @@ def run(flags_obj):
if flags_obj.enable_xla and not flags_obj.enable_eager: if flags_obj.enable_xla and not flags_obj.enable_eager:
# TODO(b/129861005): Fix OOM issue in eager mode when setting # TODO(b/129861005): Fix OOM issue in eager mode when setting
# `batch_size` in keras.Input layer. # `batch_size` in keras.Input layer.
if strategy and strategy.num_replicas_in_sync > 1: input_layer_batch_size = flags_obj.batch_size
# TODO(b/129791381): Specify `per_replica_batch_size` value in
# DistributionStrategy multi-replica case.
per_replica_batch_size = None
else:
per_replica_batch_size = flags_obj.batch_size
else: else:
per_replica_batch_size = None input_layer_batch_size = None
if flags_obj.use_trivial_model: if flags_obj.use_trivial_model:
model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES) model = trivial_model.trivial_model(imagenet_main.NUM_CLASSES)
...@@ -191,7 +186,7 @@ def run(flags_obj): ...@@ -191,7 +186,7 @@ def run(flags_obj):
model = resnet_model.resnet50( model = resnet_model.resnet50(
num_classes=imagenet_main.NUM_CLASSES, num_classes=imagenet_main.NUM_CLASSES,
dtype=dtype, dtype=dtype,
batch_size=per_replica_batch_size) batch_size=input_layer_batch_size)
model.compile(loss='sparse_categorical_crossentropy', model.compile(loss='sparse_categorical_crossentropy',
optimizer=optimizer, optimizer=optimizer,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment