"mmdet3d/vscode:/vscode.git/clone" did not exist on "db44cc50cb678dde52eab6307627c63623964465"
Commit 74a271b8 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

when the number of files is less than the number of input pipelines, fall back...

when the number of files is less than the number of input pipelines, fall back to "_read_single_file" which will send all files to every workers.

PiperOrigin-RevId: 348865649
parent e2a31b15
......@@ -14,10 +14,10 @@
# limitations under the License.
# ==============================================================================
"""A common dataset reader."""
import random
from typing import Any, Callable, Optional
from absl import logging
import tensorflow as tf
import tensorflow_datasets as tfds
......@@ -139,10 +139,9 @@ class InputReader:
self._tf_data_service_address = params.tf_data_service_address
self._tf_data_service_job_name = params.tf_data_service_job_name
def _read_sharded_files(self,
input_context: Optional[
tf.distribute.InputContext] = None):
"""Reads a dataset from sharded files."""
def _shard_files_then_read(
self, input_context: Optional[tf.distribute.InputContext] = None):
"""Shards the data files and then sent a split to every worker to read."""
dataset = tf.data.Dataset.from_tensor_slices(self._matched_files)
# Shuffle and repeat at file level.
......@@ -170,14 +169,13 @@ class InputReader:
deterministic=self._deterministic)
return dataset
def _read_single_file(self,
input_context: Optional[
tf.distribute.InputContext] = None):
"""Reads a dataset from a single file."""
# Read from `self._shards` if it is provided.
def _read_files_then_shard(
self, input_context: Optional[tf.distribute.InputContext] = None):
"""Sends all data files to every worker and then shard by data."""
dataset = self._dataset_fn(self._matched_files)
# When `input_file` is a path to a single file, disable auto sharding
# When `input_file` is a path to a single file or the number of files is
# less than the number of input pipelines, disable auto sharding
# so that same input file is sent to all workers.
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = (
......@@ -238,9 +236,18 @@ class InputReader:
if self._tfds_builder:
dataset = self._read_tfds(input_context)
elif len(self._matched_files) > 1:
dataset = self._read_sharded_files(input_context)
if input_context and (len(self._matched_files) <
input_context.num_input_pipelines):
logging.warn(
'The number of files %d is less than the number of input pipelines '
'%d. We will send all input files to every worker. '
'Please consider sharding your data into more files.',
len(self._matched_files), input_context.num_input_pipelines)
dataset = self._read_files_then_shard(input_context)
else:
dataset = self._shard_files_then_read(input_context)
elif len(self._matched_files) == 1:
dataset = self._read_single_file(input_context)
dataset = self._read_files_then_shard(input_context)
else:
raise ValueError('It is unexpected that `tfds_builder` is None and '
'there is also no `matched_files`.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment