"examples/sampling/graphbolt/vscode:/vscode.git/clone" did not exist on "d3176272c6e58b9d500bd6939e061c47bb5e92ab"
Commit aac33549 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 437358569
parent b6fcc07d
...@@ -160,22 +160,44 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder, ...@@ -160,22 +160,44 @@ def _read_tfds(tfds_builder: tfds.core.DatasetBuilder,
"""Reads a dataset from tfds.""" """Reads a dataset from tfds."""
# No op if exist. # No op if exist.
tfds_builder.download_and_prepare() tfds_builder.download_and_prepare()
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
decoders = {} decoders = {}
if tfds_skip_decoding_feature: if tfds_skip_decoding_feature:
for skip_feature in tfds_skip_decoding_feature.split(','): for skip_feature in tfds_skip_decoding_feature.split(','):
decoders[skip_feature.strip()] = tfds.decode.SkipDecoding() decoders[skip_feature.strip()] = tfds.decode.SkipDecoding()
dataset = tfds_builder.as_dataset( if tfds_builder.info.splits:
split=tfds_split, num_shards = len(tfds_builder.info.splits[tfds_split].file_instructions)
shuffle_files=is_training, else:
as_supervised=tfds_as_supervised, # The tfds mock path often does not provide splits.
decoders=decoders, num_shards = 1
read_config=read_config) if input_context and num_shards < input_context.num_input_pipelines:
# The number of files in the dataset split is smaller than the number of
# input pipelines. We read the entire dataset first and then shard in the
# host memory.
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=None,
shuffle_seed=seed)
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
dataset = dataset.shard(input_context.num_input_pipelines,
input_context.input_pipeline_id)
else:
read_config = tfds.ReadConfig(
interleave_cycle_length=cycle_length,
interleave_block_length=block_length,
input_context=input_context,
shuffle_seed=seed)
dataset = tfds_builder.as_dataset(
split=tfds_split,
shuffle_files=is_training,
as_supervised=tfds_as_supervised,
decoders=decoders,
read_config=read_config)
if is_training and not cache: if is_training and not cache:
dataset = dataset.repeat() dataset = dataset.repeat()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment