Unverified Commit b66b0b05 authored by Dan Ellis's avatar Dan Ellis Committed by GitHub
Browse files

Explicit signatures for tflite. Using ideas from #9688 (#10248)

parent c636ea33
...@@ -44,20 +44,24 @@ def log(msg): ...@@ -44,20 +44,24 @@ def log(msg):
class YAMNet(tf.Module): class YAMNet(tf.Module):
"''A TF2 Module wrapper around YAMNet.""" """A TF2 Module wrapper around YAMNet."""
def __init__(self, weights_path, params): def __init__(self, weights_path, params):
super().__init__() super().__init__()
self._yamnet = yamnet.yamnet_frames_model(params) self._yamnet = yamnet.yamnet_frames_model(params)
self._yamnet.load_weights(weights_path) self._yamnet.load_weights(weights_path)
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv') self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
@tf.function @tf.function(input_signature=[])
def class_map_path(self): def class_map_path(self):
return self._class_map_asset.asset_path return self._class_map_asset.asset_path
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),)) @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def __call__(self, waveform): def __call__(self, waveform):
return self._yamnet(waveform) predictions, embeddings, log_mel_spectrogram = self._yamnet(waveform)
return {'predictions': predictions,
'embeddings': embeddings,
'log_mel_spectrogram': log_mel_spectrogram}
def check_model(model_fn, class_map_path, params): def check_model(model_fn, class_map_path, params):
...@@ -65,7 +69,10 @@ def check_model(model_fn, class_map_path, params): ...@@ -65,7 +69,10 @@ def check_model(model_fn, class_map_path, params):
"""Applies yamnet_test's sanity checks to an instance of YAMNet.""" """Applies yamnet_test's sanity checks to an instance of YAMNet."""
def clip_test(waveform, expected_class_name, top_n=10): def clip_test(waveform, expected_class_name, top_n=10):
predictions, embeddings, log_mel_spectrogram = model_fn(waveform) results = model_fn(waveform=waveform)
predictions = results['predictions']
embeddings = results['embeddings']
log_mel_spectrogram = results['log_mel_spectrogram']
clip_predictions = np.mean(predictions, axis=0) clip_predictions = np.mean(predictions, axis=0)
top_n_indices = np.argsort(clip_predictions)[-top_n:] top_n_indices = np.argsort(clip_predictions)[-top_n:]
top_n_scores = clip_predictions[top_n_indices] top_n_scores = clip_predictions[top_n_indices]
...@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir): ...@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
# Make TF2 SavedModel export. # Make TF2 SavedModel export.
log('Making TF2 SavedModel export ...') log('Making TF2 SavedModel export ...')
tf.saved_model.save(yamnet, export_dir) tf.saved_model.save(
yamnet, export_dir,
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
log('Done') log('Done')
# Check export with TF-Hub in TF2. # Check export with TF-Hub in TF2.
...@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir): ...@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
log('Making TF-Lite SavedModel export ...') log('Making TF-Lite SavedModel export ...')
saved_model_dir = os.path.join(export_dir, 'saved_model') saved_model_dir = os.path.join(export_dir, 'saved_model')
os.makedirs(saved_model_dir) os.makedirs(saved_model_dir)
tf.saved_model.save(yamnet, saved_model_dir) tf.saved_model.save(
yamnet, saved_model_dir,
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
log('Done') log('Done')
# Check that the export can be loaded and works. # Check that the export can be loaded and works.
...@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir): ...@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
# Make a TF-Lite model from the SavedModel. # Make a TF-Lite model from the SavedModel.
log('Making TF-Lite model ...') log('Making TF-Lite model ...')
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir) tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
saved_model_dir, signature_keys=['serving_default'])
tflite_model = tflite_converter.convert() tflite_model = tflite_converter.convert()
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite') tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
with open(tflite_model_path, 'wb') as f: with open(tflite_model_path, 'wb') as f:
...@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir): ...@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
# Check the TF-Lite export. # Check the TF-Lite export.
log('Checking TF-Lite model ...') log('Checking TF-Lite model ...')
interpreter = tf.lite.Interpreter(tflite_model_path) interpreter = tf.lite.Interpreter(tflite_model_path)
audio_input_index = interpreter.get_input_details()[0]['index'] runner = interpreter.get_signature_runner('serving_default')
scores_output_index = interpreter.get_output_details()[0]['index'] check_model(runner, 'yamnet_class_map.csv', params)
embeddings_output_index = interpreter.get_output_details()[1]['index']
spectrogram_output_index = interpreter.get_output_details()[2]['index']
def run_model(waveform):
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(audio_input_index, waveform)
interpreter.invoke()
return (interpreter.get_tensor(scores_output_index),
interpreter.get_tensor(embeddings_output_index),
interpreter.get_tensor(spectrogram_output_index))
check_model(run_model, 'yamnet_class_map.csv', params)
log('Done') log('Done')
return saved_model_dir return saved_model_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment