Unverified Commit 83f56818 authored by Dan Ellis's avatar Dan Ellis Committed by GitHub
Browse files

Add the tf2-compatible graph wrapping idiom to inference.py. (#8203)

parent 2f9f2479
......@@ -21,6 +21,7 @@ import sys
import numpy as np
import resampy
import soundfile as sf
import tensorflow as tf
import params
import yamnet as yamnet_model
......@@ -29,8 +30,10 @@ import yamnet as yamnet_model
def main(argv):
assert argv
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
graph = tf.Graph()
with graph.as_default():
yamnet = yamnet_model.yamnet_frames_model(params)
yamnet.load_weights('yamnet.h5')
yamnet_classes = yamnet_model.class_names('yamnet_class_map.csv')
for file_name in argv:
......@@ -48,7 +51,8 @@ def main(argv):
# Predict YAMNet classes.
# Second output is log-mel-spectrogram array (used for visualizations).
# (steps=1 is a work around for Keras batching limitations.)
scores, _ = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)
with graph.as_default():
scores, _ = yamnet.predict(np.reshape(waveform, [1, -1]), steps=1)
# Scores is a matrix of (time_frames, num_classes) classifier scores.
# Average them along time to get an overall classifier output for the clip.
prediction = np.mean(scores, axis=0)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment