Unverified Commit b66b0b05 authored by Dan Ellis's avatar Dan Ellis Committed by GitHub
Browse files

Explicit signatures for tflite. Using ideas from #9688 (#10248)

parent c636ea33
......@@ -44,20 +44,24 @@ def log(msg):
class YAMNet(tf.Module):
"''A TF2 Module wrapper around YAMNet."""
"""A TF2 Module wrapper around YAMNet."""
def __init__(self, weights_path, params):
super().__init__()
self._yamnet = yamnet.yamnet_frames_model(params)
self._yamnet.load_weights(weights_path)
self._class_map_asset = tf.saved_model.Asset('yamnet_class_map.csv')
@tf.function
@tf.function(input_signature=[])
def class_map_path(self):
return self._class_map_asset.asset_path
@tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),))
@tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype=tf.float32)])
def __call__(self, waveform):
return self._yamnet(waveform)
predictions, embeddings, log_mel_spectrogram = self._yamnet(waveform)
return {'predictions': predictions,
'embeddings': embeddings,
'log_mel_spectrogram': log_mel_spectrogram}
def check_model(model_fn, class_map_path, params):
......@@ -65,7 +69,10 @@ def check_model(model_fn, class_map_path, params):
"""Applies yamnet_test's sanity checks to an instance of YAMNet."""
def clip_test(waveform, expected_class_name, top_n=10):
predictions, embeddings, log_mel_spectrogram = model_fn(waveform)
results = model_fn(waveform=waveform)
predictions = results['predictions']
embeddings = results['embeddings']
log_mel_spectrogram = results['log_mel_spectrogram']
clip_predictions = np.mean(predictions, axis=0)
top_n_indices = np.argsort(clip_predictions)[-top_n:]
top_n_scores = clip_predictions[top_n_indices]
......@@ -106,7 +113,9 @@ def make_tf2_export(weights_path, export_dir):
# Make TF2 SavedModel export.
log('Making TF2 SavedModel export ...')
tf.saved_model.save(yamnet, export_dir)
tf.saved_model.save(
yamnet, export_dir,
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
log('Done')
# Check export with TF-Hub in TF2.
......@@ -143,7 +152,9 @@ def make_tflite_export(weights_path, export_dir):
log('Making TF-Lite SavedModel export ...')
saved_model_dir = os.path.join(export_dir, 'saved_model')
os.makedirs(saved_model_dir)
tf.saved_model.save(yamnet, saved_model_dir)
tf.saved_model.save(
yamnet, saved_model_dir,
signatures={'serving_default': yamnet.__call__.get_concrete_function()})
log('Done')
# Check that the export can be loaded and works.
......@@ -154,7 +165,8 @@ def make_tflite_export(weights_path, export_dir):
# Make a TF-Lite model from the SavedModel.
log('Making TF-Lite model ...')
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir)
tflite_converter = tf.lite.TFLiteConverter.from_saved_model(
saved_model_dir, signature_keys=['serving_default'])
tflite_model = tflite_converter.convert()
tflite_model_path = os.path.join(export_dir, 'yamnet.tflite')
with open(tflite_model_path, 'wb') as f:
......@@ -164,19 +176,8 @@ def make_tflite_export(weights_path, export_dir):
# Check the TF-Lite export.
log('Checking TF-Lite model ...')
interpreter = tf.lite.Interpreter(tflite_model_path)
audio_input_index = interpreter.get_input_details()[0]['index']
scores_output_index = interpreter.get_output_details()[0]['index']
embeddings_output_index = interpreter.get_output_details()[1]['index']
spectrogram_output_index = interpreter.get_output_details()[2]['index']
def run_model(waveform):
interpreter.resize_tensor_input(audio_input_index, [len(waveform)], strict=True)
interpreter.allocate_tensors()
interpreter.set_tensor(audio_input_index, waveform)
interpreter.invoke()
return (interpreter.get_tensor(scores_output_index),
interpreter.get_tensor(embeddings_output_index),
interpreter.get_tensor(spectrogram_output_index))
check_model(run_model, 'yamnet_class_map.csv', params)
runner = interpreter.get_signature_runner('serving_default')
check_model(runner, 'yamnet_class_map.csv', params)
log('Done')
return saved_model_dir
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment