"megatron/git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "7120e9315cb85774f01b3f2738814246ae73b39d"
Commit fa66c645 authored by Yeqing Li's avatar Yeqing Li Committed by A. Unique TensorFlower
Browse files

Fix the missing shuffle issues when reading shards data.

PiperOrigin-RevId: 326246781
parent 9b179e8e
......@@ -138,6 +138,8 @@ class InputReader:
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
dataset = dataset.interleave(
map_func=self._dataset_fn,
......@@ -163,6 +165,8 @@ class InputReader:
input_context.num_input_pipelines > 1):
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
if self._is_training:
dataset = dataset.repeat()
return dataset
def _read_tfds(
......@@ -188,6 +192,8 @@ class InputReader:
decoders=decoders,
read_config=read_config)
if self._is_training:
dataset = dataset.repeat()
return dataset
@property
......@@ -212,6 +218,12 @@ class InputReader:
assert self._num_files == 1
dataset = self._read_single_file(input_context)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(self._shuffle_buffer_size)
if self._examples_consume > 0:
dataset = dataset.take(self._examples_consume)
......@@ -222,16 +234,6 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._cache:
dataset = dataset.cache()
if self._is_training:
dataset = dataset.shuffle(
self._shuffle_buffer_size,
seed=self._seed,
reshuffle_each_iteration=True)
dataset = dataset.repeat()
if self._transform_and_batch_fn is not None:
dataset = self._transform_and_batch_fn(dataset, input_context)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment