Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0734276a
Commit
0734276a
authored
Nov 01, 2019
by
Will Cromar
Committed by
A. Unique TensorFlower
Nov 01, 2019
Browse files
Fix mnist_test.py
PiperOrigin-RevId: 278024052
parent
ffa522ea
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
17 deletions
+22
-17
official/vision/image_classification/mnist_main.py
official/vision/image_classification/mnist_main.py
+4
-2
official/vision/image_classification/mnist_test.py
official/vision/image_classification/mnist_test.py
+18
-15
No files found.
official/vision/image_classification/mnist_main.py
View file @
0734276a
...
@@ -69,11 +69,13 @@ def decode_image(example, feature):
...
@@ -69,11 +69,13 @@ def decode_image(example, feature):
return
tf
.
cast
(
feature
.
decode_example
(
example
),
dtype
=
tf
.
float32
)
/
255
return
tf
.
cast
(
feature
.
decode_example
(
example
),
dtype
=
tf
.
float32
)
/
255
def
run
(
flags_obj
,
strategy_override
=
None
):
def
run
(
flags_obj
,
datasets_override
=
None
,
strategy_override
=
None
):
"""Run MNIST model training and eval loop using native Keras APIs.
"""Run MNIST model training and eval loop using native Keras APIs.
Args:
Args:
flags_obj: An object containing parsed flag values.
flags_obj: An object containing parsed flag values.
datasets_override: A pair of `tf.data.Dataset` objects to train the model,
representing the train and test sets.
strategy_override: A `tf.distribute.Strategy` object to use for model.
strategy_override: A `tf.distribute.Strategy` object to use for model.
Returns:
Returns:
...
@@ -90,7 +92,7 @@ def run(flags_obj, strategy_override=None):
...
@@ -90,7 +92,7 @@ def run(flags_obj, strategy_override=None):
if
flags_obj
.
download
:
if
flags_obj
.
download
:
mnist
.
download_and_prepare
()
mnist
.
download_and_prepare
()
mnist_train
,
mnist_test
=
mnist
.
as_dataset
(
mnist_train
,
mnist_test
=
datasets_override
or
mnist
.
as_dataset
(
split
=
[
'train'
,
'test'
],
split
=
[
'train'
,
'test'
],
decoders
=
{
'image'
:
decode_image
()},
# pylint: disable=no-value-for-parameter
decoders
=
{
'image'
:
decode_image
()},
# pylint: disable=no-value-for-parameter
as_supervised
=
True
)
as_supervised
=
True
)
...
...
official/vision/image_classification/mnist_test.py
View file @
0734276a
...
@@ -67,16 +67,19 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
...
@@ -67,16 +67,19 @@ class KerasMnistTest(tf.test.TestCase, parameterized.TestCase):
"--data_dir="
"--data_dir="
]
]
def
_mock_dataset
(
self
,
*
args
,
**
kwargs
):
# pylint: disable=unused-argument
dummy_data
=
(
"""Generate mock dataset with TPU-compatible dtype (instead of uint8)."""
tf
.
ones
(
shape
=
(
10
,
28
,
28
,
1
),
dtype
=
tf
.
int32
),
return
tf
.
data
.
Dataset
.
from_tensor_slices
({
tf
.
range
(
10
),
"image"
:
tf
.
ones
(
shape
=
(
10
,
28
,
28
,
1
),
dtype
=
tf
.
int32
),
)
"label"
:
tf
.
range
(
10
),
datasets
=
(
})
tf
.
data
.
Dataset
.
from_tensor_slices
(
dummy_data
),
tf
.
data
.
Dataset
.
from_tensor_slices
(
dummy_data
),
)
run
=
functools
.
partial
(
mnist_main
.
run
,
strategy_override
=
distribution
)
run
=
functools
.
partial
(
mnist_main
.
run
,
datasets_override
=
datasets
,
strategy_override
=
distribution
)
with
tfds
.
testing
.
mock_data
(
as_dataset_fn
=
_mock_dataset
):
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
run
,
main
=
run
,
synth
=
False
,
synth
=
False
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment