Unverified Commit 6dea4846 authored by rxsang's avatar rxsang Committed by GitHub
Browse files

Fix Resnet XLA with multi-GPUs (#6510)

Don't pass `batch_size` to keras.layers.Input in DS multi-replica case. There is currently a bug in Keras side which will cause a batch size incompatible error.
parent 74d924e9
......@@ -174,9 +174,10 @@ def run(flags_obj):
optimizer, loss_scale=flags_core.get_loss_scale(flags_obj))
if flags_obj.enable_xla:
if strategy:
per_replica_batch_size = (
flags_obj.batch_size // strategy.num_replicas_in_sync)
if strategy and strategy.num_replicas_in_sync > 1:
# TODO(b/129791381): Specify `per_replica_batch_size` value in
# DistributionStrategy multi-replica case.
per_replica_batch_size = None
else:
per_replica_batch_size = flags_obj.batch_size
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment