"model/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "23125648b8748d9f2ec93c3038db8689f5693f6e"
Commit b976c8f1 authored by Jeff Hwang's avatar Jeff Hwang Committed by Facebook GitHub Bot
Browse files

Revise VGGish pipeline to accept arbitrary state dict function (#3531)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3531

Revises VGGish pipeline to accept arbitrary state dict function to accommodate loading weights from any source.

Reviewed By: mthrok

Differential Revision: D48056390

fbshipit-source-id: 2767699b58442ad132b518b4a6435f2772a637c3
parent b645c07b
from dataclasses import dataclass from dataclasses import dataclass
from typing import Callable, Dict
import torch import torch
import torchaudio import torchaudio
...@@ -6,6 +7,11 @@ import torchaudio ...@@ -6,6 +7,11 @@ import torchaudio
from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor from ._vggish_impl import _SAMPLE_RATE, VGGish as _VGGish, VGGishInputProcessor as _VGGishInputProcessor
def _get_state_dict():
path = torchaudio.utils.download_asset("models/vggish.pt")
return torch.load(path)
@dataclass @dataclass
class VGGishBundle: class VGGishBundle:
"""VGGish :cite:`45611` inference pipeline ported from """VGGish :cite:`45611` inference pipeline ported from
...@@ -34,7 +40,7 @@ class VGGishBundle: ...@@ -34,7 +40,7 @@ class VGGishBundle:
class VGGishInputProcessor(_VGGishInputProcessor): class VGGishInputProcessor(_VGGishInputProcessor):
__doc__ = _VGGishInputProcessor.__doc__ __doc__ = _VGGishInputProcessor.__doc__
_weights_path: str _state_dict_func: Callable[[], Dict]
@property @property
def sample_rate(self) -> int: def sample_rate(self) -> int:
...@@ -51,8 +57,7 @@ class VGGishBundle: ...@@ -51,8 +57,7 @@ class VGGishBundle:
VGGish: VGGish model with pre-trained weights loaded. VGGish: VGGish model with pre-trained weights loaded.
""" """
model = self.VGGish() model = self.VGGish()
path = torchaudio.utils.download_asset(self._weights_path) state_dict = self._state_dict_func()
state_dict = torch.load(path)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
return model return model
...@@ -66,7 +71,7 @@ class VGGishBundle: ...@@ -66,7 +71,7 @@ class VGGishBundle:
return self.VGGishInputProcessor() return self.VGGishInputProcessor()
VGGISH = VGGishBundle("models/vggish.pt") VGGISH = VGGishBundle(_get_state_dict)
VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from VGGISH.__doc__ = """Pre-trained VGGish :cite:`45611` inference pipeline ported from
`torchvggish <https://github.com/harritaylor/torchvggish>`__ `torchvggish <https://github.com/harritaylor/torchvggish>`__
and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__. and `tensorflow-models <https://github.com/tensorflow/models/tree/master/research/audioset>`__.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment