Commit b976c8f1 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Revise VGGish pipeline to accept arbitrary state dict function (#3531)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3531

Revises VGGish pipeline to accept arbitrary state dict function to accommodate loading weights from any source.

Reviewed By: mthrok

Differential Revision: D48056390

fbshipit-source-id: 2767699b58442ad132b518b4a6435f2772a637c3
parent b645c07b
from dataclasses import dataclass
from typing import Callable, Dict
import torch
import torchaudio
......@@ -6,6 +7,11 @@ import torchaudio
from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor
def _get_state_dict():
path = torchaudio.utils.download_asset("models/vggish.pt")
return torch.load(path)
@dataclass
class VGGishBundle:
"""VGGish :cite:`45611` inference pipeline ported from
......@@ -34,7 +40,7 @@ class VGGishBundle:
class VGGishInputProcessor(_VGGishInputProcessor):
__doc__ = _VGGishInputProcessor.__doc__
_weights_path: str
_state_dict_func: Callable[[], Dict]
@property
def sample_rate(self) -> int:
......@@ -51,8 +57,7 @@ class VGGishBundle:
VGGish: VGGish model with pre-trained weights loaded.
"""
model = self.VGGish()
path = torchaudio.utils.download_asset(self._weights_path)
state_dict = torch.load(path)
state_dict = self._state_dict_func()
model.load_state_dict(state_dict)
model.eval()
return model
......@@ -66,7 +71,7 @@ class VGGishBundle:
return self.VGGishInputProcessor()
VGGISH = VGGishBundle("models/vggish.pt")
VGGISH = VGGishBundle(_get_state_dict)
VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from
`torchvggish <https://github.com/harritaylor/torchvggish>`__
and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment