Commit 64b98521 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add download utility specialized for torchaudio (#2283)

Summary:
In recent updates, torchaudio added features that download assets/models from
download.pytorch.org/torchaudio.

To reduce the code duplication, the implementations uses utilities from
``torch.hub``, but still, there are patterns repeated in implementing
the fetch mechanism, notably cache and local file path handling.

This commit introduces the utility function that handles
download/cache/local path management that can be used for
fetching pre-trained model data.

Pull Request resolved: https://github.com/pytorch/audio/pull/2283

Reviewed By: carolineechen

Differential Revision: D35050469

Pulled By: mthrok

fbshipit-source-id: 219dd806f9a96c54d2d31e981c1bbe282772702b
parent 8395fe65
import pytest import pytest
import torch import torch
from torchaudio._internal import download_url_to_file import torchaudio
class GreedyCTCDecoder(torch.nn.Module): class GreedyCTCDecoder(torch.nn.Module):
...@@ -49,9 +49,7 @@ def sample_speech(tmp_path, lang): ...@@ -49,9 +49,7 @@ def sample_speech(tmp_path, lang):
filename = _FILES[lang] filename = _FILES[lang]
path = tmp_path.parent / filename path = tmp_path.parent / filename
if not path.exists(): if not path.exists():
url = f"https://download.pytorch.org/torchaudio/test-assets/{filename}" torchaudio.utils.download_asset(f"test-assets/{filename}", path=path)
print(f"downloading from {url}")
download_url_to_file(url, path, progress=False)
return path return path
......
import json import json
import math import math
import pathlib
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
...@@ -8,25 +7,12 @@ from typing import Callable, List, Tuple ...@@ -8,25 +7,12 @@ from typing import Callable, List, Tuple
import torch import torch
import torchaudio import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils from torchaudio._internal import module_utils
from torchaudio.models import RNNT, RNNTBeamSearch, emformer_rnnt_base from torchaudio.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = [] __all__ = []
_BASE_MODELS_URL = "https://download.pytorch.org/torchaudio/models"
_BASE_PIPELINES_URL = "https://download.pytorch.org/torchaudio/pipeline-assets"
def _download_asset(asset_path: str):
dst_path = pathlib.Path(torch.hub.get_dir()) / "_assets" / asset_path
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
download_url_to_file(f"{_BASE_PIPELINES_URL}/{asset_path}", dst_path)
return str(dst_path)
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max) _decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel) _gain = pow(10, 0.05 * _decibel)
...@@ -260,8 +246,8 @@ class RNNTBundle: ...@@ -260,8 +246,8 @@ class RNNTBundle:
def _get_model(self) -> RNNT: def _get_model(self) -> RNNT:
model = self._rnnt_factory_func() model = self._rnnt_factory_func()
url = f"{_BASE_MODELS_URL}/{self._rnnt_path}" path = torchaudio.utils.download_asset(self._rnnt_path)
state_dict = load_state_dict_from_url(url) state_dict = torch.load(path)
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
model.eval() model.eval()
return model return model
...@@ -329,7 +315,7 @@ class RNNTBundle: ...@@ -329,7 +315,7 @@ class RNNTBundle:
Returns: Returns:
FeatureExtractor FeatureExtractor
""" """
local_path = _download_asset(self._global_stats_path) local_path = torchaudio.utils.download_asset(self._global_stats_path)
return _ModuleFeatureExtractor( return _ModuleFeatureExtractor(
torch.nn.Sequential( torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram( torchaudio.transforms.MelSpectrogram(
...@@ -348,7 +334,7 @@ class RNNTBundle: ...@@ -348,7 +334,7 @@ class RNNTBundle:
Returns: Returns:
FeatureExtractor FeatureExtractor
""" """
local_path = _download_asset(self._global_stats_path) local_path = torchaudio.utils.download_asset(self._global_stats_path)
return _ModuleFeatureExtractor( return _ModuleFeatureExtractor(
torch.nn.Sequential( torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram( torchaudio.transforms.MelSpectrogram(
...@@ -366,15 +352,15 @@ class RNNTBundle: ...@@ -366,15 +352,15 @@ class RNNTBundle:
Returns: Returns:
TokenProcessor TokenProcessor
""" """
local_path = _download_asset(self._sp_model_path) local_path = torchaudio.utils.download_asset(self._sp_model_path)
return _SentencePieceTokenProcessor(local_path) return _SentencePieceTokenProcessor(local_path)
EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle( EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
_rnnt_path="emformer_rnnt_base_librispeech.pt", _rnnt_path="models/emformer_rnnt_base_librispeech.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097), _rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
_global_stats_path="global_stats_rnnt_librispeech.json", _global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
_sp_model_path="spm_bpe_4096_librispeech.model", _sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
_right_padding=4, _right_padding=4,
_blank=4096, _blank=4096,
_sample_rate=16000, _sample_rate=16000,
......
...@@ -3,7 +3,13 @@ from torchaudio._internal import module_utils as _mod_utils ...@@ -3,7 +3,13 @@ from torchaudio._internal import module_utils as _mod_utils
from . import ( from . import (
sox_utils, sox_utils,
) )
from .download import download_asset
if _mod_utils.is_sox_available(): if _mod_utils.is_sox_available():
sox_utils.set_verbosity(1) sox_utils.set_verbosity(1)
__all__ = [
"download_asset",
"sox_utils",
]
import hashlib
import logging
from os import PathLike
from pathlib import Path
from typing import Union
import torch
_LG = logging.getLogger(__name__)
def _get_local_path(key):
path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key)
path.parent.mkdir(parents=True, exist_ok=True)
return path
def _download(key, path, progress):
url = f"https://download.pytorch.org/torchaudio/{key}"
torch.hub.download_url_to_file(url, path, progress=progress)
def _get_hash(path, hash, chunk_size=1028):
m = hashlib.sha256()
with open(path, "rb") as file:
data = file.read(chunk_size)
while data:
m.update(data)
data = file.read(chunk_size)
return m.hexdigest()
def download_asset(
key: str,
hash: str = "",
path: Union[str, PathLike] = "",
*,
progress: bool = True,
) -> str:
"""Download and store torchaudio assets to local file system.
If a file exists at the download path, then that path is returned with or without
hash validation.
Args:
key (str): The asset identifier.
hash (str, optional):
The value of SHA256 hash of the asset. If provided, it is used to verify
the downloaded / cached object. If not provided, then no hash validation
is performed. This means if a file exists at the download path, then the path
is returned as-is without verifying the identity of the file.
path (path-like object, optional):
By default, the downloaded asset is saved in a directory under
:py:func:`torch.hub.get_dir` and intermediate directories based on the given `key`
are created.
This argument can be used to overwrite the target location.
When this argument is provided, all the intermediate directories have to be
created beforehand.
progress (bool): Whether to show progress bar for downloading. Default: ``True``.
Note:
Currently the valid key values are the route on ``download.pytorch.org/torchaudio``,
but this is an implementation detail.
Returns:
str: The path to the asset on the local file system.
"""
path = path or _get_local_path(key)
if path.exists():
_LG.info("The local file (%s) exists. Skipping the download.", path)
else:
_LG.info("Downloading %s to %s", key, path)
_download(key, path, progress=progress)
if hash:
_LG.info("Verifying the hash value.")
digest = _get_hash(path, hash)
if digest != hash:
raise ValueError(
f"The hash value of the downloaded file ({path}), '{digest}' does not match "
f"the provided hash value, '{hash}'."
)
_LG.info("Hash validated.")
return str(path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment