Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
9af989ce
Commit
9af989ce
authored
Aug 22, 2019
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Aug 22, 2019
Browse files
Internal change
PiperOrigin-RevId: 264895439
parent
519ad098
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
16 additions
and
55 deletions
+16
-55
official/benchmark/models/__init__.py
official/benchmark/models/__init__.py
+0
-0
official/benchmark/models/trivial_model.py
official/benchmark/models/trivial_model.py
+0
-0
official/resnet/ctl/ctl_imagenet_benchmark.py
official/resnet/ctl/ctl_imagenet_benchmark.py
+3
-3
official/resnet/ctl/ctl_imagenet_main.py
official/resnet/ctl/ctl_imagenet_main.py
+8
-8
official/resnet/ctl/ctl_imagenet_test.py
official/resnet/ctl/ctl_imagenet_test.py
+3
-3
official/resnet/keras/__init__.py
official/resnet/keras/__init__.py
+0
-40
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+2
-1
No files found.
official/benchmark/models/__init__.py
0 → 100644
View file @
9af989ce
official/
vision/image_classification
/trivial_model.py
→
official/
benchmark/models
/trivial_model.py
View file @
9af989ce
File moved
official/resnet/ctl/ctl_imagenet_benchmark.py
View file @
9af989ce
...
@@ -22,7 +22,7 @@ import time
...
@@ -22,7 +22,7 @@ import time
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
from
official.utils.testing.perfzero_benchmark
import
PerfZeroBenchmark
...
@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
...
@@ -118,7 +118,7 @@ class Resnet50CtlAccuracy(CtlBenchmark):
flag_methods
=
[
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
]
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
self
.
data_dir
=
os
.
path
.
join
(
root_data_dir
,
'imagenet'
)
...
@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
...
@@ -162,7 +162,7 @@ class Resnet50CtlBenchmarkBase(CtlBenchmark):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
def
__init__
(
self
,
output_dir
=
None
,
default_flags
=
None
):
flag_methods
=
[
flag_methods
=
[
ctl_common
.
define_ctl_flags
,
ctl_common
.
define_ctl_flags
,
keras_
common
.
define_keras_flags
common
.
define_keras_flags
]
]
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
super
(
Resnet50CtlBenchmarkBase
,
self
).
__init__
(
...
...
official/resnet/ctl/ctl_imagenet_main.py
View file @
9af989ce
...
@@ -24,10 +24,10 @@ from absl import logging
...
@@ -24,10 +24,10 @@ from absl import logging
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.
resnet.keras
import
keras
_imagenet_main
from
official.
vision.image_classification
import
resnet
_imagenet_main
from
official.
resnet.keras
import
resnet_model
from
official.
vision.image_classification
import
resnet_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
...
@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
...
@@ -73,7 +73,7 @@ def get_input_dataset(flags_obj, strategy):
"""Returns the test and train input datasets."""
"""Returns the test and train input datasets."""
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
...
@@ -171,7 +171,7 @@ def run(flags_obj):
...
@@ -171,7 +171,7 @@ def run(flags_obj):
use_l2_regularizer
=
not
flags_obj
.
single_l2_loss_op
)
use_l2_regularizer
=
not
flags_obj
.
single_l2_loss_op
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
learning_rate
=
keras_
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
learning_rate
=
common
.
BASE_LEARNING_RATE
,
momentum
=
0.9
,
nesterov
=
True
)
nesterov
=
True
)
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
training_accuracy
=
tf
.
keras
.
metrics
.
SparseCategoricalAccuracy
(
...
@@ -248,7 +248,7 @@ def run(flags_obj):
...
@@ -248,7 +248,7 @@ def run(flags_obj):
training_accuracy
.
reset_states
()
training_accuracy
.
reset_states
()
for
step
in
range
(
train_steps
):
for
step
in
range
(
train_steps
):
optimizer
.
lr
=
keras
_imagenet_main
.
learning_rate_schedule
(
optimizer
.
lr
=
resnet
_imagenet_main
.
learning_rate_schedule
(
epoch
,
step
,
train_steps
,
flags_obj
.
batch_size
)
epoch
,
step
,
train_steps
,
flags_obj
.
batch_size
)
time_callback
.
on_batch_begin
(
step
+
epoch
*
train_steps
)
time_callback
.
on_batch_begin
(
step
+
epoch
*
train_steps
)
...
@@ -297,7 +297,7 @@ def main(_):
...
@@ -297,7 +297,7 @@ def main(_):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
set_verbosity
(
logging
.
INFO
)
logging
.
set_verbosity
(
logging
.
INFO
)
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
ctl_common
.
define_ctl_flags
()
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
keras_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
flags
.
adopt_module_key_flags
(
ctl_common
)
...
...
official/resnet/ctl/ctl_imagenet_test.py
View file @
9af989ce
...
@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
...
@@ -25,8 +25,8 @@ from tensorflow.python.eager import context
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_common
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.resnet.ctl
import
ctl_imagenet_main
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
from
official.
resnet.keras
import
keras_
common
from
official.
vision.image_classification
import
common
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.utils.testing
import
integration
...
@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
...
@@ -49,7 +49,7 @@ class CtlImagenetTest(googletest.TestCase):
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
super
(
CtlImagenetTest
,
cls
).
setUpClass
()
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
ctl_common
.
define_ctl_flags
()
ctl_common
.
define_ctl_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
...
...
official/resnet/keras/__init__.py
deleted
100644 → 0
View file @
519ad098
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Bring in the shared Keras ResNet modules into this module.
The TensorFlow official Keras models are moved under
official/vision/image_classification
In order to be backward compatible with models that directly import its modules,
we import the Keras ResNet modules under official.resnet.keras.
New TF models should not depend on modules directly under this path.
"""
from
__future__
import
absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
common
as
keras_common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_cifar_main
as
keras_cifar_main
from
official.vision.image_classification
import
resnet_cifar_model
from
official.vision.image_classification
import
resnet_imagenet_main
as
keras_imagenet_main
from
official.vision.image_classification
import
resnet_model
del
absolute_import
del
division
del
print_function
official/vision/image_classification/resnet_imagenet_main.py
View file @
9af989ce
...
@@ -31,7 +31,7 @@ from official.utils.misc import model_helpers
...
@@ -31,7 +31,7 @@ from official.utils.misc import model_helpers
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_model
from
official.vision.image_classification
import
resnet_model
from
official.
vision.image_classification
import
trivial_model
from
official.
benchmark.models
import
trivial_model
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
...
@@ -186,6 +186,7 @@ def run(flags_obj):
...
@@ -186,6 +186,7 @@ def run(flags_obj):
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
))
default_for_fp16
=
128
))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if
flags_obj
.
use_trivial_model
:
if
flags_obj
.
use_trivial_model
:
model
=
trivial_model
.
trivial_model
(
model
=
trivial_model
.
trivial_model
(
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
)
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment