Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
f2c61881
Commit
f2c61881
authored
Oct 18, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Oct 18, 2019
Browse files
Move CTL resnet example.
PiperOrigin-RevId: 275417626
parent
6cd426d9
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
23 additions
and
294 deletions
+23
-294
official/benchmark/resnet_ctl_imagenet_benchmark.py
official/benchmark/resnet_ctl_imagenet_benchmark.py
+5
-6
official/resnet/README.md
official/resnet/README.md
+2
-2
official/resnet/__init__.py
official/resnet/__init__.py
+0
-38
official/resnet/ctl/__init__.py
official/resnet/ctl/__init__.py
+0
-0
official/resnet/ctl/ctl_common.py
official/resnet/ctl/ctl_common.py
+0
-32
official/resnet/ctl/ctl_imagenet_test_tpu.py
official/resnet/ctl/ctl_imagenet_test_tpu.py
+0
-103
official/vision/image_classification/resnet_ctl_imagenet_main.py
...l/vision/image_classification/resnet_ctl_imagenet_main.py
+7
-3
official/vision/image_classification/resnet_ctl_imagenet_test.py
...l/vision/image_classification/resnet_ctl_imagenet_test.py
+9
-20
official/vision/image_classification/resnet_imagenet_test_tpu.py
...l/vision/image_classification/resnet_imagenet_test_tpu.py
+0
-90
No files found.
official/benchmark/resnet_ctl_imagenet_benchmark.py
View file @
f2c61881
...
@@ -23,8 +23,7 @@ from absl import flags
...
@@ -23,8 +23,7 @@ from absl import flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.vision.image_classification
import
resnet_ctl_imagenet_main
from
official.resnet.ctl
import
ctl_common
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
...
@@ -121,7 +120,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -121,7 +120,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
arguments before updating the constructor.
arguments before updating the constructor.
"""
"""
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
common
.
define_keras_flags
]
flag_methods
=
[
common
.
define_keras_flags
]
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
super
(
Resnet50CtlAccuracy
,
self
).
__init__
(
super
(
Resnet50CtlAccuracy
,
self
).
__init__
(
...
@@ -158,7 +157,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -158,7 +157,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
def
_run_and_report_benchmark
(
self
):
def
_run_and_report_benchmark
(
self
):
start_time_sec
=
time
.
time
()
start_time_sec
=
time
.
time
()
stats
=
ctl_imagenet_main
.
run
(
flags
.
FLAGS
)
stats
=
resnet_
ctl_imagenet_main
.
run
(
flags
.
FLAGS
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
wall_time_sec
=
time
.
time
()
-
start_time_sec
super
(
Resnet50CtlAccuracy
,
self
).
_report_benchmark
(
super
(
Resnet50CtlAccuracy
,
self
).
_report_benchmark
(
...
@@ -177,7 +176,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -177,7 +176,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
"""Resnet50 benchmarks."""
"""Resnet50 benchmarks."""
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
common
.
define_keras_flags
]
flag_methods
=
[
common
.
define_keras_flags
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
output_dir
=
output_dir
,
output_dir
=
output_dir
,
...
@@ -186,7 +185,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -186,7 +185,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def
_run_and_report_benchmark
(
self
):
def
_run_and_report_benchmark
(
self
):
start_time_sec
=
time
.
time
()
start_time_sec
=
time
.
time
()
stats
=
ctl_imagenet_main
.
run
(
FLAGS
)
stats
=
resnet_
ctl_imagenet_main
.
run
(
FLAGS
)
wall_time_sec
=
time
.
time
()
-
start_time_sec
wall_time_sec
=
time
.
time
()
-
start_time_sec
# Number of logged step time entries that are excluded in performance
# Number of logged step time entries that are excluded in performance
...
...
official/resnet/README.md
View file @
f2c61881
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
*
For the Keras version of the ResNet model, see
*
For the Keras version of the ResNet model, see
[
`official/vision/image_classification`
](
../vision/image_classification
)
.
[
`official/vision/image_classification`
](
../vision/image_classification
)
.
*
For the Keras custom training loop version, see
*
For the Keras custom training loop version,
also
see
[
`official/
resnet/ctl`
](
ctl
)
.
[
`official/
vision/image_classification`
](
../vision/image_classification
)
.
*
For the Estimator version, see
[
`official/r1/resnet`
](
../r1/resnet
)
.
*
For the Estimator version, see
[
`official/r1/resnet`
](
../r1/resnet
)
.
official/resnet/__init__.py
deleted
100644 → 0
View file @
6cd426d9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared ResNet modules into this module.
The TensorFlow v1 official models are moved under official/r1/resnet. In order
to be backward compatible with models that directly import v1 modules, we import
the v1 ResNet modules under official.resnet.
New TF models should not depend on modules directly under this path (which will
soon be deprecated and removed).
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
official.r1.resnet
import
cifar10_main
from
official.r1.resnet
import
imagenet_main
from
official.r1.resnet
import
imagenet_preprocessing
from
official.r1.resnet
import
resnet_model
from
official.r1.resnet
import
resnet_run_loop
del
absolute_import
del
division
del
print_function
official/resnet/ctl/__init__.py
deleted
100644 → 0
View file @
6cd426d9
official/resnet/ctl/ctl_common.py
deleted
100644 → 0
View file @
6cd426d9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Common util functions and classes used by CTL."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
absl
import
flags
def
define_ctl_flags
():
"""Define flags for CTL."""
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
help
=
'Wrap the train and test step inside a '
'tf.function.'
)
flags
.
DEFINE_boolean
(
name
=
'single_l2_loss_op'
,
default
=
False
,
help
=
'Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.'
)
official/resnet/ctl/ctl_imagenet_test_tpu.py
deleted
100644 → 0
View file @
6cd426d9
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the ResNet model with ImageNet data using CTL."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tensorflow
as
tf
from
tensorflow.python.platform
import
googletest
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
common
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
class
CtlImagenetTest
(
googletest
.
TestCase
):
"""Unit tests for Keras ResNet with ImageNet using CTL."""
_extra_flags
=
[
'-batch_size'
,
'4'
,
'-train_steps'
,
'4'
,
'-use_synthetic_data'
,
'true'
]
_tempdir
=
None
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
def
setUp
(
self
):
super
(
CtlImagenetTest
,
self
).
setUp
()
if
not
keras_utils
.
is_v2_0
():
tf
.
compat
.
v1
.
enable_v2_behavior
()
imagenet_preprocessing
.
NUM_IMAGES
[
'validation'
]
=
4
def
tearDown
(
self
):
super
(
CtlImagenetTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_end_to_end_tpu
(
self
):
"""Test Keras model with TPU distribution strategy."""
extra_flags
=
[
'-distribution_strategy'
,
'tpu'
,
'-model_dir'
,
'ctl_imagenet_tpu_dist_strat'
,
'-data_format'
,
'channels_last'
,
'-use_tf_function'
,
'true'
,
'-single_l2_loss_op'
,
'true'
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
ctl_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
def
test_end_to_end_tpu_bf16
(
self
):
"""Test Keras model with TPU and bfloat16 activation."""
extra_flags
=
[
'-distribution_strategy'
,
'tpu'
,
'-model_dir'
,
'ctl_imagenet_tpu_dist_strat_bf16'
,
'-data_format'
,
'channels_last'
,
'-use_tf_function'
,
'true'
,
'-single_l2_loss_op'
,
'true'
,
'-dtype'
,
'bf16'
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
ctl_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
if
__name__
==
'__main__'
:
googletest
.
main
()
official/
resnet/ctl/
ctl_imagenet_main.py
→
official/
vision/image_classification/resnet_
ctl_imagenet_main.py
View file @
f2c61881
...
@@ -23,7 +23,6 @@ from absl import flags
...
@@ -23,7 +23,6 @@ from absl import flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.ctl
import
ctl_common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
resnet_model
from
official.vision.image_classification
import
resnet_model
...
@@ -33,6 +32,13 @@ from official.utils.misc import distribution_utils
...
@@ -33,6 +32,13 @@ from official.utils.misc import distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
flags
.
DEFINE_boolean
(
name
=
'use_tf_function'
,
default
=
True
,
help
=
'Wrap the train and test step inside a '
'tf.function.'
)
flags
.
DEFINE_boolean
(
name
=
'single_l2_loss_op'
,
default
=
False
,
help
=
'Calculate L2_loss on concatenated weights, '
'instead of using Keras per-layer L2 loss.'
)
def
build_stats
(
train_result
,
eval_result
,
time_callback
):
def
build_stats
(
train_result
,
eval_result
,
time_callback
):
"""Normalizes and returns dictionary of stats.
"""Normalizes and returns dictionary of stats.
...
@@ -379,6 +385,4 @@ def main(_):
...
@@ -379,6 +385,4 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
flags
.
adopt_module_key_flags
(
ctl_common
)
app
.
run
(
main
)
app
.
run
(
main
)
official/
resnet/ctl/
ctl_imagenet_test.py
→
official/
vision/image_classification/resnet_
ctl_imagenet_test.py
View file @
f2c61881
...
@@ -18,20 +18,16 @@ from __future__ import absolute_import
...
@@ -18,20 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tensorflow.compat.v2
as
tf
import
tensorflow
as
tf
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
context
from
tensorflow.python.platform
import
googletest
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
common
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_ctl_imagenet_main
class
CtlImagenetTest
(
google
test
.
TestCase
):
class
CtlImagenetTest
(
tf
.
test
.
TestCase
):
"""Unit tests for Keras ResNet with ImageNet using CTL."""
"""Unit tests for Keras ResNet with ImageNet using CTL."""
_extra_flags
=
[
_extra_flags
=
[
...
@@ -41,21 +37,13 @@ class CtlImagenetTest(googletest.TestCase):
...
@@ -41,21 +37,13 @@ class CtlImagenetTest(googletest.TestCase):
]
]
_tempdir
=
None
_tempdir
=
None
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
super
(
CtlImagenetTest
,
self
).
setUp
()
super
(
CtlImagenetTest
,
self
).
setUp
()
if
not
keras_utils
.
is_v2_0
():
tf
.
compat
.
v1
.
enable_v2_behavior
()
imagenet_preprocessing
.
NUM_IMAGES
[
'validation'
]
=
4
imagenet_preprocessing
.
NUM_IMAGES
[
'validation'
]
=
4
def
tearDown
(
self
):
def
tearDown
(
self
):
...
@@ -73,7 +61,7 @@ class CtlImagenetTest(googletest.TestCase):
...
@@ -73,7 +61,7 @@ class CtlImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
ctl_imagenet_main
.
run
,
main
=
resnet_
ctl_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -93,10 +81,11 @@ class CtlImagenetTest(googletest.TestCase):
...
@@ -93,10 +81,11 @@ class CtlImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
ctl_imagenet_main
.
run
,
main
=
resnet_
ctl_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
googletest
.
main
()
assert
tf
.
version
.
VERSION
.
startswith
(
'2.'
)
tf
.
test
.
main
()
official/vision/image_classification/resnet_imagenet_test_tpu.py
deleted
100644 → 0
View file @
6cd426d9
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test the keras ResNet model with ImageNet data on TPU."""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
import
tensorflow
as
tf
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_imagenet_main
class
KerasImagenetTest
(
tf
.
test
.
TestCase
):
"""Unit tests for Keras ResNet with ImageNet."""
_extra_flags
=
[
"-batch_size"
,
"4"
,
"-train_steps"
,
"1"
,
"-use_synthetic_data"
,
"true"
]
_tempdir
=
None
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasImagenetTest
,
cls
).
setUpClass
()
resnet_imagenet_main
.
define_imagenet_keras_flags
()
def
setUp
(
self
):
super
(
KerasImagenetTest
,
self
).
setUp
()
imagenet_preprocessing
.
NUM_IMAGES
[
"validation"
]
=
4
def
tearDown
(
self
):
super
(
KerasImagenetTest
,
self
).
tearDown
()
tf
.
io
.
gfile
.
rmtree
(
self
.
get_temp_dir
())
def
test_end_to_end_tpu
(
self
):
"""Test Keras model with TPU distribution strategy."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"tpu"
,
"-data_format"
,
"channels_last"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
def
test_end_to_end_tpu_bf16
(
self
):
"""Test Keras model with TPU and bfloat16 activation."""
config
=
keras_utils
.
get_config_proto_v1
()
tf
.
compat
.
v1
.
enable_eager_execution
(
config
=
config
)
extra_flags
=
[
"-distribution_strategy"
,
"tpu"
,
"-data_format"
,
"channels_last"
,
"-dtype"
,
"bf16"
,
]
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
resnet_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
if
__name__
==
"__main__"
:
tf
.
compat
.
v1
.
enable_v2_behavior
()
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment