Unverified Commit e5eb4857 authored by moto's avatar moto Committed by GitHub
Browse files

Move utility Tensor functions to misc_ops module (#694)

* also deletes duplicated func
parent 9f3075c1
...@@ -12,13 +12,15 @@ from torchaudio import ( ...@@ -12,13 +12,15 @@ from torchaudio import (
transforms transforms
) )
from torchaudio._backend import ( from torchaudio._backend import (
check_input,
_get_audio_backend_module, _get_audio_backend_module,
get_audio_backend, get_audio_backend,
set_audio_backend, set_audio_backend,
) )
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio.sox_effects import initialize_sox, shutdown_sox from torchaudio.sox_effects import initialize_sox, shutdown_sox
try: try:
...@@ -161,7 +163,7 @@ def save_encinfo(filepath: str, ...@@ -161,7 +163,7 @@ def save_encinfo(filepath: str,
if not os.path.isdir(abs_dirpath): if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath)) raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor # check that src is a CPU tensor
check_input(src) _misc_ops.check_input(src)
# Check/Fix shape of source data # Check/Fix shape of source data
if src.dim() == 1: if src.dim() == 1:
# 1d tensors as assumed to be mono signals # 1d tensors as assumed to be mono signals
...@@ -328,24 +330,3 @@ def get_sox_bool(i: int = 0) -> Any: ...@@ -328,24 +330,3 @@ def get_sox_bool(i: int = 0) -> Any:
return _torchaudio.sox_bool return _torchaudio.sox_bool
else: else:
return _torchaudio.sox_bool(i) return _torchaudio.sox_bool(i)
def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""
if not normalization:
return
if isinstance(normalization, bool):
normalization = 1 << 31
if isinstance(normalization, (float, int)):
# normalize with custom value
a = normalization
signal /= a
elif callable(normalization):
a = normalization(signal)
signal /= a
from functools import wraps from typing import Any
from typing import Any, List, Union
import platform import platform
import torch
from torch import Tensor
from . import _soundfile_backend, _sox_backend from . import _soundfile_backend, _sox_backend
...@@ -43,10 +41,3 @@ def _get_audio_backend_module() -> Any: ...@@ -43,10 +41,3 @@ def _get_audio_backend_module() -> Any:
""" """
backend = get_audio_backend() backend = get_audio_backend()
return _audio_backends[backend] return _audio_backends[backend]
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
from typing import Union, Callable
import torch
from torch import Tensor
def normalize_audio(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""
if not normalization:
return
if isinstance(normalization, bool):
normalization = 1 << 31
if isinstance(normalization, (float, int)):
# normalize with custom value
signal /= normalization
elif callable(normalization):
signal /= normalization(signal)
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
...@@ -4,6 +4,9 @@ from typing import Any, Optional, Tuple, Union ...@@ -4,6 +4,9 @@ from typing import Any, Optional, Tuple, Union
import torch import torch
from torch import Tensor from torch import Tensor
from torchaudio._internal import misc_ops as _misc_ops
_subtype_to_precision = { _subtype_to_precision = {
'PCM_S8': 8, 'PCM_S8': 8,
'PCM_16': 16, 'PCM_16': 16,
...@@ -43,13 +46,6 @@ class EncodingInfo: ...@@ -43,13 +46,6 @@ class EncodingInfo:
self.opposite_endian = opposite_endian self.opposite_endian = opposite_endian
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load(filepath: str, def load(filepath: str,
out: Optional[Tensor] = None, out: Optional[Tensor] = None,
normalization: Optional[bool] = True, normalization: Optional[bool] = True,
...@@ -108,7 +104,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan ...@@ -108,7 +104,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
if not os.path.isdir(abs_dirpath): if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath)) raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor # check that src is a CPU tensor
check_input(src) _misc_ops.check_input(src)
# Check/Fix shape of source data # Check/Fix shape of source data
if src.dim() == 1: if src.dim() == 1:
# 1d tensors as assumed to be mono signals # 1d tensors as assumed to be mono signals
......
...@@ -5,7 +5,10 @@ import torch ...@@ -5,7 +5,10 @@ import torch
from torch import Tensor from torch import Tensor
import torchaudio import torchaudio
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
if _mod_utils.is_module_available('torchaudio._torchaudio'): if _mod_utils.is_module_available('torchaudio._torchaudio'):
...@@ -32,7 +35,7 @@ def load(filepath: str, ...@@ -32,7 +35,7 @@ def load(filepath: str,
# initialize output tensor # initialize output tensor
if out is not None: if out is not None:
torchaudio.check_input(out) _misc_ops.check_input(out)
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
...@@ -53,7 +56,7 @@ def load(filepath: str, ...@@ -53,7 +56,7 @@ def load(filepath: str,
) )
# normalize if needed # normalize if needed
torchaudio._audio_normalization(out, normalization) _misc_ops.normalize_audio(out, normalization)
return out, sample_rate return out, sample_rate
......
...@@ -5,7 +5,10 @@ import torch ...@@ -5,7 +5,10 @@ import torch
import torchaudio import torchaudio
from torch import Tensor from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
if _mod_utils.is_module_available('torchaudio._torchaudio'): if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio from . import _torchaudio
...@@ -200,7 +203,7 @@ class SoxEffectsChain(object): ...@@ -200,7 +203,7 @@ class SoxEffectsChain(object):
""" """
# initialize output tensor # initialize output tensor
if out is not None: if out is not None:
torchaudio.check_input(out) _misc_ops.check_input(out)
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
if not len(self.chain): if not len(self.chain):
...@@ -220,7 +223,7 @@ class SoxEffectsChain(object): ...@@ -220,7 +223,7 @@ class SoxEffectsChain(object):
self.chain, self.chain,
self.MAX_EFFECT_OPTS) self.MAX_EFFECT_OPTS)
torchaudio._audio_normalization(out, self.normalization) _misc_ops.normalize_audio(out, self.normalization)
return out, sr return out, sr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment