Commit c6970b7f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add sample_fn to input_reader.

PiperOrigin-RevId: 336973004
parent 6e94b63b
......@@ -35,6 +35,7 @@ class InputReader:
params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]],
......@@ -48,6 +49,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
......@@ -124,6 +128,7 @@ class InputReader:
self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn
self._sample_fn = sample_fn
self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn
......@@ -251,6 +256,8 @@ class InputReader:
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn)
dataset = maybe_map_fn(dataset, self._parser_fn)
if self._transform_and_batch_fn is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment