Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
83f56818
Unverified
Commit
83f56818
authored
Feb 26, 2020
by
Dan Ellis
Committed by
GitHub
Feb 26, 2020
Browse files
Add the tf2-compatible graph wrapping idiom to inference.py. (#8203)
parent
2f9f2479
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
3 deletions
+7
-3
research/audioset/yamnet/inference.py
research/audioset/yamnet/inference.py
+7
-3
No files found.
research/audioset/yamnet/inference.py
View file @
83f56818
...
@@ -21,6 +21,7 @@ import sys
...
@@ -21,6 +21,7 @@ import sys
import
numpy
as
np
import
numpy
as
np
import
resampy
import
resampy
import
soundfile
as
sf
import
soundfile
as
sf
import
tensorflow
as
tf
import
params
import
params
import
yamnet
as
yamnet_model
import
yamnet
as
yamnet_model
...
@@ -29,8 +30,10 @@ import yamnet as yamnet_model
...
@@ -29,8 +30,10 @@ import yamnet as yamnet_model
def
main
(
argv
):
def
main
(
argv
):
assert
argv
assert
argv
yamnet
=
yamnet_model
.
yamnet_frames_model
(
params
)
graph
=
tf
.
Graph
()
yamnet
.
load_weights
(
'yamnet.h5'
)
with
graph
.
as_default
():
yamnet
=
yamnet_model
.
yamnet_frames_model
(
params
)
yamnet
.
load_weights
(
'yamnet.h5'
)
yamnet_classes
=
yamnet_model
.
class_names
(
'yamnet_class_map.csv'
)
yamnet_classes
=
yamnet_model
.
class_names
(
'yamnet_class_map.csv'
)
for
file_name
in
argv
:
for
file_name
in
argv
:
...
@@ -48,7 +51,8 @@ def main(argv):
...
@@ -48,7 +51,8 @@ def main(argv):
# Predict YAMNet classes.
# Predict YAMNet classes.
# Second output is log-mel-spectrogram array (used for visualizations).
# Second output is log-mel-spectrogram array (used for visualizations).
# (steps=1 is a work around for Keras batching limitations.)
# (steps=1 is a work around for Keras batching limitations.)
scores
,
_
=
yamnet
.
predict
(
np
.
reshape
(
waveform
,
[
1
,
-
1
]),
steps
=
1
)
with
graph
.
as_default
():
scores
,
_
=
yamnet
.
predict
(
np
.
reshape
(
waveform
,
[
1
,
-
1
]),
steps
=
1
)
# Scores is a matrix of (time_frames, num_classes) classifier scores.
# Scores is a matrix of (time_frames, num_classes) classifier scores.
# Average them along time to get an overall classifier output for the clip.
# Average them along time to get an overall classifier output for the clip.
prediction
=
np
.
mean
(
scores
,
axis
=
0
)
prediction
=
np
.
mean
(
scores
,
axis
=
0
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment