Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ba39a3db
Commit
ba39a3db
authored
Dec 21, 2018
by
Toby Boyd
Browse files
Change test to use run() to match API change.
parent
80dcd27c
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
8 deletions
+6
-8
official/resnet/keras/keras_cifar_benchmark.py
official/resnet/keras/keras_cifar_benchmark.py
+6
-8
No files found.
official/resnet/keras/keras_cifar_benchmark.py
View file @
ba39a3db
...
@@ -9,6 +9,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
...
@@ -9,6 +9,7 @@ import tensorflow as tf # pylint: disable=g-bad-import-order
from
official.resnet
import
cifar10_main
as
cifar_main
from
official.resnet
import
cifar10_main
as
cifar_main
import
official.resnet.keras.keras_cifar_main
as
keras_cifar_main
import
official.resnet.keras.keras_cifar_main
as
keras_cifar_main
import
official.resnet.keras.keras_common
as
keras_common
DATA_DIR
=
'/data/cifar10_data/'
DATA_DIR
=
'/data/cifar10_data/'
...
@@ -32,7 +33,7 @@ class KerasCifar10BenchmarkTests(object):
...
@@ -32,7 +33,7 @@ class KerasCifar10BenchmarkTests(object):
flags
.
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'keras_resnet56_1_gpu'
)
flags
.
FLAGS
.
model_dir
=
self
.
_get_model_dir
(
'keras_resnet56_1_gpu'
)
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
_cifar_with_keras
(
flags
.
FLAGS
)
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
self
.
_fill_report_object
(
stats
)
def
keras_resnet56_4_gpu
(
self
):
def
keras_resnet56_4_gpu
(
self
):
...
@@ -45,7 +46,7 @@ class KerasCifar10BenchmarkTests(object):
...
@@ -45,7 +46,7 @@ class KerasCifar10BenchmarkTests(object):
flags
.
FLAGS
.
model_dir
=
''
flags
.
FLAGS
.
model_dir
=
''
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
_cifar_with_keras
(
flags
.
FLAGS
)
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
self
.
_fill_report_object
(
stats
)
def
keras_resnet56_no_dist_strat_1_gpu
(
self
):
def
keras_resnet56_no_dist_strat_1_gpu
(
self
):
...
@@ -60,7 +61,7 @@ class KerasCifar10BenchmarkTests(object):
...
@@ -60,7 +61,7 @@ class KerasCifar10BenchmarkTests(object):
'keras_resnet56_no_dist_strat_1_gpu'
)
'keras_resnet56_no_dist_strat_1_gpu'
)
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
resnet_size
=
56
flags
.
FLAGS
.
dtype
=
'fp32'
flags
.
FLAGS
.
dtype
=
'fp32'
stats
=
keras_cifar_main
.
run
_cifar_with_keras
(
flags
.
FLAGS
)
stats
=
keras_cifar_main
.
run
(
flags
.
FLAGS
)
self
.
_fill_report_object
(
stats
)
self
.
_fill_report_object
(
stats
)
def
_fill_report_object
(
self
,
stats
):
def
_fill_report_object
(
self
,
stats
):
...
@@ -76,17 +77,14 @@ class KerasCifar10BenchmarkTests(object):
...
@@ -76,17 +77,14 @@ class KerasCifar10BenchmarkTests(object):
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
return
os
.
path
.
join
(
self
.
output_dir
,
folder_name
)
def
_setup
(
self
):
def
_setup
(
self
):
"""Setups up and resets flags before each test."""
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
DEBUG
)
if
KerasCifar10BenchmarkTests
.
local_flags
is
None
:
if
KerasCifar10BenchmarkTests
.
local_flags
is
None
:
print
(
'Build Flags!!!!'
)
keras_common
.
define_keras_flags
()
keras_cifar_main
.
define_keras_cifar_flags
()
cifar_main
.
define_cifar_flags
()
cifar_main
.
define_cifar_flags
()
# Loads flags to get defaults to then override.
# Loads flags to get defaults to then override.
flags
.
FLAGS
([
'foo'
])
flags
.
FLAGS
([
'foo'
])
saved_flag_values
=
flagsaver
.
save_flag_values
()
saved_flag_values
=
flagsaver
.
save_flag_values
()
KerasCifar10BenchmarkTests
.
local_flags
=
saved_flag_values
KerasCifar10BenchmarkTests
.
local_flags
=
saved_flag_values
return
return
print
(
'Restore Flags'
)
flagsaver
.
restore_flag_values
(
KerasCifar10BenchmarkTests
.
local_flags
)
flagsaver
.
restore_flag_values
(
KerasCifar10BenchmarkTests
.
local_flags
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment