Unverified Commit 69b2a0ad authored by nateanl's avatar nateanl Committed by GitHub
Browse files

Fix model downloading in bento (#3803)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3803

The model checkpoint path can not be created for Squim models. Use the latest download_asset method to fix it.

Reviewed By: moto-meta

Differential Revision: D59061348
parent 7f6209b4
from dataclasses import dataclass from dataclasses import dataclass
from torchaudio._internal import load_state_dict_from_url import torch
import torchaudio
from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective from torchaudio.models import squim_objective_base, squim_subjective_base, SquimObjective, SquimSubjective
...@@ -42,26 +43,16 @@ class SquimObjectiveBundle: ...@@ -42,26 +43,16 @@ class SquimObjectiveBundle:
_path: str _path: str
_sample_rate: float _sample_rate: float
def _get_state_dict(self, dl_kwargs): def get_model(self) -> SquimObjective:
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> SquimObjective:
"""Construct the SquimObjective model, and load the pretrained weight. """Construct the SquimObjective model, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns: Returns:
Variation of :py:class:`~torchaudio.models.SquimObjective`. Variation of :py:class:`~torchaudio.models.SquimObjective`.
""" """
model = squim_objective_base() model = squim_objective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs)) path = torchaudio.utils.download_asset(f"models/{self._path}")
state_dict = torch.load(path, weights_only=True)
model.load_state_dict(state_dict)
model.eval() model.eval()
return model return model
...@@ -128,26 +119,15 @@ class SquimSubjectiveBundle: ...@@ -128,26 +119,15 @@ class SquimSubjectiveBundle:
_path: str _path: str
_sample_rate: float _sample_rate: float
def _get_state_dict(self, dl_kwargs): def get_model(self) -> SquimSubjective:
url = f"https://download.pytorch.org/torchaudio/models/{self._path}"
dl_kwargs = {} if dl_kwargs is None else dl_kwargs
state_dict = load_state_dict_from_url(url, **dl_kwargs)
return state_dict
def get_model(self, *, dl_kwargs=None) -> SquimSubjective:
"""Construct the SquimSubjective model, and load the pretrained weight. """Construct the SquimSubjective model, and load the pretrained weight.
The weight file is downloaded from the internet and cached with
:func:`torch.hub.load_state_dict_from_url`
Args:
dl_kwargs (dictionary of keyword arguments): Passed to :func:`torch.hub.load_state_dict_from_url`.
Returns: Returns:
Variation of :py:class:`~torchaudio.models.SquimObjective`. Variation of :py:class:`~torchaudio.models.SquimObjective`.
""" """
model = squim_subjective_base() model = squim_subjective_base()
model.load_state_dict(self._get_state_dict(dl_kwargs)) path = torchaudio.utils.download_asset(f"models/{self._path}")
state_dict = torch.load(path, weights_only=True)
model.load_state_dict(state_dict)
model.eval() model.eval()
return model return model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment