Commit c6970b7f authored by A. Unique TensorFlower's avatar A. Unique TensorFlower
Browse files

Add sample_fn to input_reader.

PiperOrigin-RevId: 336973004
parent 6e94b63b
...@@ -35,6 +35,7 @@ class InputReader: ...@@ -35,6 +35,7 @@ class InputReader:
params: cfg.DataConfig, params: cfg.DataConfig,
dataset_fn=tf.data.TFRecordDataset, dataset_fn=tf.data.TFRecordDataset,
decoder_fn: Optional[Callable[..., Any]] = None, decoder_fn: Optional[Callable[..., Any]] = None,
sample_fn: Optional[Callable[..., Any]] = None,
parser_fn: Optional[Callable[..., Any]] = None, parser_fn: Optional[Callable[..., Any]] = None,
transform_and_batch_fn: Optional[Callable[ transform_and_batch_fn: Optional[Callable[
[tf.data.Dataset, Optional[tf.distribute.InputContext]], [tf.data.Dataset, Optional[tf.distribute.InputContext]],
...@@ -48,6 +49,9 @@ class InputReader: ...@@ -48,6 +49,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`. example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary. and decodes them into the raw tensor dictionary.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn. model. It will be executed after decoder_fn.
...@@ -124,6 +128,7 @@ class InputReader: ...@@ -124,6 +128,7 @@ class InputReader:
self._dataset_fn = dataset_fn self._dataset_fn = dataset_fn
self._decoder_fn = decoder_fn self._decoder_fn = decoder_fn
self._sample_fn = sample_fn
self._parser_fn = parser_fn self._parser_fn = parser_fn
self._transform_and_batch_fn = transform_and_batch_fn self._transform_and_batch_fn = transform_and_batch_fn
self._postprocess_fn = postprocess_fn self._postprocess_fn = postprocess_fn
...@@ -251,6 +256,8 @@ class InputReader: ...@@ -251,6 +256,8 @@ class InputReader:
fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) fn, num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = maybe_map_fn(dataset, self._decoder_fn) dataset = maybe_map_fn(dataset, self._decoder_fn)
if self._sample_fn is not None:
dataset = dataset.apply(self._sample_fn)
dataset = maybe_map_fn(dataset, self._parser_fn) dataset = maybe_map_fn(dataset, self._parser_fn)
if self._transform_and_batch_fn is not None: if self._transform_and_batch_fn is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment