Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f5a953c8
Commit
f5a953c8
authored
Mar 14, 2018
by
Chris Shallue
Committed by
Christopher Shallue
Mar 15, 2018
Browse files
Make the input_fn return the dataset directly.
PiperOrigin-RevId: 189060074
parent
bd855ed1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
9 deletions
+9
-9
research/astronet/astronet/util/estimator_util.py
research/astronet/astronet/util/estimator_util.py
+9
-9
No files found.
research/astronet/astronet/util/estimator_util.py
View file @
f5a953c8
...
@@ -69,15 +69,7 @@ def create_input_fn(file_pattern,
...
@@ -69,15 +69,7 @@ def create_input_fn(file_pattern,
repeat
=
repeat
,
repeat
=
repeat
,
use_tpu
=
use_tpu
)
use_tpu
=
use_tpu
)
# We must use an initializable iterator, rather than a one-shot iterator,
return
dataset
# because the input pipeline contains a stateful table that requires
# initialization. We add the initializer to the TABLE_INITIALIZERS
# collection to ensure it is run during initialization.
iterator
=
dataset
.
make_initializable_iterator
()
tf
.
add_to_collection
(
tf
.
GraphKeys
.
TABLE_INITIALIZERS
,
iterator
.
initializer
)
inputs
=
iterator
.
get_next
()
return
inputs
,
inputs
.
pop
(
"labels"
,
None
)
return
input_fn
return
input_fn
...
@@ -103,6 +95,14 @@ def create_model_fn(model_class, hparams, use_tpu=False):
...
@@ -103,6 +95,14 @@ def create_model_fn(model_class, hparams, use_tpu=False):
if
"batch_size"
in
params
:
if
"batch_size"
in
params
:
hparams
.
batch_size
=
params
[
"batch_size"
]
hparams
.
batch_size
=
params
[
"batch_size"
]
# Allow labels to be passed in the features dictionary.
if
"labels"
in
features
:
if
labels
is
not
None
and
labels
is
not
features
[
"labels"
]:
raise
ValueError
(
"Conflicting labels: features['labels'] = %s, labels = %s"
%
(
features
[
"labels"
],
labels
))
labels
=
features
.
pop
(
"labels"
)
model
=
model_class
(
features
,
labels
,
hparams
,
mode
)
model
=
model_class
(
features
,
labels
,
hparams
,
mode
)
model
.
build
()
model
.
build
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment