Commit 7e67dbbc authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 286061591
parent 68146271
...@@ -32,6 +32,7 @@ RETINANET_CFG = { ...@@ -32,6 +32,7 @@ RETINANET_CFG = {
'type': 'retinanet', 'type': 'retinanet',
'model_dir': '', 'model_dir': '',
'use_tpu': True, 'use_tpu': True,
'strategy_type': 'tpu',
'train': { 'train': {
'batch_size': 64, 'batch_size': 64,
'iterations_per_loop': 500, 'iterations_per_loop': 500,
......
...@@ -94,7 +94,9 @@ class InputFn(object): ...@@ -94,7 +94,9 @@ class InputFn(object):
dataset = dataset.cache() dataset = dataset.cache()
if self._is_training: if self._is_training:
dataset = dataset.shuffle(64) # Large shuffle size is critical for 2vm input pipeline. Can use small
# value (e.g. 64) for 1vm.
dataset = dataset.shuffle(1000)
if self._num_examples > 0: if self._num_examples > 0:
dataset = dataset.take(self._num_examples) dataset = dataset.take(self._num_examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment