Commit 64b98521 authored by moto's avatar moto Committed by Facebook GitHub Bot
Browse files

Add download utility specialized for torchaudio (#2283)

Summary:
In recent updates, torchaudio added features that download assets/models from
download.pytorch.org/torchaudio.

To reduce the code duplication, the implementations uses utilities from
``torch.hub``, but still, there are patterns repeated in implementing
the fetch mechanism, notably cache and local file path handling.

This commit introduces the utility function that handles
download/cache/local path management that can be used for
fetching pre-trained model data.

Pull Request resolved: https://github.com/pytorch/audio/pull/2283

Reviewed By: carolineechen

Differential Revision: D35050469

Pulled By: mthrok

fbshipit-source-id: 219dd806f9a96c54d2d31e981c1bbe282772702b
parent 8395fe65
import pytest
import torch
from torchaudio._internal import download_url_to_file
import torchaudio
class GreedyCTCDecoder(torch.nn.Module):
......@@ -49,9 +49,7 @@ def sample_speech(tmp_path, lang):
filename = _FILES[lang]
path = tmp_path.parent / filename
if not path.exists():
url = f"https://download.pytorch.org/torchaudio/test-assets/{filename}"
print(f"downloading from {url}")
download_url_to_file(url, path, progress=False)
torchaudio.utils.download_asset(f"test-assets/{filename}", path=path)
return path
......
import json
import math
import pathlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import partial
......@@ -8,25 +7,12 @@ from typing import Callable, List, Tuple
import torch
import torchaudio
from torchaudio._internal import download_url_to_file, load_state_dict_from_url, module_utils
from torchaudio._internal import module_utils
from torchaudio.models import RNNT, RNNTBeamSearch, emformer_rnnt_base
__all__ = []
_BASE_MODELS_URL = "https://download.pytorch.org/torchaudio/models"
_BASE_PIPELINES_URL = "https://download.pytorch.org/torchaudio/pipeline-assets"
def _download_asset(asset_path: str):
dst_path = pathlib.Path(torch.hub.get_dir()) / "_assets" / asset_path
if not dst_path.exists():
dst_path.parent.mkdir(parents=True, exist_ok=True)
download_url_to_file(f"{_BASE_PIPELINES_URL}/{asset_path}", dst_path)
return str(dst_path)
_decibel = 2 * 20 * math.log10(torch.iinfo(torch.int16).max)
_gain = pow(10, 0.05 * _decibel)
......@@ -260,8 +246,8 @@ class RNNTBundle:
def _get_model(self) -> RNNT:
model = self._rnnt_factory_func()
url = f"{_BASE_MODELS_URL}/{self._rnnt_path}"
state_dict = load_state_dict_from_url(url)
path = torchaudio.utils.download_asset(self._rnnt_path)
state_dict = torch.load(path)
model.load_state_dict(state_dict)
model.eval()
return model
......@@ -329,7 +315,7 @@ class RNNTBundle:
Returns:
FeatureExtractor
"""
local_path = _download_asset(self._global_stats_path)
local_path = torchaudio.utils.download_asset(self._global_stats_path)
return _ModuleFeatureExtractor(
torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
......@@ -348,7 +334,7 @@ class RNNTBundle:
Returns:
FeatureExtractor
"""
local_path = _download_asset(self._global_stats_path)
local_path = torchaudio.utils.download_asset(self._global_stats_path)
return _ModuleFeatureExtractor(
torch.nn.Sequential(
torchaudio.transforms.MelSpectrogram(
......@@ -366,15 +352,15 @@ class RNNTBundle:
Returns:
TokenProcessor
"""
local_path = _download_asset(self._sp_model_path)
local_path = torchaudio.utils.download_asset(self._sp_model_path)
return _SentencePieceTokenProcessor(local_path)
EMFORMER_RNNT_BASE_LIBRISPEECH = RNNTBundle(
_rnnt_path="emformer_rnnt_base_librispeech.pt",
_rnnt_path="models/emformer_rnnt_base_librispeech.pt",
_rnnt_factory_func=partial(emformer_rnnt_base, num_symbols=4097),
_global_stats_path="global_stats_rnnt_librispeech.json",
_sp_model_path="spm_bpe_4096_librispeech.model",
_global_stats_path="pipeline-assets/global_stats_rnnt_librispeech.json",
_sp_model_path="pipeline-assets/spm_bpe_4096_librispeech.model",
_right_padding=4,
_blank=4096,
_sample_rate=16000,
......
......@@ -3,7 +3,13 @@ from torchaudio._internal import module_utils as _mod_utils
from . import (
sox_utils,
)
from .download import download_asset
if _mod_utils.is_sox_available():
sox_utils.set_verbosity(1)
__all__ = [
"download_asset",
"sox_utils",
]
import hashlib
import logging
from os import PathLike
from pathlib import Path
from typing import Union
import torch
_LG = logging.getLogger(__name__)
def _get_local_path(key):
path = Path(torch.hub.get_dir()) / "torchaudio" / Path(key)
path.parent.mkdir(parents=True, exist_ok=True)
return path
def _download(key, path, progress):
url = f"https://download.pytorch.org/torchaudio/{key}"
torch.hub.download_url_to_file(url, path, progress=progress)
def _get_hash(path, hash, chunk_size=1028):
m = hashlib.sha256()
with open(path, "rb") as file:
data = file.read(chunk_size)
while data:
m.update(data)
data = file.read(chunk_size)
return m.hexdigest()
def download_asset(
key: str,
hash: str = "",
path: Union[str, PathLike] = "",
*,
progress: bool = True,
) -> str:
"""Download and store torchaudio assets to local file system.
If a file exists at the download path, then that path is returned with or without
hash validation.
Args:
key (str): The asset identifier.
hash (str, optional):
The value of SHA256 hash of the asset. If provided, it is used to verify
the downloaded / cached object. If not provided, then no hash validation
is performed. This means if a file exists at the download path, then the path
is returned as-is without verifying the identity of the file.
path (path-like object, optional):
By default, the downloaded asset is saved in a directory under
:py:func:`torch.hub.get_dir` and intermediate directories based on the given `key`
are created.
This argument can be used to overwrite the target location.
When this argument is provided, all the intermediate directories have to be
created beforehand.
progress (bool): Whether to show progress bar for downloading. Default: ``True``.
Note:
Currently the valid key values are the route on ``download.pytorch.org/torchaudio``,
but this is an implementation detail.
Returns:
str: The path to the asset on the local file system.
"""
path = path or _get_local_path(key)
if path.exists():
_LG.info("The local file (%s) exists. Skipping the download.", path)
else:
_LG.info("Downloading %s to %s", key, path)
_download(key, path, progress=progress)
if hash:
_LG.info("Verifying the hash value.")
digest = _get_hash(path, hash)
if digest != hash:
raise ValueError(
f"The hash value of the downloaded file ({path}), '{digest}' does not match "
f"the provided hash value, '{hash}'."
)
_LG.info("Hash validated.")
return str(path)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment