"vscode:/vscode.git/clone" did not exist on "56cb8c15bd1c0785cc6a93429cc00b98dbce26bd"
Commit 2659ca30 authored by Chen Chen's avatar Chen Chen Committed by A. Unique TensorFlower
Browse files

Change the dataset_transform_fn argument in InputReader's constructor to transform_and_batch_fn.

PiperOrigin-RevId: 323013252
parent 07484704
...@@ -32,7 +32,8 @@ class InputReader: ...@@ -32,7 +32,8 @@ class InputReader:
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
dataset_transform_fn: Optional[Callable[[tf.data.Dataset], transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
tf.data.Dataset]] = None, tf.data.Dataset]] = None,
postprocess_fn: Optional[Callable[..., Any]] = None): postprocess_fn: Optional[Callable[..., Any]] = None):
"""Initializes an InputReader instance. """Initializes an InputReader instance.
...@@ -48,9 +49,12 @@ class InputReader: ...@@ -48,9 +49,12 @@ class InputReader:
parser_fn: An optional `callable` that takes the decoded raw tensors dict parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be `tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
executed after parser_fn. input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching. will be executed after batching.
""" """
...@@ -101,7 +105,7 @@ class InputReader: ...@@ -101,7 +105,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._dataset_transform_fn = dataset_transform_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
def _read_sharded_files( def _read_sharded_files(
...@@ -214,13 +218,13 @@ class InputReader: ...@@ -214,13 +218,13 @@ class InputReader:
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._dataset_transform_fn is not None: if self._transform_and_batch_fn is not None:
dataset = self._dataset_transform_fn(dataset) dataset = self._transform_and_batch_fn(dataset, input_context)
else:
per_replica_batch_size = input_context.get_per_replica_batch_size( per_replica_batch_size = input_context.get_per_replica_batch_size(
self._global_batch_size) if input_context else self._global_batch_size self._global_batch_size) if input_context else self._global_batch_size
dataset = dataset.batch( dataset = dataset.batch(
per_replica_batch_size, drop_remainder=self._drop_remainder) per_replica_batch_size, drop_remainder=self._drop_remainder)
dataset = maybe_map_fn(dataset, self._postprocess_fn) dataset = maybe_map_fn(dataset, self._postprocess_fn)
return dataset.prefetch(tf.data.experimental.AUTOTUNE) return dataset.prefetch(tf.data.experimental.AUTOTUNE)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment