Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d32be917
Commit
d32be917
authored
Oct 13, 2020
by
A. Unique TensorFlower
Browse files
Add sample_fn to input_reader.
PiperOrigin-RevId: 336973004
parent
2b2d4820
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
0 deletions
+7
-0
official/core/input_reader.py
official/core/input_reader.py
+7
-0
No files found.
official/core/input_reader.py
View file @
d32be917
...
@@ -35,6 +35,7 @@ class InputReader:
...
@@ -35,6 +35,7 @@ class InputReader:
params
:
cfg
.
DataConfig
,
params
:
cfg
.
DataConfig
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
sample_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
transform_and_batch_fn
:
Optional
[
Callable
[
[
tf
.
data
.
Dataset
,
Optional
[
tf
.
distribute
.
InputContext
]],
[
tf
.
data
.
Dataset
,
Optional
[
tf
.
distribute
.
InputContext
]],
...
@@ -48,6 +49,9 @@ class InputReader:
...
@@ -48,6 +49,9 @@ class InputReader:
example, it can be `tf.data.TFRecordDataset`.
example, it can be `tf.data.TFRecordDataset`.
decoder_fn: An optional `callable` that takes the serialized data string
decoder_fn: An optional `callable` that takes the serialized data string
and decodes them into the raw tensor dictionary.
and decodes them into the raw tensor dictionary.
sample_fn: An optional `callable` that takes a `tf.data.Dataset` object as
input and outputs the transformed dataset. It performs sampling on the
decoded raw tensors dict before the parser_fn.
parser_fn: An optional `callable` that takes the decoded raw tensors dict
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
model. It will be executed after decoder_fn.
...
@@ -124,6 +128,7 @@ class InputReader:
...
@@ -124,6 +128,7 @@ class InputReader:
self
.
_dataset_fn
=
dataset_fn
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_sample_fn
=
sample_fn
self
.
_parser_fn
=
parser_fn
self
.
_parser_fn
=
parser_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_transform_and_batch_fn
=
transform_and_batch_fn
self
.
_postprocess_fn
=
postprocess_fn
self
.
_postprocess_fn
=
postprocess_fn
...
@@ -251,6 +256,8 @@ class InputReader:
...
@@ -251,6 +256,8 @@ class InputReader:
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
fn
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
if
self
.
_sample_fn
is
not
None
:
dataset
=
dataset
.
apply
(
self
.
_sample_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
if
self
.
_transform_and_batch_fn
is
not
None
:
if
self
.
_transform_and_batch_fn
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment