Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
901c4cc4
Commit
901c4cc4
authored
Aug 20, 2019
by
Vinh Nguyen
Browse files
Merge remote-tracking branch 'upstream/master' into amp_resnet50
parents
ef30de93
824ff2d6
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
899 additions
and
139 deletions
+899
-139
official/vision/image_classification/cifar_preprocessing.py
official/vision/image_classification/cifar_preprocessing.py
+1
-1
official/vision/image_classification/common.py
official/vision/image_classification/common.py
+3
-6
official/vision/image_classification/common_test.py
official/vision/image_classification/common_test.py
+7
-7
official/vision/image_classification/imagenet_preprocessing.py
...ial/vision/image_classification/imagenet_preprocessing.py
+0
-0
official/vision/image_classification/resnet_cifar_main.py
official/vision/image_classification/resnet_cifar_main.py
+11
-11
official/vision/image_classification/resnet_cifar_model.py
official/vision/image_classification/resnet_cifar_model.py
+0
-0
official/vision/image_classification/resnet_cifar_test.py
official/vision/image_classification/resnet_cifar_test.py
+14
-15
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+21
-16
official/vision/image_classification/resnet_imagenet_test.py
official/vision/image_classification/resnet_imagenet_test.py
+18
-18
official/vision/image_classification/resnet_model.py
official/vision/image_classification/resnet_model.py
+0
-0
official/vision/image_classification/trivial_model.py
official/vision/image_classification/trivial_model.py
+0
-0
official/wide_deep/__init__.py
official/wide_deep/__init__.py
+0
-0
research/lstm_object_detection/README.md
research/lstm_object_detection/README.md
+5
-0
research/lstm_object_detection/configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config
...configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config
+239
-0
research/lstm_object_detection/eval.py
research/lstm_object_detection/eval.py
+1
-3
research/lstm_object_detection/export_tflite_lstd_graph.py
research/lstm_object_detection/export_tflite_lstd_graph.py
+138
-0
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
...rch/lstm_object_detection/export_tflite_lstd_graph_lib.py
+327
-0
research/lstm_object_detection/export_tflite_lstd_model.py
research/lstm_object_detection/export_tflite_lstd_model.py
+65
-0
research/lstm_object_detection/g3doc/exporting_models.md
research/lstm_object_detection/g3doc/exporting_models.md
+49
-0
research/lstm_object_detection/inputs/seq_dataset_builder_test.py
.../lstm_object_detection/inputs/seq_dataset_builder_test.py
+0
-62
No files found.
official/
resnet/keras
/cifar_preprocessing.py
→
official/
vision/image_classification
/cifar_preprocessing.py
View file @
901c4cc4
...
...
@@ -22,7 +22,7 @@ import os
from
absl
import
logging
import
tensorflow
as
tf
from
official.
resnet.keras
import
imagenet_preprocessing
from
official.
vision.image_classification
import
imagenet_preprocessing
HEIGHT
=
32
WIDTH
=
32
...
...
official/
resnet/keras/keras_
common.py
→
official/
vision/image_classification/
common.py
View file @
901c4cc4
...
...
@@ -20,17 +20,13 @@ from __future__ import print_function
import
multiprocessing
import
os
import
numpy
as
np
# pylint: disable=g-bad-import-order
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.python.keras.optimizer_v2
import
gradient_descent
as
gradient_descent_v2
from
official.utils.flags
import
core
as
flags_core
from
official.utils.misc
import
keras_utils
# pylint: disable=ungrouped-imports
from
tensorflow.python.keras.optimizer_v2
import
(
gradient_descent
as
gradient_descent_v2
)
FLAGS
=
flags
.
FLAGS
BASE_LEARNING_RATE
=
0.1
# This matches Jing's version.
...
...
@@ -262,6 +258,7 @@ def define_keras_flags(dynamic_loss_scale=True):
force_v2_in_keras_compile
=
True
)
flags_core
.
define_image
()
flags_core
.
define_benchmark
()
flags_core
.
define_distribution
()
flags
.
adopt_module_key_flags
(
flags_core
)
flags
.
DEFINE_boolean
(
name
=
'enable_eager'
,
default
=
False
,
help
=
'Enable eager?'
)
...
...
official/
resnet/keras/keras_
common_test.py
→
official/
vision/image_classification/
common_test.py
View file @
901c4cc4
...
...
@@ -12,21 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for the
keras_
common module."""
"""Tests for the common module."""
from
__future__
import
absolute_import
from
__future__
import
print_function
from
mock
import
Mock
import
numpy
as
np
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
from
tensorflow.python.platform
import
googletest
import
tensorflow
as
tf
from
official.resnet.keras
import
keras_common
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
common
class
KerasCommonTests
(
tf
.
test
.
TestCase
):
"""Tests for
keras_
common."""
"""Tests for common."""
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
...
...
@@ -42,7 +42,7 @@ class KerasCommonTests(tf.test.TestCase):
keras_utils
.
BatchTimestamp
(
1
,
2
),
keras_utils
.
BatchTimestamp
(
2
,
3
)]
th
.
train_finish_time
=
12345
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
[
th
])
stats
=
common
.
build_stats
(
history
,
eval_output
,
[
th
])
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
...
...
@@ -57,7 +57,7 @@ class KerasCommonTests(tf.test.TestCase):
history
=
self
.
_build_history
(
1.145
,
cat_accuracy_sparse
=
.
99988
)
eval_output
=
self
.
_build_eval_output
(.
928
,
1.9844
)
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
None
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
None
)
self
.
assertEqual
(
1.145
,
stats
[
'loss'
])
self
.
assertEqual
(.
99988
,
stats
[
'training_accuracy_top_1'
])
...
...
official/
resnet/keras
/imagenet_preprocessing.py
→
official/
vision/image_classification
/imagenet_preprocessing.py
View file @
901c4cc4
File moved
official/
resnet/keras/keras
_cifar_main.py
→
official/
vision/image_classification/resnet
_cifar_main.py
View file @
901c4cc4
...
...
@@ -22,13 +22,13 @@ from absl import app as absl_app
from
absl
import
flags
import
tensorflow
as
tf
from
official.resnet.keras
import
cifar_preprocessing
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_cifar_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
resnet_cifar_model
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
...
...
@@ -55,7 +55,7 @@ def learning_rate_schedule(current_epoch,
Adjusted learning rate.
"""
del
current_batch
,
batches_per_epoch
# not used
initial_learning_rate
=
keras_
common
.
BASE_LEARNING_RATE
*
batch_size
/
128
initial_learning_rate
=
common
.
BASE_LEARNING_RATE
*
batch_size
/
128
learning_rate
=
initial_learning_rate
for
mult
,
start_epoch
in
LR_SCHEDULE
:
if
current_epoch
>=
start_epoch
:
...
...
@@ -83,8 +83,8 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if
flags_obj
.
tf_gpu_thread_mode
:
keras_
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
keras_
common
.
set_cudnn_batchnorm_mode
()
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
common
.
set_cudnn_batchnorm_mode
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'fp16'
:
...
...
@@ -116,7 +116,7 @@ def run(flags_obj):
if
flags_obj
.
use_synthetic_data
:
distribution_utils
.
set_up_synthetic_data
()
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
cifar_preprocessing
.
HEIGHT
,
width
=
cifar_preprocessing
.
WIDTH
,
num_channels
=
cifar_preprocessing
.
NUM_CHANNELS
,
...
...
@@ -150,7 +150,7 @@ def run(flags_obj):
parse_record_fn
=
cifar_preprocessing
.
parse_record
)
with
strategy_scope
:
optimizer
=
keras_
common
.
get_optimizer
()
optimizer
=
common
.
get_optimizer
()
model
=
resnet_cifar_model
.
resnet56
(
classes
=
cifar_preprocessing
.
NUM_CLASSES
)
# TODO(b/138957587): Remove when force_v2_in_keras_compile is on longer
...
...
@@ -171,7 +171,7 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
callbacks
=
keras_
common
.
get_callbacks
(
callbacks
=
common
.
get_callbacks
(
learning_rate_schedule
,
cifar_preprocessing
.
NUM_IMAGES
[
'train'
])
train_steps
=
cifar_preprocessing
.
NUM_IMAGES
[
'train'
]
//
flags_obj
.
batch_size
...
...
@@ -216,12 +216,12 @@ def run(flags_obj):
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
def
define_cifar_flags
():
keras_
common
.
define_keras_flags
(
dynamic_loss_scale
=
False
)
common
.
define_keras_flags
(
dynamic_loss_scale
=
False
)
flags_core
.
set_defaults
(
data_dir
=
'/tmp/cifar10_data/cifar-10-batches-bin'
,
model_dir
=
'/tmp/cifar10_model'
,
...
...
official/
resnet/keras
/resnet_cifar_model.py
→
official/
vision/image_classification
/resnet_cifar_model.py
View file @
901c4cc4
File moved
official/
resnet/keras/keras
_cifar_test.py
→
official/
vision/image_classification/resnet
_cifar_test.py
View file @
901c4cc4
...
...
@@ -18,17 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tempfile
import
tensorflow
as
tf
from
official.resnet.keras
import
cifar_preprocessing
from
official.resnet.keras
import
keras_cifar_main
from
official.resnet.keras
import
keras_common
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
# pylint: disable=ungrouped-imports
from
tensorflow.python.eager
import
context
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
cifar_preprocessing
from
official.vision.image_classification
import
resnet_cifar_main
class
KerasCifarTest
(
googletest
.
TestCase
):
...
...
@@ -43,13 +42,13 @@ class KerasCifarTest(googletest.TestCase):
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
self
.
_tempdir
=
tempfile
.
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasCifarTest
,
cls
).
setUpClass
()
keras
_cifar_main
.
define_cifar_flags
()
resnet
_cifar_main
.
define_cifar_flags
()
def
setUp
(
self
):
super
(
KerasCifarTest
,
self
).
setUp
()
...
...
@@ -72,7 +71,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -88,7 +87,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -112,7 +111,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -134,7 +133,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -157,7 +156,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -178,7 +177,7 @@ class KerasCifarTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_cifar_main
.
run
,
main
=
resnet
_cifar_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
official/
resnet/keras/keras
_imagenet_main.py
→
official/
vision/image_classification/resnet
_imagenet_main.py
View file @
901c4cc4
...
...
@@ -21,17 +21,17 @@ from __future__ import print_function
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.resnet.keras
import
imagenet_preprocessing
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_model
from
official.resnet.keras
import
trivial_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_model
from
official.vision.image_classification
import
trivial_model
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
...
...
@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Adjusted learning rate.
"""
initial_lr
=
keras_
common
.
BASE_LEARNING_RATE
*
batch_size
/
256
initial_lr
=
common
.
BASE_LEARNING_RATE
*
batch_size
/
256
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
if
epoch
<
warmup_end_epoch
:
...
...
@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance
if
flags_obj
.
tf_gpu_thread_mode
:
keras_
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
if
flags_obj
.
data_delay_prefetch
:
keras_
common
.
data_delay_prefetch
()
keras_
common
.
set_cudnn_batchnorm_mode
()
common
.
data_delay_prefetch
()
common
.
set_cudnn_batchnorm_mode
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'float16'
:
...
...
@@ -105,10 +105,14 @@ def run(flags_obj):
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
# Configures cluster spec for distribution strategy.
num_workers
=
distribution_utils
.
configure_cluster
(
flags_obj
.
worker_hosts
,
flags_obj
.
task_index
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_workers
=
distribution_utils
.
configure_cluster
()
,
num_workers
=
num_workers
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
num_packs
=
flags_obj
.
num_packs
)
...
...
@@ -125,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
distribution_utils
.
set_up_synthetic_data
()
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
...
...
@@ -165,7 +169,7 @@ def run(flags_obj):
lr_schedule
=
0.1
if
flags_obj
.
use_tensor_lr
:
lr_schedule
=
keras_
common
.
PiecewiseConstantDecayWithWarmup
(
lr_schedule
=
common
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
flags_obj
.
batch_size
,
epoch_size
=
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
],
warmup_epochs
=
LR_SCHEDULE
[
0
][
1
],
...
...
@@ -174,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu
=
True
)
with
strategy_scope
:
optimizer
=
keras_
common
.
get_optimizer
(
lr_schedule
)
optimizer
=
common
.
get_optimizer
(
lr_schedule
)
if
dtype
==
'float16'
:
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
...
...
@@ -212,7 +216,7 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
callbacks
=
keras_
common
.
get_callbacks
(
callbacks
=
common
.
get_callbacks
(
learning_rate_schedule
,
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
])
train_steps
=
(
...
...
@@ -262,13 +266,14 @@ def run(flags_obj):
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
def
define_imagenet_keras_flags
():
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
flags_core
.
set_defaults
(
train_epochs
=
90
)
flags
.
adopt_module_key_flags
(
common
)
def
main
(
_
):
...
...
official/
resnet/keras/keras
_imagenet_test.py
→
official/
vision/image_classification/resnet
_imagenet_test.py
View file @
901c4cc4
...
...
@@ -18,16 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tempfile
import
tensorflow
as
tf
from
official.resnet.keras
import
imagenet_preprocessing
from
official.resnet.keras
import
keras_imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
# pylint: disable=ungrouped-imports
from
tensorflow.python.eager
import
context
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_imagenet_main
class
KerasImagenetTest
(
googletest
.
TestCase
):
...
...
@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
self
.
_tempdir
=
tempfile
.
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasImagenetTest
,
cls
).
setUpClass
()
keras
_imagenet_main
.
define_imagenet_keras_flags
()
resnet
_imagenet_main
.
define_imagenet_keras_flags
()
def
setUp
(
self
):
super
(
KerasImagenetTest
,
self
).
setUp
()
...
...
@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
)
...
...
official/
resnet/keras
/resnet_model.py
→
official/
vision/image_classification
/resnet_model.py
View file @
901c4cc4
File moved
official/
resnet/keras
/trivial_model.py
→
official/
vision/image_classification
/trivial_model.py
View file @
901c4cc4
File moved
official/wide_deep/__init__.py
deleted
100644 → 0
View file @
ef30de93
research/lstm_object_detection/README.md
View file @
901c4cc4
...
...
@@ -32,3 +32,8 @@ https://scholar.googleusercontent.com/scholar.bib?q=info:rLqvkztmWYgJ:scholar.go
*
yinxiao@google.com
*
menglong@google.com
*
yongzhe@google.com
## Table of Contents
*
<a
href=
'g3doc/exporting_models.md'
>
Exporting a trained model
</a>
research/lstm_object_detection/configs/lstm_ssd_interleaved_mobilenet_v2_imagenet.config
0 → 100644
View file @
901c4cc4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# For training on Imagenet Video with LSTM Interleaved Mobilenet V2
[
lstm_object_detection
.
protos
.
lstm_model
] {
train_unroll_length
:
4
eval_unroll_length
:
4
lstm_state_depth
:
320
depth_multipliers
:
1
.
4
depth_multipliers
:
0
.
35
pre_bottleneck
:
true
low_res
:
true
train_interleave_method
:
'RANDOM_SKIP_SMALL'
eval_interleave_method
:
'SKIP3'
}
model
{
ssd
{
num_classes
:
30
# Num of class for imagenet vid dataset.
box_coder
{
faster_rcnn_box_coder
{
y_scale
:
10
.
0
x_scale
:
10
.
0
height_scale
:
5
.
0
width_scale
:
5
.
0
}
}
matcher
{
argmax_matcher
{
matched_threshold
:
0
.
5
unmatched_threshold
:
0
.
5
ignore_thresholds
:
false
negatives_lower_than_unmatched
:
true
force_match_for_each_row
:
true
}
}
similarity_calculator
{
iou_similarity
{
}
}
anchor_generator
{
ssd_anchor_generator
{
num_layers
:
5
min_scale
:
0
.
2
max_scale
:
0
.
95
aspect_ratios
:
1
.
0
aspect_ratios
:
2
.
0
aspect_ratios
:
0
.
5
aspect_ratios
:
3
.
0
aspect_ratios
:
0
.
3333
}
}
image_resizer
{
fixed_shape_resizer
{
height
:
320
width
:
320
}
}
box_predictor
{
convolutional_box_predictor
{
min_depth
:
0
max_depth
:
0
num_layers_before_predictor
:
3
use_dropout
:
false
dropout_keep_probability
:
0
.
8
kernel_size
:
3
box_code_size
:
4
apply_sigmoid_to_scores
:
false
use_depthwise
:
true
conv_hyperparams
{
activation
:
RELU_6
,
regularizer
{
l2_regularizer
{
weight
:
0
.
00004
}
}
initializer
{
truncated_normal_initializer
{
stddev
:
0
.
03
mean
:
0
.
0
}
}
batch_norm
{
train
:
true
,
scale
:
true
,
center
:
true
,
decay
:
0
.
9997
,
epsilon
:
0
.
001
,
}
}
}
}
feature_extractor
{
type
:
'lstm_ssd_interleaved_mobilenet_v2'
conv_hyperparams
{
activation
:
RELU_6
,
regularizer
{
l2_regularizer
{
weight
:
0
.
00004
}
}
initializer
{
truncated_normal_initializer
{
stddev
:
0
.
03
mean
:
0
.
0
}
}
batch_norm
{
train
:
true
,
scale
:
true
,
center
:
true
,
decay
:
0
.
9997
,
epsilon
:
0
.
001
,
}
}
}
loss
{
classification_loss
{
weighted_sigmoid
{
}
}
localization_loss
{
weighted_smooth_l1
{
}
}
hard_example_miner
{
num_hard_examples
:
3000
iou_threshold
:
0
.
99
loss_type
:
CLASSIFICATION
max_negatives_per_positive
:
3
min_negatives_per_image
:
0
}
classification_weight
:
1
.
0
localization_weight
:
4
.
0
}
normalize_loss_by_num_matches
:
true
post_processing
{
batch_non_max_suppression
{
score_threshold
: -
20
.
0
iou_threshold
:
0
.
5
max_detections_per_class
:
100
max_total_detections
:
100
}
score_converter
:
SIGMOID
}
}
}
train_config
: {
batch_size
:
8
optimizer
{
use_moving_average
:
false
rms_prop_optimizer
: {
learning_rate
: {
exponential_decay_learning_rate
{
initial_learning_rate
:
0
.
002
decay_steps
:
200000
decay_factor
:
0
.
95
}
}
momentum_optimizer_value
:
0
.
9
decay
:
0
.
9
epsilon
:
1
.
0
}
}
gradient_clipping_by_norm
:
10
.
0
batch_queue_capacity
:
12
prefetch_queue_capacity
:
4
}
train_input_reader
: {
shuffle_buffer_size
:
32
queue_capacity
:
12
prefetch_size
:
12
min_after_dequeue
:
4
label_map_path
:
"path/to/label_map"
external_input_reader
{
[
lstm_object_detection
.
protos
.
GoogleInputReader
.
google_input_reader
] {
tf_record_video_input_reader
: {
input_path
:
'/data/lstm_detection/tfrecords/test.tfrecord'
data_type
:
TF_SEQUENCE_EXAMPLE
video_length
:
4
}
}
}
}
eval_config
: {
metrics_set
:
"coco_evaluation_all_frames"
use_moving_averages
:
true
min_score_threshold
:
0
.
5
max_num_boxes_to_visualize
:
300
visualize_groundtruth_boxes
:
true
groundtruth_box_visualization_color
:
"red"
}
eval_input_reader
{
label_map_path
:
"path/to/label_map"
shuffle
:
true
num_epochs
:
1
num_parallel_batches
:
1
num_readers
:
1
external_input_reader
{
[
lstm_object_detection
.
protos
.
GoogleInputReader
.
google_input_reader
] {
tf_record_video_input_reader
: {
input_path
:
"path/to/sequence_example/data"
data_type
:
TF_SEQUENCE_EXAMPLE
video_length
:
10
}
}
}
}
eval_input_reader
: {
label_map_path
:
"path/to/label_map"
external_input_reader
{
[
lstm_object_detection
.
protos
.
GoogleInputReader
.
google_input_reader
] {
tf_record_video_input_reader
: {
input_path
:
"path/to/sequence_example/data"
data_type
:
TF_SEQUENCE_EXAMPLE
video_length
:
4
}
}
}
shuffle
:
true
num_readers
:
1
}
research/lstm_object_detection/eval.py
View file @
901c4cc4
...
...
@@ -27,8 +27,6 @@ import functools
import
os
import
tensorflow
as
tf
from
google.protobuf
import
text_format
from
google3.pyglib
import
app
from
google3.pyglib
import
flags
from
lstm_object_detection
import
evaluator
from
lstm_object_detection
import
model_builder
from
lstm_object_detection.inputs
import
seq_dataset_builder
...
...
@@ -107,4 +105,4 @@ def main(unused_argv):
FLAGS
.
checkpoint_dir
,
FLAGS
.
eval_dir
)
if
__name__
==
'__main__'
:
app
.
run
()
tf
.
app
.
run
()
research/lstm_object_detection/export_tflite_lstd_graph.py
0 → 100644
View file @
901c4cc4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
* A tflite compatible frozen graph - $output_directory/tflite_graph.pb
The exported graph has the following input and output nodes.
Inputs:
'input_video_tensor': a float32 tensor of shape
[unroll_length, height, width, 3] containing the normalized input image.
Note that the height and width must be compatible with the height and
width configured in the fixed_shape_image resizer options in the pipeline
config proto.
Outputs:
If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected boxes
else:
the graph has three outputs:
'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
containing the encoded box predictions.
'raw_outputs/class_predictions': a float32 tensor of shape
[1, num_anchors, num_classes] containing the class scores for each anchor
after applying score conversion.
'anchors': a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
Example Usage:
--------------
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
The expected output would be in the directory
path/to/exported_model_directory (which is created if it does not exist)
with contents:
- tflite_graph.pbtxt
- tflite_graph.pb
Config overrides (see the `config_override` flag) are text protobufs
(also of type pipeline_pb2.TrainEvalPipelineConfig) which are used to override
certain fields in the provided pipeline_config_path. These are useful for
making small changes to the inference graph that differ from the training or
eval config.
Example Usage (in which we change the NMS iou_threshold to be 0.5 and
NMS score_threshold to be 0.0):
python lstm_object_detection/export_tflite_lstd_graph.py \
--pipeline_config_path path/to/lstm_pipeline.config \
--trained_checkpoint_prefix path/to/model.ckpt \
--output_directory path/to/exported_model_directory
--config_override " \
model{ \
ssd{ \
post_processing { \
batch_non_max_suppression { \
score_threshold: 0.0 \
iou_threshold: 0.5 \
} \
} \
} \
} \
"
"""
import
tensorflow
as
tf
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
from
lstm_object_detection.utils
import
config_util
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
flags
.
DEFINE_string
(
'pipeline_config_path'
,
None
,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config '
'file.'
)
flags
.
DEFINE_string
(
'trained_checkpoint_prefix'
,
None
,
'Checkpoint prefix.'
)
flags
.
DEFINE_integer
(
'max_detections'
,
10
,
'Maximum number of detections (boxes) to show.'
)
flags
.
DEFINE_integer
(
'max_classes_per_detection'
,
1
,
'Maximum number of classes to output per detection box.'
)
flags
.
DEFINE_integer
(
'detections_per_class'
,
100
,
'Number of anchors used per class in Regular Non-Max-Suppression.'
)
flags
.
DEFINE_bool
(
'add_postprocessing_op'
,
True
,
'Add TFLite custom op for postprocessing to the graph.'
)
flags
.
DEFINE_bool
(
'use_regular_nms'
,
False
,
'Flag to set postprocessing op to use Regular NMS instead of Fast NMS.'
)
flags
.
DEFINE_string
(
'config_override'
,
''
,
'pipeline_pb2.TrainEvalPipelineConfig '
'text proto to override pipeline_config_path.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
argv
):
del
argv
# Unused.
flags
.
mark_flag_as_required
(
'output_directory'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
flags
.
mark_flag_as_required
(
'trained_checkpoint_prefix'
)
pipeline_config
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
pipeline_config
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
output_directory
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
(
main
)
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
0 → 100644
View file @
901c4cc4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r
"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
"""
import
os
import
tempfile
import
numpy
as
np
import
tensorflow
as
tf
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
lstm_object_detection
import
model_builder
from
object_detection
import
exporter
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.core
import
box_list
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_COORD_BOX
=
4
def
get_const_center_size_encoded_anchors
(
anchors
):
"""Exports center-size encoded anchors as a constant tensor.
Args:
anchors: a float32 tensor of shape [num_anchors, 4] containing the anchor
boxes
Returns:
encoded_anchors: a float32 constant tensor of shape [num_anchors, 4]
containing the anchor boxes.
"""
anchor_boxlist
=
box_list
.
BoxList
(
anchors
)
y
,
x
,
h
,
w
=
anchor_boxlist
.
get_center_coordinates_and_sizes
()
num_anchors
=
y
.
get_shape
().
as_list
()
with
tf
.
Session
()
as
sess
:
y_out
,
x_out
,
h_out
,
w_out
=
sess
.
run
([
y
,
x
,
h
,
w
])
encoded_anchors
=
tf
.
constant
(
np
.
transpose
(
np
.
stack
((
y_out
,
x_out
,
h_out
,
w_out
))),
dtype
=
tf
.
float32
,
shape
=
[
num_anchors
[
0
],
_DEFAULT_NUM_COORD_BOX
],
name
=
'anchors'
)
return
encoded_anchors
def
append_postprocessing_op
(
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
nms_score_threshold
,
nms_iou_threshold
,
num_classes
,
scale_values
,
detections_per_class
=
100
,
use_regular_nms
=
False
):
"""Appends postprocessing custom op.
Args:
frozen_graph_def: Frozen GraphDef for SSD model after freezing the
checkpoint
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
nms_score_threshold: Score threshold used in Non-maximal suppression in
post-processing
nms_iou_threshold: Intersection-over-union threshold used in Non-maximal
suppression in post-processing
num_classes: number of classes in SSD detector
scale_values: scale values is a dict with following key-value pairs
{y_scale: 10, x_scale: 10, h_scale: 5, w_scale: 5} that are used in decode
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
appended
TFLite_Detection_PostProcess custom op node has four outputs:
detection_boxes: a float32 tensor of shape [1, num_boxes, 4] with box
locations
detection_classes: a float32 tensor of shape [1, num_boxes]
with class indices
detection_scores: a float32 tensor of shape [1, num_boxes]
with class scores
num_boxes: a float32 tensor of size 1 containing the number of detected
boxes
"""
new_output
=
frozen_graph_def
.
node
.
add
()
new_output
.
op
=
'TFLite_Detection_PostProcess'
new_output
.
name
=
'TFLite_Detection_PostProcess'
new_output
.
attr
[
'_output_quantized'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
b
=
True
))
new_output
.
attr
[
'_output_types'
].
list
.
type
.
extend
([
types_pb2
.
DT_FLOAT
,
types_pb2
.
DT_FLOAT
,
types_pb2
.
DT_FLOAT
,
types_pb2
.
DT_FLOAT
])
new_output
.
attr
[
'_support_output_type_float_in_quantized_op'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
b
=
True
))
new_output
.
attr
[
'max_detections'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
i
=
max_detections
))
new_output
.
attr
[
'max_classes_per_detection'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
i
=
max_classes_per_detection
))
new_output
.
attr
[
'nms_score_threshold'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
nms_score_threshold
.
pop
()))
new_output
.
attr
[
'nms_iou_threshold'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
nms_iou_threshold
.
pop
()))
new_output
.
attr
[
'num_classes'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
i
=
num_classes
))
new_output
.
attr
[
'y_scale'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
scale_values
[
'y_scale'
].
pop
()))
new_output
.
attr
[
'x_scale'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
scale_values
[
'x_scale'
].
pop
()))
new_output
.
attr
[
'h_scale'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
scale_values
[
'h_scale'
].
pop
()))
new_output
.
attr
[
'w_scale'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
f
=
scale_values
[
'w_scale'
].
pop
()))
new_output
.
attr
[
'detections_per_class'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
i
=
detections_per_class
))
new_output
.
attr
[
'use_regular_nms'
].
CopyFrom
(
attr_value_pb2
.
AttrValue
(
b
=
use_regular_nms
))
new_output
.
input
.
extend
(
[
'raw_outputs/box_encodings'
,
'raw_outputs/class_predictions'
,
'anchors'
])
# Transform the graph to append new postprocessing op
input_names
=
[]
output_names
=
[
'TFLite_Detection_PostProcess'
]
transforms
=
[
'strip_unused_nodes'
]
transformed_graph_def
=
TransformGraph
(
frozen_graph_def
,
input_names
,
output_names
,
transforms
)
return
transformed_graph_def
def
export_tflite_graph
(
pipeline_config
,
trained_checkpoint_prefix
,
output_dir
,
add_postprocessing_op
,
max_detections
,
max_classes_per_detection
,
detections_per_class
=
100
,
use_regular_nms
=
False
,
binary_graph_name
=
'tflite_graph.pb'
,
txt_graph_name
=
'tflite_graph.pbtxt'
):
"""Exports a tflite compatible graph and anchors for ssd detection model.
Anchors are written to a tensor and tflite compatible graph
is written to output_dir/tflite_graph.pb.
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`, `train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`. Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
add_postprocessing_op: If add_postprocessing_op is true: frozen graph adds a
TFLite_Detection_PostProcess custom op
max_detections: Maximum number of detections (boxes) to show
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead of
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
Raises:
ValueError: if the pipeline config contains models other than ssd or uses an
fixed_shape_resizer and provides a shape as well.
"""
model_config
=
pipeline_config
[
'model'
]
lstm_config
=
pipeline_config
[
'lstm_model'
]
eval_config
=
pipeline_config
[
'eval_config'
]
tf
.
gfile
.
MakeDirs
(
output_dir
)
if
model_config
.
WhichOneof
(
'model'
)
!=
'ssd'
:
raise
ValueError
(
'Only ssd models are supported in tflite. '
'Found {} in config'
.
format
(
model_config
.
WhichOneof
(
'model'
)))
num_classes
=
model_config
.
ssd
.
num_classes
nms_score_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
}
nms_iou_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
}
scale_values
=
{}
scale_values
[
'y_scale'
]
=
{
model_config
.
ssd
.
box_coder
.
faster_rcnn_box_coder
.
y_scale
}
scale_values
[
'x_scale'
]
=
{
model_config
.
ssd
.
box_coder
.
faster_rcnn_box_coder
.
x_scale
}
scale_values
[
'h_scale'
]
=
{
model_config
.
ssd
.
box_coder
.
faster_rcnn_box_coder
.
height_scale
}
scale_values
[
'w_scale'
]
=
{
model_config
.
ssd
.
box_coder
.
faster_rcnn_box_coder
.
width_scale
}
image_resizer_config
=
model_config
.
ssd
.
image_resizer
image_resizer
=
image_resizer_config
.
WhichOneof
(
'image_resizer_oneof'
)
num_channels
=
_DEFAULT_NUM_CHANNELS
if
image_resizer
==
'fixed_shape_resizer'
:
height
=
image_resizer_config
.
fixed_shape_resizer
.
height
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
num_channels
=
1
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
else
:
raise
ValueError
(
'Only fixed_shape_resizer'
'is supported with tflite. Found {}'
.
format
(
image_resizer_config
.
WhichOneof
(
'image_resizer_oneof'
)))
video_tensor
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
detection_model
=
model_builder
.
build
(
model_config
,
lstm_config
,
is_training
=
False
)
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
tf
.
to_float
(
video_tensor
))
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
true_image_shapes
)
# predicted_tensors = detection_model.postprocess(predicted_tensors,
# true_image_shapes)
# The score conversion occurs before the post-processing custom op
_
,
score_conversion_fn
=
post_processing_builder
.
build
(
model_config
.
ssd
.
post_processing
)
class_predictions
=
score_conversion_fn
(
predicted_tensors
[
'class_predictions_with_background'
])
with
tf
.
name_scope
(
'raw_outputs'
):
# 'raw_outputs/box_encodings': a float32 tensor of shape [1, num_anchors, 4]
# containing the encoded box predictions. Note that these are raw
# predictions and no Non-Max suppression is applied on them and
# no decode center size boxes is applied to them.
tf
.
identity
(
predicted_tensors
[
'box_encodings'
],
name
=
'box_encodings'
)
# 'raw_outputs/class_predictions': a float32 tensor of shape
# [1, num_anchors, num_classes] containing the class scores for each anchor
# after applying score conversion.
tf
.
identity
(
class_predictions
,
name
=
'class_predictions'
)
# 'anchors': a float32 tensor of shape
# [4, num_anchors] containing the anchors as a constant node.
tf
.
identity
(
get_const_center_size_encoded_anchors
(
predicted_tensors
[
'anchors'
]),
name
=
'anchors'
)
# Add global step to the graph, so we know the training step number when we
# evaluate the model.
tf
.
train
.
get_or_create_global_step
()
# graph rewriter
is_quantized
=
(
'graph_rewriter'
in
pipeline_config
)
if
is_quantized
:
graph_rewriter_config
=
pipeline_config
[
'graph_rewriter'
]
graph_rewriter_fn
=
graph_rewriter_builder
.
build
(
graph_rewriter_config
,
is_training
=
False
,
is_export
=
True
)
graph_rewriter_fn
()
if
model_config
.
ssd
.
feature_extractor
.
HasField
(
'fpn'
):
exporter
.
rewrite_nn_resize_op
(
is_quantized
)
# freeze the graph
saver_kwargs
=
{}
if
eval_config
.
use_moving_averages
:
saver_kwargs
[
'write_version'
]
=
saver_pb2
.
SaverDef
.
V1
moving_average_checkpoint
=
tempfile
.
NamedTemporaryFile
()
exporter
.
replace_variable_values_with_moving_averages
(
tf
.
get_default_graph
(),
trained_checkpoint_prefix
,
moving_average_checkpoint
.
name
)
checkpoint_to_use
=
moving_average_checkpoint
.
name
else
:
checkpoint_to_use
=
trained_checkpoint_prefix
saver
=
tf
.
train
.
Saver
(
**
saver_kwargs
)
input_saver_def
=
saver
.
as_saver_def
()
frozen_graph_def
=
exporter
.
freeze_graph_with_def_protos
(
input_graph_def
=
tf
.
get_default_graph
().
as_graph_def
(),
input_saver_def
=
input_saver_def
,
input_checkpoint
=
checkpoint_to_use
,
output_node_names
=
','
.
join
([
'raw_outputs/box_encodings'
,
'raw_outputs/class_predictions'
,
'anchors'
]),
restore_op_name
=
'save/restore_all'
,
filename_tensor_name
=
'save/Const:0'
,
clear_devices
=
True
,
output_graph
=
''
,
initializer_nodes
=
''
)
# Add new operation to do post processing in a custom op (TF Lite only)
if
add_postprocessing_op
:
transformed_graph_def
=
append_postprocessing_op
(
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
nms_score_threshold
,
nms_iou_threshold
,
num_classes
,
scale_values
,
detections_per_class
,
use_regular_nms
)
else
:
# Return frozen without adding post-processing custom op
transformed_graph_def
=
frozen_graph_def
binary_graph
=
os
.
path
.
join
(
output_dir
,
binary_graph_name
)
with
tf
.
gfile
.
GFile
(
binary_graph
,
'wb'
)
as
f
:
f
.
write
(
transformed_graph_def
.
SerializeToString
())
txt_graph
=
os
.
path
.
join
(
output_dir
,
txt_graph_name
)
with
tf
.
gfile
.
GFile
(
txt_graph
,
'w'
)
as
f
:
f
.
write
(
str
(
transformed_graph_def
))
research/lstm_object_detection/export_tflite_lstd_model.py
0 → 100644
View file @
901c4cc4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Export a LSTD model in tflite format."""
import
os
from
absl
import
flags
import
tensorflow
as
tf
from
lstm_object_detection.utils
import
config_util
flags
.
DEFINE_string
(
'export_path'
,
None
,
'Path to export model.'
)
flags
.
DEFINE_string
(
'frozen_graph_path'
,
None
,
'Path to frozen graph.'
)
flags
.
DEFINE_string
(
'pipeline_config_path'
,
''
,
'Path to a pipeline_pb2.TrainEvalPipelineConfig config file.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
flags
.
mark_flag_as_required
(
'export_path'
)
flags
.
mark_flag_as_required
(
'frozen_graph_path'
)
flags
.
mark_flag_as_required
(
'pipeline_config_path'
)
configs
=
config_util
.
get_configs_from_pipeline_file
(
FLAGS
.
pipeline_config_path
)
lstm_config
=
configs
[
'lstm_model'
]
input_arrays
=
[
'input_video_tensor'
]
output_arrays
=
[
'TFLite_Detection_PostProcess'
,
'TFLite_Detection_PostProcess:1'
,
'TFLite_Detection_PostProcess:2'
,
'TFLite_Detection_PostProcess:3'
,
]
input_shapes
=
{
'input_video_tensor'
:
[
lstm_config
.
eval_unroll_length
,
320
,
320
,
3
],
}
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
FLAGS
.
frozen_graph_path
,
input_arrays
,
output_arrays
,
input_shapes
=
input_shapes
)
converter
.
allow_custom_ops
=
True
tflite_model
=
converter
.
convert
()
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
open
(
ofilename
,
'wb'
).
write
(
tflite_model
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/lstm_object_detection/g3doc/exporting_models.md
0 → 100644
View file @
901c4cc4
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
tensorflow/models/research:
```
bash
# from tensorflow/models/research
PIPELINE_CONFIG_PATH
={
path to pipeline config
}
TRAINED_CKPT_PREFIX
=
/
{
path to model.ckpt
}
EXPORT_DIR
={
path to folder that will be used
for
export
}
python lstm_object_detection/export_tflite_lstd_graph.py
\
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
\
--trained_checkpoint_prefix
${
TRAINED_CKPT_PREFIX
}
\
--output_directory
${
EXPORT_DIR
}
\
--add_preprocessing_op
```
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
*
`tflite_graph.pb`
*
`tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
TFLite FlatBuffer file by running the following:
```
bash
# from tensorflow/models/research
FROZEN_GRAPH_PATH
={
path to exported tflite_graph.pb
}
EXPORT_PATH
={
path to filename that will be used
for
export
}
PIPELINE_CONFIG_PATH
={
path to pipeline config
}
python lstm_object_detection/export_tflite_lstd_model.py
\
--export_path
${
EXPORT_PATH
}
\
--frozen_graph_path
${
FROZEN_GRAPH_PATH
}
\
--pipeline_config_path
${
PIPELINE_CONFIG_PATH
}
```
After export, you should see the file ${EXPORT_PATH} containing the FlatBuffer
model to be used by an application.
research/lstm_object_detection/inputs/seq_dataset_builder_test.py
View file @
901c4cc4
...
...
@@ -33,68 +33,6 @@ from object_detection.protos import preprocessor_pb2
class
DatasetBuilderTest
(
tf
.
test
.
TestCase
):
def
_create_tf_record
(
self
):
path
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
'tfrecord'
)
writer
=
tf
.
python_io
.
TFRecordWriter
(
path
)
image_tensor
=
np
.
random
.
randint
(
255
,
size
=
(
16
,
16
,
3
)).
astype
(
np
.
uint8
)
with
self
.
test_session
():
encoded_jpeg
=
tf
.
image
.
encode_jpeg
(
tf
.
constant
(
image_tensor
)).
eval
()
sequence_example
=
example_pb2
.
SequenceExample
(
context
=
feature_pb2
.
Features
(
feature
=
{
'image/format'
:
feature_pb2
.
Feature
(
bytes_list
=
feature_pb2
.
BytesList
(
value
=
[
'jpeg'
.
encode
(
'utf-8'
)])),
'image/height'
:
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
16
])),
'image/width'
:
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
16
])),
}),
feature_lists
=
feature_pb2
.
FeatureLists
(
feature_list
=
{
'image/encoded'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
bytes_list
=
feature_pb2
.
BytesList
(
value
=
[
encoded_jpeg
])),
]),
'image/object/bbox/xmin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'image/object/bbox/xmax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
'image/object/bbox/ymin'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
0.0
])),
]),
'image/object/bbox/ymax'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
float_list
=
feature_pb2
.
FloatList
(
value
=
[
1.0
]))
]),
'image/object/class/label'
:
feature_pb2
.
FeatureList
(
feature
=
[
feature_pb2
.
Feature
(
int64_list
=
feature_pb2
.
Int64List
(
value
=
[
2
]))
]),
}))
writer
.
write
(
sequence_example
.
SerializeToString
())
writer
.
close
()
return
path
def
_get_model_configs_from_proto
(
self
):
"""Creates a model text proto for testing.
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment