Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
5b0ef1fc
"vscode:/vscode.git/clone" did not exist on "c2e8cbaa140986b6a27f2c795e2fb9b38e74f094"
Commit
5b0ef1fc
authored
Aug 23, 2019
by
Nimit Nigania
Browse files
Merge branch 'master' into ncf_f16
parents
1cba90f3
bf748370
Changes
92
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
545 additions
and
136 deletions
+545
-136
official/vision/image_classification/resnet_imagenet_main.py
official/vision/image_classification/resnet_imagenet_main.py
+22
-16
official/vision/image_classification/resnet_imagenet_test.py
official/vision/image_classification/resnet_imagenet_test.py
+18
-18
official/vision/image_classification/resnet_model.py
official/vision/image_classification/resnet_model.py
+389
-0
research/lstm_object_detection/export_tflite_lstd_graph.py
research/lstm_object_detection/export_tflite_lstd_graph.py
+10
-6
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
...rch/lstm_object_detection/export_tflite_lstd_graph_lib.py
+22
-24
research/lstm_object_detection/export_tflite_lstd_model.py
research/lstm_object_detection/export_tflite_lstd_model.py
+31
-28
research/lstm_object_detection/g3doc/exporting_models.md
research/lstm_object_detection/g3doc/exporting_models.md
+13
-13
research/lstm_object_detection/test_tflite_model.py
research/lstm_object_detection/test_tflite_model.py
+21
-18
research/lstm_object_detection/tflite/BUILD
research/lstm_object_detection/tflite/BUILD
+9
-2
research/lstm_object_detection/tflite/WORKSPACE
research/lstm_object_detection/tflite/WORKSPACE
+0
-6
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
...lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
+5
-0
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
...h/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
+5
-5
No files found.
official/
resnet/keras/keras
_imagenet_main.py
→
official/
vision/image_classification/resnet
_imagenet_main.py
View file @
5b0ef1fc
...
@@ -21,17 +21,17 @@ from __future__ import print_function
...
@@ -21,17 +21,17 @@ from __future__ import print_function
from
absl
import
app
as
absl_app
from
absl
import
app
as
absl_app
from
absl
import
flags
from
absl
import
flags
from
absl
import
logging
from
absl
import
logging
import
tensorflow
as
tf
# pylint: disable=g-bad-import-order
import
tensorflow
as
tf
from
official.resnet.keras
import
imagenet_preprocessing
from
official.resnet.keras
import
keras_common
from
official.resnet.keras
import
resnet_model
from
official.resnet.keras
import
trivial_model
from
official.utils.flags
import
core
as
flags_core
from
official.utils.flags
import
core
as
flags_core
from
official.utils.logs
import
logger
from
official.utils.logs
import
logger
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
distribution_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
keras_utils
from
official.utils.misc
import
model_helpers
from
official.utils.misc
import
model_helpers
from
official.vision.image_classification
import
common
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_model
from
official.benchmark.models
import
trivial_model
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
LR_SCHEDULE
=
[
# (multiplier, epoch to start) tuples
...
@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
...
@@ -57,7 +57,7 @@ def learning_rate_schedule(current_epoch,
Returns:
Returns:
Adjusted learning rate.
Adjusted learning rate.
"""
"""
initial_lr
=
keras_
common
.
BASE_LEARNING_RATE
*
batch_size
/
256
initial_lr
=
common
.
BASE_LEARNING_RATE
*
batch_size
/
256
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
epoch
=
current_epoch
+
float
(
current_batch
)
/
batches_per_epoch
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
warmup_lr_multiplier
,
warmup_end_epoch
=
LR_SCHEDULE
[
0
]
if
epoch
<
warmup_end_epoch
:
if
epoch
<
warmup_end_epoch
:
...
@@ -89,10 +89,10 @@ def run(flags_obj):
...
@@ -89,10 +89,10 @@ def run(flags_obj):
# Execute flag override logic for better model performance
# Execute flag override logic for better model performance
if
flags_obj
.
tf_gpu_thread_mode
:
if
flags_obj
.
tf_gpu_thread_mode
:
keras_
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
common
.
set_gpu_thread_mode_and_count
(
flags_obj
)
if
flags_obj
.
data_delay_prefetch
:
if
flags_obj
.
data_delay_prefetch
:
keras_
common
.
data_delay_prefetch
()
common
.
data_delay_prefetch
()
keras_
common
.
set_cudnn_batchnorm_mode
()
common
.
set_cudnn_batchnorm_mode
()
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
dtype
=
flags_core
.
get_tf_dtype
(
flags_obj
)
if
dtype
==
'float16'
:
if
dtype
==
'float16'
:
...
@@ -105,10 +105,14 @@ def run(flags_obj):
...
@@ -105,10 +105,14 @@ def run(flags_obj):
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
if
tf
.
test
.
is_built_with_cuda
()
else
'channels_last'
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
tf
.
keras
.
backend
.
set_image_data_format
(
data_format
)
# Configures cluster spec for distribution strategy.
num_workers
=
distribution_utils
.
configure_cluster
(
flags_obj
.
worker_hosts
,
flags_obj
.
task_index
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
flags_obj
.
distribution_strategy
,
distribution_strategy
=
flags_obj
.
distribution_strategy
,
num_gpus
=
flags_obj
.
num_gpus
,
num_gpus
=
flags_obj
.
num_gpus
,
num_workers
=
distribution_utils
.
configure_cluster
()
,
num_workers
=
num_workers
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
all_reduce_alg
=
flags_obj
.
all_reduce_alg
,
num_packs
=
flags_obj
.
num_packs
)
num_packs
=
flags_obj
.
num_packs
)
...
@@ -125,7 +129,7 @@ def run(flags_obj):
...
@@ -125,7 +129,7 @@ def run(flags_obj):
# pylint: disable=protected-access
# pylint: disable=protected-access
if
flags_obj
.
use_synthetic_data
:
if
flags_obj
.
use_synthetic_data
:
distribution_utils
.
set_up_synthetic_data
()
distribution_utils
.
set_up_synthetic_data
()
input_fn
=
keras_
common
.
get_synth_input_fn
(
input_fn
=
common
.
get_synth_input_fn
(
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
height
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
width
=
imagenet_preprocessing
.
DEFAULT_IMAGE_SIZE
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
num_channels
=
imagenet_preprocessing
.
NUM_CHANNELS
,
...
@@ -165,7 +169,7 @@ def run(flags_obj):
...
@@ -165,7 +169,7 @@ def run(flags_obj):
lr_schedule
=
0.1
lr_schedule
=
0.1
if
flags_obj
.
use_tensor_lr
:
if
flags_obj
.
use_tensor_lr
:
lr_schedule
=
keras_
common
.
PiecewiseConstantDecayWithWarmup
(
lr_schedule
=
common
.
PiecewiseConstantDecayWithWarmup
(
batch_size
=
flags_obj
.
batch_size
,
batch_size
=
flags_obj
.
batch_size
,
epoch_size
=
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
],
epoch_size
=
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
],
warmup_epochs
=
LR_SCHEDULE
[
0
][
1
],
warmup_epochs
=
LR_SCHEDULE
[
0
][
1
],
...
@@ -174,7 +178,7 @@ def run(flags_obj):
...
@@ -174,7 +178,7 @@ def run(flags_obj):
compute_lr_on_cpu
=
True
)
compute_lr_on_cpu
=
True
)
with
strategy_scope
:
with
strategy_scope
:
optimizer
=
keras_
common
.
get_optimizer
(
lr_schedule
)
optimizer
=
common
.
get_optimizer
(
lr_schedule
)
if
dtype
==
'float16'
:
if
dtype
==
'float16'
:
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# TODO(reedwm): Remove manually wrapping optimizer once mixed precision
# can be enabled with a single line of code.
# can be enabled with a single line of code.
...
@@ -182,6 +186,7 @@ def run(flags_obj):
...
@@ -182,6 +186,7 @@ def run(flags_obj):
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
optimizer
,
loss_scale
=
flags_core
.
get_loss_scale
(
flags_obj
,
default_for_fp16
=
128
))
default_for_fp16
=
128
))
# TODO(hongkuny): Remove trivial model usage and move it to benchmark.
if
flags_obj
.
use_trivial_model
:
if
flags_obj
.
use_trivial_model
:
model
=
trivial_model
.
trivial_model
(
model
=
trivial_model
.
trivial_model
(
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
)
imagenet_preprocessing
.
NUM_CLASSES
,
dtype
)
...
@@ -207,7 +212,7 @@ def run(flags_obj):
...
@@ -207,7 +212,7 @@ def run(flags_obj):
if
flags_obj
.
report_accuracy_metrics
else
None
),
if
flags_obj
.
report_accuracy_metrics
else
None
),
run_eagerly
=
flags_obj
.
run_eagerly
)
run_eagerly
=
flags_obj
.
run_eagerly
)
callbacks
=
keras_
common
.
get_callbacks
(
callbacks
=
common
.
get_callbacks
(
learning_rate_schedule
,
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
])
learning_rate_schedule
,
imagenet_preprocessing
.
NUM_IMAGES
[
'train'
])
train_steps
=
(
train_steps
=
(
...
@@ -257,13 +262,14 @@ def run(flags_obj):
...
@@ -257,13 +262,14 @@ def run(flags_obj):
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
if
not
strategy
and
flags_obj
.
explicit_gpu_placement
:
no_dist_strat_device
.
__exit__
()
no_dist_strat_device
.
__exit__
()
stats
=
keras_
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
stats
=
common
.
build_stats
(
history
,
eval_output
,
callbacks
)
return
stats
return
stats
def
define_imagenet_keras_flags
():
def
define_imagenet_keras_flags
():
keras_
common
.
define_keras_flags
()
common
.
define_keras_flags
()
flags_core
.
set_defaults
(
train_epochs
=
90
)
flags_core
.
set_defaults
(
train_epochs
=
90
)
flags
.
adopt_module_key_flags
(
common
)
def
main
(
_
):
def
main
(
_
):
...
...
official/
resnet/keras/keras
_imagenet_test.py
→
official/
vision/image_classification/resnet
_imagenet_test.py
View file @
5b0ef1fc
...
@@ -18,16 +18,16 @@ from __future__ import absolute_import
...
@@ -18,16 +18,16 @@ from __future__ import absolute_import
from
__future__
import
division
from
__future__
import
division
from
__future__
import
print_function
from
__future__
import
print_function
from
tempfile
import
mkdtemp
import
tempfile
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.resnet.keras
import
imagenet_preprocessing
from
official.resnet.keras
import
keras_imagenet_main
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
# pylint: disable=ungrouped-imports
from
tensorflow.python.eager
import
context
from
tensorflow.python.eager
import
context
from
tensorflow.python.platform
import
googletest
from
tensorflow.python.platform
import
googletest
from
official.utils.misc
import
keras_utils
from
official.utils.testing
import
integration
from
official.vision.image_classification
import
imagenet_preprocessing
from
official.vision.image_classification
import
resnet_imagenet_main
class
KerasImagenetTest
(
googletest
.
TestCase
):
class
KerasImagenetTest
(
googletest
.
TestCase
):
...
@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -42,13 +42,13 @@ class KerasImagenetTest(googletest.TestCase):
def
get_temp_dir
(
self
):
def
get_temp_dir
(
self
):
if
not
self
.
_tempdir
:
if
not
self
.
_tempdir
:
self
.
_tempdir
=
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
self
.
_tempdir
=
tempfile
.
mkdtemp
(
dir
=
googletest
.
GetTempDir
())
return
self
.
_tempdir
return
self
.
_tempdir
@
classmethod
@
classmethod
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
def
setUpClass
(
cls
):
# pylint: disable=invalid-name
super
(
KerasImagenetTest
,
cls
).
setUpClass
()
super
(
KerasImagenetTest
,
cls
).
setUpClass
()
keras
_imagenet_main
.
define_imagenet_keras_flags
()
resnet
_imagenet_main
.
define_imagenet_keras_flags
()
def
setUp
(
self
):
def
setUp
(
self
):
super
(
KerasImagenetTest
,
self
).
setUp
()
super
(
KerasImagenetTest
,
self
).
setUp
()
...
@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -71,7 +71,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -87,7 +87,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -111,7 +111,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -133,7 +133,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -156,7 +156,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -180,7 +180,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -204,7 +204,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -229,7 +229,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -250,7 +250,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
...
@@ -272,7 +272,7 @@ class KerasImagenetTest(googletest.TestCase):
extra_flags
=
extra_flags
+
self
.
_extra_flags
extra_flags
=
extra_flags
+
self
.
_extra_flags
integration
.
run_synthetic
(
integration
.
run_synthetic
(
main
=
keras
_imagenet_main
.
run
,
main
=
resnet
_imagenet_main
.
run
,
tmp_root
=
self
.
get_temp_dir
(),
tmp_root
=
self
.
get_temp_dir
(),
extra_flags
=
extra_flags
extra_flags
=
extra_flags
)
)
...
...
official/
resnet/keras
/resnet_model.py
→
official/
vision/image_classification
/resnet_model.py
View file @
5b0ef1fc
...
@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
...
@@ -39,7 +39,16 @@ BATCH_NORM_DECAY = 0.9
BATCH_NORM_EPSILON
=
1e-5
BATCH_NORM_EPSILON
=
1e-5
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
):
def
_gen_l2_regularizer
(
use_l2_regularizer
=
True
):
return
regularizers
.
l2
(
L2_WEIGHT_DECAY
)
if
use_l2_regularizer
else
None
def
identity_block
(
input_tensor
,
kernel_size
,
filters
,
stage
,
block
,
use_l2_regularizer
=
True
):
"""The identity block is the block that has no conv layer at shortcut.
"""The identity block is the block that has no conv layer at shortcut.
Args:
Args:
...
@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
...
@@ -48,6 +57,7 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
filters: list of integers, the filters of 3 conv layer at main path
filters: list of integers, the filters of 3 conv layer at main path
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
...
@@ -60,35 +70,51 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
name
=
conv_name_base
+
'2a'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
x
=
layers
.
Conv2D
(
padding
=
'same'
,
use_bias
=
False
,
filters2
,
kernel_size
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2b'
)(
x
)
name
=
conv_name_base
+
'2b'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2c'
)(
x
)
name
=
conv_name_base
+
'2c'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
name
=
bn_name_base
+
'2c'
)(
x
)
x
=
layers
.
add
([
x
,
input_tensor
])
x
=
layers
.
add
([
x
,
input_tensor
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
...
@@ -100,7 +126,8 @@ def conv_block(input_tensor,
...
@@ -100,7 +126,8 @@ def conv_block(input_tensor,
filters
,
filters
,
stage
,
stage
,
block
,
block
,
strides
=
(
2
,
2
)):
strides
=
(
2
,
2
),
use_l2_regularizer
=
True
):
"""A block that has a conv layer at shortcut.
"""A block that has a conv layer at shortcut.
Note that from stage 3,
Note that from stage 3,
...
@@ -114,6 +141,7 @@ def conv_block(input_tensor,
...
@@ -114,6 +141,7 @@ def conv_block(input_tensor,
stage: integer, current stage label, used for generating layer names
stage: integer, current stage label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
block: 'a','b'..., current block label, used for generating layer names
strides: Strides for the second conv layer in the block.
strides: Strides for the second conv layer in the block.
use_l2_regularizer: whether to use L2 regularizer on Conv layer.
Returns:
Returns:
Output tensor for the block.
Output tensor for the block.
...
@@ -126,114 +154,231 @@ def conv_block(input_tensor,
...
@@ -126,114 +154,231 @@ def conv_block(input_tensor,
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
conv_name_base
=
'res'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
bn_name_base
=
'bn'
+
str
(
stage
)
+
block
+
'_branch'
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters1
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2a'
)(
input_tensor
)
name
=
conv_name_base
+
'2a'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
input_tensor
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2a'
)(
x
)
name
=
bn_name_base
+
'2a'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters2
,
kernel_size
,
strides
=
strides
,
padding
=
'same'
,
x
=
layers
.
Conv2D
(
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
filters2
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_size
,
name
=
conv_name_base
+
'2b'
)(
x
)
strides
=
strides
,
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
padding
=
'same'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2b'
)(
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2b'
)(
x
)
name
=
bn_name_base
+
'2b'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
x
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'2c'
)(
x
)
name
=
conv_name_base
+
'2c'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'2c'
)(
x
)
name
=
bn_name_base
+
'2c'
)(
x
)
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
use_bias
=
False
,
shortcut
=
layers
.
Conv2D
(
filters3
,
(
1
,
1
),
strides
=
strides
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
conv_name_base
+
'1'
)(
input_tensor
)
name
=
conv_name_base
+
'1'
)(
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
input_tensor
)
shortcut
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
bn_name_base
+
'1'
)(
shortcut
)
name
=
bn_name_base
+
'1'
)(
shortcut
)
x
=
layers
.
add
([
x
,
shortcut
])
x
=
layers
.
add
([
x
,
shortcut
])
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
return
x
return
x
def
resnet50
(
num_classes
,
dtype
=
'float32'
,
batch_size
=
None
):
def
resnet50
(
num_classes
,
dtype
=
'float32'
,
batch_size
=
None
,
use_l2_regularizer
=
True
):
"""Instantiates the ResNet50 architecture.
"""Instantiates the ResNet50 architecture.
Args:
Args:
num_classes: `int` number of classes for image classification.
num_classes: `int` number of classes for image classification.
dtype: dtype to use float32 or float16 are most common.
dtype: dtype to use float32 or float16 are most common.
batch_size: Size of the batches for each step.
batch_size: Size of the batches for each step.
use_l2_regularizer: whether to use L2 regularizer on Conv/Dense layer.
Returns:
Returns:
A Keras model instance.
A Keras model instance.
"""
"""
input_shape
=
(
224
,
224
,
3
)
input_shape
=
(
224
,
224
,
3
)
img_input
=
layers
.
Input
(
shape
=
input_shape
,
dtype
=
dtype
,
img_input
=
layers
.
Input
(
batch_size
=
batch_size
)
shape
=
input_shape
,
dtype
=
dtype
,
batch_size
=
batch_size
)
if
backend
.
image_data_format
()
==
'channels_first'
:
if
backend
.
image_data_format
()
==
'channels_first'
:
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
x
=
layers
.
Lambda
(
name
=
'transpose'
)(
img_input
)
lambda
x
:
backend
.
permute_dimensions
(
x
,
(
0
,
3
,
1
,
2
)),
name
=
'transpose'
)(
img_input
)
bn_axis
=
1
bn_axis
=
1
else
:
# channels_last
else
:
# channels_last
x
=
img_input
x
=
img_input
bn_axis
=
3
bn_axis
=
3
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
ZeroPadding2D
(
padding
=
(
3
,
3
),
name
=
'conv1_pad'
)(
x
)
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
x
=
layers
.
Conv2D
(
64
,
(
7
,
7
),
strides
=
(
2
,
2
),
strides
=
(
2
,
2
),
padding
=
'valid'
,
use_bias
=
False
,
padding
=
'valid'
,
use_bias
=
False
,
kernel_initializer
=
'he_normal'
,
kernel_initializer
=
'he_normal'
,
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
'conv1'
)(
x
)
name
=
'conv1'
)(
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
x
)
x
=
layers
.
BatchNormalization
(
axis
=
bn_axis
,
momentum
=
BATCH_NORM_DECAY
,
momentum
=
BATCH_NORM_DECAY
,
epsilon
=
BATCH_NORM_EPSILON
,
epsilon
=
BATCH_NORM_EPSILON
,
name
=
'bn_conv1'
)(
x
)
name
=
'bn_conv1'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
Activation
(
'relu'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
layers
.
MaxPooling2D
((
3
,
3
),
strides
=
(
2
,
2
),
padding
=
'same'
)(
x
)
x
=
conv_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'a'
,
strides
=
(
1
,
1
))
x
=
conv_block
(
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'b'
)
x
,
x
=
identity_block
(
x
,
3
,
[
64
,
64
,
256
],
stage
=
2
,
block
=
'c'
)
3
,
[
64
,
64
,
256
],
stage
=
2
,
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
)
block
=
'a'
,
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
)
strides
=
(
1
,
1
),
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
)
x
=
identity_block
(
x
,
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
)
3
,
[
64
,
64
,
256
],
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
)
stage
=
2
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
)
block
=
'b'
,
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
)
x
=
identity_block
(
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
)
x
,
3
,
[
64
,
64
,
256
],
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
)
stage
=
2
,
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
)
block
=
'c'
,
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
)
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
128
,
128
,
512
],
stage
=
3
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'd'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'e'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
256
,
256
,
1024
],
stage
=
4
,
block
=
'f'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
conv_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'a'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'b'
,
use_l2_regularizer
=
use_l2_regularizer
)
x
=
identity_block
(
x
,
3
,
[
512
,
512
,
2048
],
stage
=
5
,
block
=
'c'
,
use_l2_regularizer
=
use_l2_regularizer
)
rm_axes
=
[
1
,
2
]
if
backend
.
image_data_format
()
==
'channels_last'
else
[
2
,
3
]
rm_axes
=
[
1
,
2
]
if
backend
.
image_data_format
()
==
'channels_last'
else
[
2
,
3
]
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
mean
(
x
,
rm_axes
),
name
=
'reduce_mean'
)(
x
)
x
=
layers
.
Lambda
(
lambda
x
:
backend
.
mean
(
x
,
rm_axes
),
name
=
'reduce_mean'
)(
x
)
x
=
layers
.
Dense
(
x
=
layers
.
Dense
(
num_classes
,
num_classes
,
kernel_initializer
=
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_initializer
=
initializers
.
RandomNormal
(
stddev
=
0.01
),
kernel_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
kernel_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
bias_regularizer
=
regularizers
.
l2
(
L2_WEIGHT_DECAY
),
bias_regularizer
=
_gen_l2_regularizer
(
use_l2_regularizer
),
name
=
'fc1000'
)(
x
)
name
=
'fc1000'
)(
x
)
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# TODO(reedwm): Remove manual casts once mixed precision can be enabled with a
# single line of code.
# single line of code.
...
...
research/lstm_object_detection/export_tflite_lstd_graph.py
View file @
5b0ef1fc
...
@@ -12,7 +12,6 @@
...
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
r
"""Exports an LSTM detection model to use with tf-lite.
r
"""Exports an LSTM detection model to use with tf-lite.
Outputs file:
Outputs file:
...
@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
...
@@ -86,8 +85,9 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
"""
"""
import
tensorflow
as
tf
import
tensorflow
as
tf
from
lstm_object_detection.utils
import
config_util
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
from
lstm_object_detection
import
export_tflite_lstd_graph_lib
from
lstm_object_detection.utils
import
config_util
flags
=
tf
.
app
.
flags
flags
=
tf
.
app
.
flags
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
flags
.
DEFINE_string
(
'output_directory'
,
None
,
'Path to write outputs.'
)
...
@@ -125,9 +125,13 @@ def main(argv):
...
@@ -125,9 +125,13 @@ def main(argv):
FLAGS
.
pipeline_config_path
)
FLAGS
.
pipeline_config_path
)
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
export_tflite_lstd_graph_lib
.
export_tflite_graph
(
pipeline_config
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
output_directory
,
pipeline_config
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
trained_checkpoint_prefix
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
FLAGS
.
output_directory
,
FLAGS
.
add_postprocessing_op
,
FLAGS
.
max_detections
,
FLAGS
.
max_classes_per_detection
,
use_regular_nms
=
FLAGS
.
use_regular_nms
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/export_tflite_lstd_graph_lib.py
View file @
5b0ef1fc
...
@@ -12,26 +12,26 @@
...
@@ -12,26 +12,26 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
r
"""Exports detection models to use with tf-lite.
r
"""Exports detection models to use with tf-lite.
See export_tflite_lstd_graph.py for usage.
See export_tflite_lstd_graph.py for usage.
"""
"""
import
os
import
os
import
tempfile
import
tempfile
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
attr_value_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.framework
import
types_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.core.protobuf
import
saver_pb2
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
tensorflow.tools.graph_transforms
import
TransformGraph
from
lstm_object_detection
import
model_builder
from
object_detection
import
exporter
from
object_detection
import
exporter
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
graph_rewriter_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.builders
import
post_processing_builder
from
object_detection.core
import
box_list
from
object_detection.core
import
box_list
from
lstm_object_detection
import
model_builder
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_CHANNELS
=
3
_DEFAULT_NUM_COORD_BOX
=
4
_DEFAULT_NUM_COORD_BOX
=
4
...
@@ -87,8 +87,8 @@ def append_postprocessing_op(frozen_graph_def,
...
@@ -87,8 +87,8 @@ def append_postprocessing_op(frozen_graph_def,
centersize boxes
centersize boxes
detections_per_class: In regular NonMaxSuppression, number of anchors used
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
of
Fast NMS.
Fast NMS.
Returns:
Returns:
transformed_graph_def: Frozen GraphDef with postprocessing custom op
transformed_graph_def: Frozen GraphDef with postprocessing custom op
...
@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
...
@@ -165,9 +165,9 @@ def export_tflite_graph(pipeline_config,
is written to output_dir/tflite_graph.pb.
is written to output_dir/tflite_graph.pb.
Args:
Args:
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_config`,
pipeline_config: Dictionary of configuration objects. Keys are `model`,
`train_input_config`, `eval_config`, `eval_input_config`,
`lstm_model`.
`train_config`,
`train_input_config`, `eval_config`, `eval_input_config`,
Value are the corresponding config objects.
`lstm_model`.
Value are the corresponding config objects.
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained_checkpoint_prefix: a file prefix for the checkpoint containing the
trained parameters of the SSD model.
trained parameters of the SSD model.
output_dir: A directory to write the tflite graph and anchor file to.
output_dir: A directory to write the tflite graph and anchor file to.
...
@@ -177,8 +177,8 @@ def export_tflite_graph(pipeline_config,
...
@@ -177,8 +177,8 @@ def export_tflite_graph(pipeline_config,
max_classes_per_detection: Number of classes to display per detection
max_classes_per_detection: Number of classes to display per detection
detections_per_class: In regular NonMaxSuppression, number of anchors used
detections_per_class: In regular NonMaxSuppression, number of anchors used
for NonMaxSuppression per class
for NonMaxSuppression per class
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
use_regular_nms: Flag to set postprocessing op to use Regular NMS instead
of
of
Fast NMS.
Fast NMS.
binary_graph_name: Name of the exported graph file in binary format.
binary_graph_name: Name of the exported graph file in binary format.
txt_graph_name: Name of the exported graph file in text format.
txt_graph_name: Name of the exported graph file in text format.
...
@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config,
...
@@ -197,12 +197,10 @@ def export_tflite_graph(pipeline_config,
num_classes
=
model_config
.
ssd
.
num_classes
num_classes
=
model_config
.
ssd
.
num_classes
nms_score_threshold
=
{
nms_score_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
score_threshold
score_threshold
}
}
nms_iou_threshold
=
{
nms_iou_threshold
=
{
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
model_config
.
ssd
.
post_processing
.
batch_non_max_suppression
.
iou_threshold
iou_threshold
}
}
scale_values
=
{}
scale_values
=
{}
scale_values
[
'y_scale'
]
=
{
scale_values
[
'y_scale'
]
=
{
...
@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config,
...
@@ -226,7 +224,7 @@ def export_tflite_graph(pipeline_config,
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
width
=
image_resizer_config
.
fixed_shape_resizer
.
width
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
if
image_resizer_config
.
fixed_shape_resizer
.
convert_to_grayscale
:
num_channels
=
1
num_channels
=
1
#TODO(richardbrks) figure out how to make with a None defined batch size
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
shape
=
[
lstm_config
.
eval_unroll_length
,
height
,
width
,
num_channels
]
else
:
else
:
raise
ValueError
(
raise
ValueError
(
...
@@ -237,8 +235,8 @@ def export_tflite_graph(pipeline_config,
...
@@ -237,8 +235,8 @@ def export_tflite_graph(pipeline_config,
video_tensor
=
tf
.
placeholder
(
video_tensor
=
tf
.
placeholder
(
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
tf
.
float32
,
shape
=
shape
,
name
=
'input_video_tensor'
)
detection_model
=
model_builder
.
build
(
model_config
,
lstm_config
,
detection_model
=
model_builder
.
build
(
is_training
=
False
)
model_config
,
lstm_config
,
is_training
=
False
)
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
preprocessed_video
,
true_image_shapes
=
detection_model
.
preprocess
(
tf
.
to_float
(
video_tensor
))
tf
.
to_float
(
video_tensor
))
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
predicted_tensors
=
detection_model
.
predict
(
preprocessed_video
,
...
@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config,
...
@@ -311,7 +309,7 @@ def export_tflite_graph(pipeline_config,
initializer_nodes
=
''
)
initializer_nodes
=
''
)
# Add new operation to do post processing in a custom op (TF Lite only)
# Add new operation to do post processing in a custom op (TF Lite only)
#(richardbrks) Do use this or detection_model.postprocess?
if
add_postprocessing_op
:
if
add_postprocessing_op
:
transformed_graph_def
=
append_postprocessing_op
(
transformed_graph_def
=
append_postprocessing_op
(
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
frozen_graph_def
,
max_detections
,
max_classes_per_detection
,
...
...
research/lstm_object_detection/export_tflite_lstd_model.py
View file @
5b0ef1fc
...
@@ -13,6 +13,8 @@
...
@@ -13,6 +13,8 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Export a LSTD model in tflite format."""
import
os
import
os
from
absl
import
flags
from
absl
import
flags
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -49,13 +51,14 @@ def main(_):
...
@@ -49,13 +51,14 @@ def main(_):
}
}
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
converter
=
tf
.
lite
.
TFLiteConverter
.
from_frozen_graph
(
FLAGS
.
frozen_graph_path
,
input_arrays
,
output_arrays
,
FLAGS
.
frozen_graph_path
,
input_shapes
=
input_shapes
input_arrays
,
)
output_arrays
,
input_shapes
=
input_shapes
)
converter
.
allow_custom_ops
=
True
converter
.
allow_custom_ops
=
True
tflite_model
=
converter
.
convert
()
tflite_model
=
converter
.
convert
()
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
ofilename
=
os
.
path
.
join
(
FLAGS
.
export_path
)
open
(
ofilename
,
"
wb
"
).
write
(
tflite_model
)
open
(
ofilename
,
'
wb
'
).
write
(
tflite_model
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/g3doc/exporting_models.md
View file @
5b0ef1fc
# Exporting a tflite model from a checkpoint
# Exporting a tflite model from a checkpoint
Starting from a trained model checkpoint, creating a tflite model requires 2 steps:
Starting from a trained model checkpoint, creating a tflite model requires 2
steps:
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite frozen graph from a checkpoint
*
exporting a tflite model from a frozen graph
*
exporting a tflite model from a frozen graph
## Exporting a tflite frozen graph from a checkpoint
## Exporting a tflite frozen graph from a checkpoint
With a candidate checkpoint to export, run the following command from
With a candidate checkpoint to export, run the following command from
...
@@ -23,12 +23,12 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
...
@@ -23,12 +23,12 @@ python lstm_object_detection/export_tflite_lstd_graph.py \
--add_preprocessing_op
--add_preprocessing_op
```
```
After export, you should see the directory ${EXPORT_DIR} containing the following files:
After export, you should see the directory ${EXPORT_DIR} containing the
following files:
*
`tflite_graph.pb`
*
`tflite_graph.pb`
*
`tflite_graph.pbtxt`
*
`tflite_graph.pbtxt`
## Exporting a tflite model from a frozen graph
## Exporting a tflite model from a frozen graph
We then take the exported tflite-compatable tflite model, and convert it to a
We then take the exported tflite-compatable tflite model, and convert it to a
...
...
research/lstm_object_detection/test_tflite_model.py
View file @
5b0ef1fc
...
@@ -13,6 +13,9 @@
...
@@ -13,6 +13,9 @@
# limitations under the License.
# limitations under the License.
# ==============================================================================
# ==============================================================================
"""Test a tflite model using random input data."""
from
__future__
import
print_function
from
absl
import
flags
from
absl
import
flags
import
numpy
as
np
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow
as
tf
...
@@ -31,9 +34,9 @@ def main(_):
...
@@ -31,9 +34,9 @@ def main(_):
# Get input and output tensors.
# Get input and output tensors.
input_details
=
interpreter
.
get_input_details
()
input_details
=
interpreter
.
get_input_details
()
print
'input_details:'
,
input_details
print
(
'input_details:'
,
input_details
)
output_details
=
interpreter
.
get_output_details
()
output_details
=
interpreter
.
get_output_details
()
print
'output_details:'
,
output_details
print
(
'output_details:'
,
output_details
)
# Test model on random input data.
# Test model on random input data.
input_shape
=
input_details
[
0
][
'shape'
]
input_shape
=
input_details
[
0
][
'shape'
]
...
@@ -43,7 +46,7 @@ def main(_):
...
@@ -43,7 +46,7 @@ def main(_):
interpreter
.
invoke
()
interpreter
.
invoke
()
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
print
output_data
print
(
output_data
)
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
research/lstm_object_detection/tflite/BUILD
View file @
5b0ef1fc
...
@@ -59,12 +59,19 @@ cc_library(
...
@@ -59,12 +59,19 @@ cc_library(
name
=
"mobile_lstd_tflite_client"
,
name
=
"mobile_lstd_tflite_client"
,
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
defines
=
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"ENABLE_EDGETPU"
],
}),
deps
=
[
deps
=
[
":mobile_ssd_client"
,
":mobile_ssd_client"
,
":mobile_ssd_tflite_client"
,
":mobile_ssd_tflite_client"
,
"@com_google_absl//absl/base:core_headers"
,
"@com_google_glog//:glog"
,
"@com_google_glog//:glog"
,
"@com_google_absl//absl/base:core_headers"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
],
]
+
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"@libedgetpu//libedgetpu:header"
],
}),
alwayslink
=
1
,
alwayslink
=
1
,
)
)
research/lstm_object_detection/tflite/WORKSPACE
View file @
5b0ef1fc
...
@@ -90,12 +90,6 @@ http_archive(
...
@@ -90,12 +90,6 @@ http_archive(
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
)
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
# Needed by TensorFlow
http_archive
(
http_archive
(
...
...
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
View file @
5b0ef1fc
...
@@ -66,6 +66,11 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
...
@@ -66,6 +66,11 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_
->
UseNNAPI
(
false
);
interpreter_
->
UseNNAPI
(
false
);
}
}
#ifdef ENABLE_EDGETPU
interpreter_
->
SetExternalContext
(
kTfLiteEdgeTpuContext
,
edge_tpu_context_
.
get
());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
// raw_inputs/init_lstm_h
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
...
...
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
View file @
5b0ef1fc
...
@@ -26,7 +26,7 @@ limitations under the License.
...
@@ -26,7 +26,7 @@ limitations under the License.
#include "mobile_ssd_client.h"
#include "mobile_ssd_client.h"
#include "protos/anchor_generation_options.pb.h"
#include "protos/anchor_generation_options.pb.h"
#ifdef ENABLE_EDGETPU
#ifdef ENABLE_EDGETPU
#include "libedgetpu/
lib
edgetpu.h"
#include "libedgetpu/edgetpu.h"
#endif // ENABLE_EDGETPU
#endif // ENABLE_EDGETPU
namespace
lstm_object_detection
{
namespace
lstm_object_detection
{
...
@@ -76,6 +76,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
...
@@ -76,6 +76,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
private:
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
...
@@ -103,10 +107,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
...
@@ -103,10 +107,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
};
};
}
// namespace tflite
}
// namespace tflite
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment