Unverified Commit e5eb4857 authored by moto's avatar moto Committed by GitHub
Browse files

Move utility Tensor functions to misc_ops module (#694)

* also deletes duplicated func
parent 9f3075c1
......@@ -12,13 +12,15 @@ from torchaudio import (
transforms
)
from torchaudio._backend import (
check_input,
_get_audio_backend_module,
get_audio_backend,
set_audio_backend,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio.sox_effects import initialize_sox, shutdown_sox
try:
......@@ -161,7 +163,7 @@ def save_encinfo(filepath: str,
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
......@@ -328,24 +330,3 @@ def get_sox_bool(i: int = 0) -> Any:
return _torchaudio.sox_bool
else:
return _torchaudio.sox_bool(i)
def _audio_normalization(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""
if not normalization:
return
if isinstance(normalization, bool):
normalization = 1 << 31
if isinstance(normalization, (float, int)):
# normalize with custom value
a = normalization
signal /= a
elif callable(normalization):
a = normalization(signal)
signal /= a
from functools import wraps
from typing import Any, List, Union
from typing import Any
import platform
import torch
from torch import Tensor
from . import _soundfile_backend, _sox_backend
......@@ -43,10 +41,3 @@ def _get_audio_backend_module() -> Any:
"""
backend = get_audio_backend()
return _audio_backends[backend]
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
from typing import Union, Callable
import torch
from torch import Tensor
def normalize_audio(signal: Tensor, normalization: Union[bool, float, Callable]) -> None:
"""Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""
if not normalization:
return
if isinstance(normalization, bool):
normalization = 1 << 31
if isinstance(normalization, (float, int)):
# normalize with custom value
signal /= normalization
elif callable(normalization):
signal /= normalization(signal)
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
......@@ -4,6 +4,9 @@ from typing import Any, Optional, Tuple, Union
import torch
from torch import Tensor
from torchaudio._internal import misc_ops as _misc_ops
_subtype_to_precision = {
'PCM_S8': 8,
'PCM_16': 16,
......@@ -43,13 +46,6 @@ class EncodingInfo:
self.opposite_endian = opposite_endian
def check_input(src: Tensor) -> None:
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))
def load(filepath: str,
out: Optional[Tensor] = None,
normalization: Optional[bool] = True,
......@@ -108,7 +104,7 @@ def save(filepath: str, src: Tensor, sample_rate: int, precision: int = 16, chan
if not os.path.isdir(abs_dirpath):
raise OSError("Directory does not exist: {}".format(abs_dirpath))
# check that src is a CPU tensor
check_input(src)
_misc_ops.check_input(src)
# Check/Fix shape of source data
if src.dim() == 1:
# 1d tensors as assumed to be mono signals
......
......@@ -5,7 +5,10 @@ import torch
from torch import Tensor
import torchaudio
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
from torchaudio._soundfile_backend import SignalInfo, EncodingInfo
if _mod_utils.is_module_available('torchaudio._torchaudio'):
......@@ -32,7 +35,7 @@ def load(filepath: str,
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()
......@@ -53,7 +56,7 @@ def load(filepath: str,
)
# normalize if needed
torchaudio._audio_normalization(out, normalization)
_misc_ops.normalize_audio(out, normalization)
return out, sample_rate
......
......@@ -5,7 +5,10 @@ import torch
import torchaudio
from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils
from torchaudio._internal import (
module_utils as _mod_utils,
misc_ops as _misc_ops,
)
if _mod_utils.is_module_available('torchaudio._torchaudio'):
from . import _torchaudio
......@@ -200,7 +203,7 @@ class SoxEffectsChain(object):
"""
# initialize output tensor
if out is not None:
torchaudio.check_input(out)
_misc_ops.check_input(out)
else:
out = torch.FloatTensor()
if not len(self.chain):
......@@ -220,7 +223,7 @@ class SoxEffectsChain(object):
self.chain,
self.MAX_EFFECT_OPTS)
torchaudio._audio_normalization(out, self.normalization)
_misc_ops.normalize_audio(out, self.normalization)
return out, sr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment