Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
901c4cc4
Commit
901c4cc4
authored
Aug 20, 2019
by
Vinh Nguyen
Browse files
Merge remote-tracking branch 'upstream/master' into amp_resnet50
parents
ef30de93
824ff2d6
Changes
86
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
73 additions
and
15 deletions
+73
-15
research/lstm_object_detection/test_tflite_model.py
research/lstm_object_detection/test_tflite_model.py
+53
-0
research/lstm_object_detection/tflite/BUILD
research/lstm_object_detection/tflite/BUILD
+9
-2
research/lstm_object_detection/tflite/WORKSPACE
research/lstm_object_detection/tflite/WORKSPACE
+0
-7
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
...lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
+5
-0
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
...h/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
+5
-5
research/lstm_object_detection/utils/config_util.py
research/lstm_object_detection/utils/config_util.py
+1
-1
No files found.
research/lstm_object_detection/test_tflite_model.py
0 → 100644
View file @
901c4cc4
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Test a tflite model using random input data."""
from
__future__
import
print_function
from
absl
import
flags
import
numpy
as
np
import
tensorflow
as
tf
flags
.
DEFINE_string
(
'model_path'
,
None
,
'Path to model.'
)
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
flags
.
mark_flag_as_required
(
'model_path'
)
# Load TFLite model and allocate tensors.
interpreter
=
tf
.
lite
.
Interpreter
(
model_path
=
FLAGS
.
model_path
)
interpreter
.
allocate_tensors
()
# Get input and output tensors.
input_details
=
interpreter
.
get_input_details
()
print
(
'input_details:'
,
input_details
)
output_details
=
interpreter
.
get_output_details
()
print
(
'output_details:'
,
output_details
)
# Test model on random input data.
input_shape
=
input_details
[
0
][
'shape'
]
# change the following line to feed into your own data.
input_data
=
np
.
array
(
np
.
random
.
random_sample
(
input_shape
),
dtype
=
np
.
float32
)
interpreter
.
set_tensor
(
input_details
[
0
][
'index'
],
input_data
)
interpreter
.
invoke
()
output_data
=
interpreter
.
get_tensor
(
output_details
[
0
][
'index'
])
print
(
output_data
)
if
__name__
==
'__main__'
:
tf
.
app
.
run
()
research/lstm_object_detection/tflite/BUILD
View file @
901c4cc4
...
...
@@ -59,12 +59,19 @@ cc_library(
name
=
"mobile_lstd_tflite_client"
,
srcs
=
[
"mobile_lstd_tflite_client.cc"
],
hdrs
=
[
"mobile_lstd_tflite_client.h"
],
defines
=
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"ENABLE_EDGETPU"
],
}),
deps
=
[
":mobile_ssd_client"
,
":mobile_ssd_tflite_client"
,
"@com_google_absl//absl/base:core_headers"
,
"@com_google_glog//:glog"
,
"@com_google_absl//absl/base:core_headers"
,
"@org_tensorflow//tensorflow/lite/kernels:builtin_ops"
,
],
]
+
select
({
"//conditions:default"
:
[],
"enable_edgetpu"
:
[
"@libedgetpu//libedgetpu:header"
],
}),
alwayslink
=
1
,
)
research/lstm_object_detection/tflite/WORKSPACE
View file @
901c4cc4
...
...
@@ -90,13 +90,6 @@ http_archive(
sha256
=
"79d102c61e2a479a0b7e5fc167bcfaa4832a0c6aad4a75fa7da0480564931bcc"
,
)
#
# http_archive(
# name = "com_google_protobuf",
# strip_prefix = "protobuf-master",
# urls = ["https://github.com/protocolbuffers/protobuf/archive/master.zip"],
# )
# Needed by TensorFlow
http_archive
(
name
=
"io_bazel_rules_closure"
,
...
...
research/lstm_object_detection/tflite/mobile_lstd_tflite_client.cc
View file @
901c4cc4
...
...
@@ -66,6 +66,11 @@ bool MobileLSTDTfLiteClient::InitializeInterpreter(
interpreter_
->
UseNNAPI
(
false
);
}
#ifdef ENABLE_EDGETPU
interpreter_
->
SetExternalContext
(
kTfLiteEdgeTpuContext
,
edge_tpu_context_
.
get
());
#endif
// Inputs are: normalized_input_image_tensor, raw_inputs/init_lstm_c,
// raw_inputs/init_lstm_h
if
(
interpreter_
->
inputs
().
size
()
!=
3
)
{
...
...
research/lstm_object_detection/tflite/mobile_ssd_tflite_client.h
View file @
901c4cc4
...
...
@@ -26,7 +26,7 @@ limitations under the License.
#include "mobile_ssd_client.h"
#include "protos/anchor_generation_options.pb.h"
#ifdef ENABLE_EDGETPU
#include "libedgetpu/
lib
edgetpu.h"
#include "libedgetpu/edgetpu.h"
#endif // ENABLE_EDGETPU
namespace
lstm_object_detection
{
...
...
@@ -76,6 +76,10 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
std
::
unique_ptr
<::
tflite
::
MutableOpResolver
>
resolver_
;
std
::
unique_ptr
<::
tflite
::
Interpreter
>
interpreter_
;
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
private:
// MobileSSDTfLiteClient is neither copyable nor movable.
MobileSSDTfLiteClient
(
const
MobileSSDTfLiteClient
&
)
=
delete
;
...
...
@@ -103,10 +107,6 @@ class MobileSSDTfLiteClient : public MobileSSDClient {
bool
FloatInference
(
const
uint8_t
*
input_data
);
bool
QuantizedInference
(
const
uint8_t
*
input_data
);
void
GetOutputBoxesAndScoreTensorsFromUInt8
();
#ifdef ENABLE_EDGETPU
std
::
unique_ptr
<
edgetpu
::
EdgeTpuContext
>
edge_tpu_context_
;
#endif
};
}
// namespace tflite
...
...
research/lstm_object_detection/utils/config_util.py
View file @
901c4cc4
...
...
@@ -37,7 +37,7 @@ def get_configs_from_pipeline_file(pipeline_config_path):
Returns:
Dictionary of configuration objects. Keys are `model`, `train_config`,
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_
confg
`.
`train_input_config`, `eval_config`, `eval_input_config`, `lstm_
model
`.
Value are the corresponding config objects.
"""
pipeline_config
=
pipeline_pb2
.
TrainEvalPipelineConfig
()
...
...
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment