Unverified Commit 536e8ac0 authored by moto's avatar moto Committed by GitHub
Browse files

Consolidate network utils (#1974)

This commit changes all the `torch.hub` network utility functions to
be imported from `torchaudio._internal`, so that later we can replace
the function within fbcode.
parent 1852d3e1
import torch import torch
import requests from torchaudio._internal import download_url_to_file
import pytest import pytest
...@@ -51,10 +51,7 @@ def sample_speech(tmp_path, lang): ...@@ -51,10 +51,7 @@ def sample_speech(tmp_path, lang):
if not path.exists(): if not path.exists():
url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}' url = f'https://download.pytorch.org/torchaudio/test-assets/{filename}'
print(f'downloading from {url}') print(f'downloading from {url}')
with open(path, 'wb') as file: download_url_to_file(url, path, progress=False)
with requests.get(url) as resp:
resp.raise_for_status()
file.write(resp.content)
return path return path
......
from torch.hub import load_state_dict_from_url, download_url_to_file
__all__ = [
"load_state_dict_from_url",
"download_url_to_file",
]
...@@ -4,8 +4,8 @@ from typing import Union, Optional, Dict, Any, Tuple, List ...@@ -4,8 +4,8 @@ from typing import Union, Optional, Dict, Any, Tuple, List
import torch import torch
from torch import Tensor from torch import Tensor
from torch.hub import load_state_dict_from_url
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import Tacotron2, WaveRNN from torchaudio.models import Tacotron2, WaveRNN
from torchaudio.functional import mu_law_decoding from torchaudio.functional import mu_law_decoding
from torchaudio.transforms import InverseMelScale, GriffinLim from torchaudio.transforms import InverseMelScale, GriffinLim
......
...@@ -3,7 +3,10 @@ import logging ...@@ -3,7 +3,10 @@ import logging
import torch import torch
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import (
download_url_to_file,
module_utils as _mod_utils,
)
def _get_chars(): def _get_chars():
...@@ -174,7 +177,7 @@ def _load_phonemizer(file, dl_kwargs): ...@@ -174,7 +177,7 @@ def _load_phonemizer(file, dl_kwargs):
path = os.path.join(directory, file) path = os.path.join(directory, file)
if not os.path.exists(path): if not os.path.exists(path):
dl_kwargs = {} if dl_kwargs is None else dl_kwargs dl_kwargs = {} if dl_kwargs is None else dl_kwargs
torch.hub.download_url_to_file(url, path, **dl_kwargs) download_url_to_file(url, path, **dl_kwargs)
return Phonemizer.from_checkpoint(path) return Phonemizer.from_checkpoint(path)
finally: finally:
logger.setLevel(orig_level) logger.setLevel(orig_level)
......
...@@ -2,8 +2,8 @@ from dataclasses import dataclass ...@@ -2,8 +2,8 @@ from dataclasses import dataclass
from typing import Dict, Tuple, Any from typing import Dict, Tuple, Any
import torch import torch
from torch.hub import load_state_dict_from_url
from torchaudio._internal import load_state_dict_from_url
from torchaudio.models import wav2vec2_model, Wav2Vec2Model from torchaudio.models import wav2vec2_model, Wav2Vec2Model
from . import utils from . import utils
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment