Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
b66b0b05
Unverified
Commit
b66b0b05
authored
Sep 08, 2021
by
Dan Ellis
Committed by
GitHub
Sep 08, 2021
Browse files
Explicit signatures for tflite. Using ideas from #9688 (#10248)
parent
c636ea33
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
22 additions
and
21 deletions
+22
-21
research/audioset/yamnet/export.py
research/audioset/yamnet/export.py
+22
-21
No files found.
research/audioset/yamnet/export.py
View file @
b66b0b05
...
@@ -44,20 +44,24 @@ def log(msg):
...
@@ -44,20 +44,24 @@ def log(msg):
class
YAMNet
(
tf
.
Module
):
class
YAMNet
(
tf
.
Module
):
"
''
A TF2 Module wrapper around YAMNet."""
"
""
A TF2 Module wrapper around YAMNet."""
def
__init__
(
self
,
weights_path
,
params
):
def
__init__
(
self
,
weights_path
,
params
):
super
().
__init__
()
super
().
__init__
()
self
.
_yamnet
=
yamnet
.
yamnet_frames_model
(
params
)
self
.
_yamnet
=
yamnet
.
yamnet_frames_model
(
params
)
self
.
_yamnet
.
load_weights
(
weights_path
)
self
.
_yamnet
.
load_weights
(
weights_path
)
self
.
_class_map_asset
=
tf
.
saved_model
.
Asset
(
'yamnet_class_map.csv'
)
self
.
_class_map_asset
=
tf
.
saved_model
.
Asset
(
'yamnet_class_map.csv'
)
@tf.function
@
tf
.
function
(
input_signature
=
[])
def
class_map_path
(
self
):
def
class_map_path
(
self
):
return
self
.
_class_map_asset
.
asset_path
return
self
.
_class_map_asset
.
asset_path
@tf.function(input_signature=
(
tf.TensorSpec(shape=[None], dtype=tf.float32)
,)
)
@
tf
.
function
(
input_signature
=
[
tf
.
TensorSpec
(
shape
=
[
None
],
dtype
=
tf
.
float32
)
]
)
def
__call__
(
self
,
waveform
):
def
__call__
(
self
,
waveform
):
return self._yamnet(waveform)
predictions
,
embeddings
,
log_mel_spectrogram
=
self
.
_yamnet
(
waveform
)
return
{
'predictions'
:
predictions
,
'embeddings'
:
embeddings
,
'log_mel_spectrogram'
:
log_mel_spectrogram
}
def
check_model
(
model_fn
,
class_map_path
,
params
):
def
check_model
(
model_fn
,
class_map_path
,
params
):
...
@@ -65,7 +69,10 @@ def check_model(model_fn, class_map_path, params):
...
@@ -65,7 +69,10 @@ def check_model(model_fn, class_map_path, params):
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
def
clip_test
(
waveform
,
expected_class_name
,
top_n
=
10
):
def
clip_test
(
waveform
,
expected_class_name
,
top_n
=
10
):
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
results
=
model_fn
(
waveform
=
waveform
)
predictions
=
results
[
'predictions'
]
embeddings
=
results
[
'embeddings'
]
log_mel_spectrogram
=
results
[
'log_mel_spectrogram'
]
clip_predictions
=
np
.
mean
(
predictions
,
axis
=
0
)
clip_predictions
=
np
.
mean
(
predictions
,
axis
=
0
)
top_n_indices
=
np
.
argsort
(
clip_predictions
)[
-
top_n
:]
top_n_indices
=
np
.
argsort
(
clip_predictions
)[
-
top_n
:]
top_n_scores
=
clip_predictions
[
top_n_indices
]
top_n_scores
=
clip_predictions
[
top_n_indices
]
...
@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
...
@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
# Make TF2 SavedModel export.
# Make TF2 SavedModel export.
log
(
'Making TF2 SavedModel export ...'
)
log
(
'Making TF2 SavedModel export ...'
)
tf.saved_model.save(yamnet, export_dir)
tf
.
saved_model
.
save
(
yamnet
,
export_dir
,
signatures
=
{
'serving_default'
:
yamnet
.
__call__
.
get_concrete_function
()})
log
(
'Done'
)
log
(
'Done'
)
# Check export with TF-Hub in TF2.
# Check export with TF-Hub in TF2.
...
@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
...
@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
log
(
'Making TF-Lite SavedModel export ...'
)
log
(
'Making TF-Lite SavedModel export ...'
)
saved_model_dir
=
os
.
path
.
join
(
export_dir
,
'saved_model'
)
saved_model_dir
=
os
.
path
.
join
(
export_dir
,
'saved_model'
)
os
.
makedirs
(
saved_model_dir
)
os
.
makedirs
(
saved_model_dir
)
tf.saved_model.save(yamnet, saved_model_dir)
tf
.
saved_model
.
save
(
yamnet
,
saved_model_dir
,
signatures
=
{
'serving_default'
:
yamnet
.
__call__
.
get_concrete_function
()})
log
(
'Done'
)
log
(
'Done'
)
# Check that the export can be loaded and works.
# Check that the export can be loaded and works.
...
@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
...
@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
# Make a TF-Lite model from the SavedModel.
# Make a TF-Lite model from the SavedModel.
log
(
'Making TF-Lite model ...'
)
log
(
'Making TF-Lite model ...'
)
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_converter
=
tf
.
lite
.
TFLiteConverter
.
from_saved_model
(
saved_model_dir
,
signature_keys
=
[
'serving_default'
])
tflite_model
=
tflite_converter
.
convert
()
tflite_model
=
tflite_converter
.
convert
()
tflite_model_path
=
os
.
path
.
join
(
export_dir
,
'yamnet.tflite'
)
tflite_model_path
=
os
.
path
.
join
(
export_dir
,
'yamnet.tflite'
)
with
open
(
tflite_model_path
,
'wb'
)
as
f
:
with
open
(
tflite_model_path
,
'wb'
)
as
f
:
...
@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
...
@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
# Check the TF-Lite export.
# Check the TF-Lite export.
log
(
'Checking TF-Lite model ...'
)
log
(
'Checking TF-Lite model ...'
)
interpreter
=
tf
.
lite
.
Interpreter
(
tflite_model_path
)
interpreter
=
tf
.
lite
.
Interpreter
(
tflite_model_path
)
audio_input_index = interpreter.get_input_details()[0]['index']
runner
=
interpreter
.
get_signature_runner
(
'serving_default'
)
scores_output_index = interpreter.get_output_details()[0]['index']
check_model
(
runner
,
'yamnet_class_map.csv'
,
params
)
embeddings_output_index = interpreter.get_output_details()[1]['index']
spectrogram_output_index = interpreter.get_output_details()[2]['index']
def run_model(waveform):
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(audio_input_index, waveform)
interpreter.invoke()
return (interpreter.get_tensor(scores_output_index),
interpreter.get_tensor(embeddings_output_index),
interpreter.get_tensor(spectrogram_output_index))
check_model(run_model, 'yamnet_class_map.csv', params)
log
(
'Done'
)
log
(
'Done'
)
return
saved_model_dir
return
saved_model_dir
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment