Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
2659ca30
Commit
2659ca30
authored
Jul 24, 2020
by
Chen Chen
Committed by
A. Unique TensorFlower
Jul 24, 2020
Browse files
Change the dataset_transform_fn argument in InputReader's constructor to transform_and_batch_fn.
PiperOrigin-RevId: 323013252
parent
07484704
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
13 deletions
+17
-13
official/core/input_reader.py
official/core/input_reader.py
+17
-13
No files found.
official/core/input_reader.py
View file @
2659ca30
...
...
@@ -32,8 +32,9 @@ class InputReader:
dataset_fn
=
tf
.
data
.
TFRecordDataset
,
decoder_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
parser_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
,
dataset_transform_fn
:
Optional
[
Callable
[[
tf
.
data
.
Dataset
],
tf
.
data
.
Dataset
]]
=
None
,
transform_and_batch_fn
:
Optional
[
Callable
[
[
tf
.
data
.
Dataset
,
Optional
[
tf
.
distribute
.
InputContext
]],
tf
.
data
.
Dataset
]]
=
None
,
postprocess_fn
:
Optional
[
Callable
[...,
Any
]]
=
None
):
"""Initializes an InputReader instance.
...
...
@@ -48,9 +49,12 @@ class InputReader:
parser_fn: An optional `callable` that takes the decoded raw tensors dict
and parse them into a dictionary of tensors that can be consumed by the
model. It will be executed after decoder_fn.
dataset_transform_fn: An optional `callable` that takes a
`tf.data.Dataset` object and returns a `tf.data.Dataset`. It will be
executed after parser_fn.
transform_and_batch_fn: An optional `callable` that takes a
`tf.data.Dataset` object and an optional `tf.distribute.InputContext` as
input, and returns a `tf.data.Dataset` object. It will be
executed after `parser_fn` to transform and batch the dataset; if None,
after `parser_fn` is executed, the dataset will be batched into
per-replica batch size.
postprocess_fn: A optional `callable` that processes batched tensors. It
will be executed after batching.
"""
...
...
@@ -101,7 +105,7 @@ class InputReader:
self
.
_dataset_fn
=
dataset_fn
self
.
_decoder_fn
=
decoder_fn
self
.
_parser_fn
=
parser_fn
self
.
_
dataset_
transform_
fn
=
dataset_transform
_fn
self
.
_transform_
and_batch_fn
=
transform_and_batch
_fn
self
.
_postprocess_fn
=
postprocess_fn
def
_read_sharded_files
(
...
...
@@ -214,13 +218,13 @@ class InputReader:
dataset
=
maybe_map_fn
(
dataset
,
self
.
_decoder_fn
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_parser_fn
)
if
self
.
_dataset_transform_fn
is
not
None
:
dataset
=
self
.
_dataset_transform_fn
(
dataset
)
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
if
self
.
_transform_and_batch_fn
is
not
None
:
dataset
=
self
.
_transform_and_batch_fn
(
dataset
,
input_context
)
else
:
per_replica_batch_size
=
input_context
.
get_per_replica_batch_size
(
self
.
_global_batch_size
)
if
input_context
else
self
.
_global_batch_size
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
dataset
=
dataset
.
batch
(
per_replica_batch_size
,
drop_remainder
=
self
.
_drop_remainder
)
dataset
=
maybe_map_fn
(
dataset
,
self
.
_postprocess_fn
)
return
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment