Commit 3bd4db86 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

refactoring and clearning up code

parent 0e0d1e59
---
AccessModifierOffset: -1
AlignAfterOpenBracket: AlwaysBreak
AlignConsecutiveAssignments: false
AlignConsecutiveDeclarations: false
AlignEscapedNewlinesLeft: true
AlignOperands: false
AlignTrailingComments: false
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: false
AllowShortFunctionsOnASingleLine: Empty
AllowShortIfStatementsOnASingleLine: false
AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: true
BinPackArguments: false
BinPackParameters: false
BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
BeforeCatch: false
BeforeElse: false
IndentBraces: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
BreakConstructorInitializersBeforeComma: false
BreakAfterJavaFieldAnnotations: false
BreakStringLiterals: false
ColumnLimit: 80
CommentPragmas: '^ IWYU pragma:'
CompactNamespaces: false
ConstructorInitializerAllOnOneLineOrOnePerLine: true
ConstructorInitializerIndentWidth: 4
ContinuationIndentWidth: 4
Cpp11BracedListStyle: true
DerivePointerAlignment: false
DisableFormat: false
ForEachMacros: [ FOR_EACH_RANGE, FOR_EACH, ]
IncludeCategories:
- Regex: '^<.*\.h(pp)?>'
Priority: 1
- Regex: '^<.*'
Priority: 2
- Regex: '.*'
Priority: 3
IndentCaseLabels: true
IndentWidth: 2
IndentWrappedFunctionNames: false
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: ''
MacroBlockEnd: ''
MaxEmptyLinesToKeep: 1
NamespaceIndentation: None
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: false
PenaltyBreakBeforeFirstCallParameter: 1
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 2000000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SpaceAfterCStyleCast: false
SpaceBeforeAssignmentOperators: true
SpaceBeforeParens: ControlStatements
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
Standard: Cpp11
TabWidth: 8
UseTab: Never
...
---
# NOTE there must be no spaces before the '-', so put the comma first.
Checks: '
-*
,bugprone-*
,-bugprone-forward-declaration-namespace
,-bugprone-macro-parentheses
,cppcoreguidelines-*
,-cppcoreguidelines-interfaces-global-init
,-cppcoreguidelines-owning-memory
,-cppcoreguidelines-pro-bounds-array-to-pointer-decay
,-cppcoreguidelines-pro-bounds-constant-array-index
,-cppcoreguidelines-pro-bounds-pointer-arithmetic
,-cppcoreguidelines-pro-type-cstyle-cast
,-cppcoreguidelines-pro-type-reinterpret-cast
,-cppcoreguidelines-pro-type-static-cast-downcast
,-cppcoreguidelines-pro-type-union-access
,-cppcoreguidelines-pro-type-vararg
,-cppcoreguidelines-special-member-functions
,hicpp-exception-baseclass
,hicpp-avoid-goto
,modernize-*
,-modernize-return-braced-init-list
,-modernize-use-auto
,-modernize-use-default-member-init
,-modernize-use-using
,performance-unnecessary-value-param
'
WarningsAsErrors: '*'
HeaderFilterRegex: 'torchaudio/.*'
AnalyzeTemporaryDtors: false
CheckOptions:
...
[flake8]
max-line-length = 120
ignore = E305,E402,E721,E741,F401,F403,F405,F821,F841,F999,W503,W504
exclude = build,docs/source,_ext
#!/usr/bin/env python #!/usr/bin/env python
import os
import platform
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension from torch.utils.cpp_extension import BuildExtension, CppExtension
def check_env_flag(name, default=''):
return os.getenv(name, default).upper() in set(['ON', '1', 'YES', 'TRUE', 'Y'])
DEBUG = check_env_flag('DEBUG')
eca = []
ela = []
if DEBUG:
if platform.system() == 'Windows':
ela += ['/DEBUG:FULL']
else:
eca += ['-O0', '-g']
ela += ['-O0', '-g']
setup( setup(
name="torchaudio", name="torchaudio",
version="0.2", version="0.2",
...@@ -14,6 +30,10 @@ setup( ...@@ -14,6 +30,10 @@ setup(
packages=find_packages(exclude=["build"]), packages=find_packages(exclude=["build"]),
ext_modules=[ ext_modules=[
CppExtension( CppExtension(
'_torch_sox', ['torchaudio/torch_sox.cpp'], libraries=['sox']), '_torch_sox',
['torchaudio/torch_sox.cpp'],
libraries=['sox'],
extra_compile_args=eca,
extra_link_args=ela),
], ],
cmdclass={'build_ext': BuildExtension}) cmdclass={'build_ext': BuildExtension})
...@@ -27,7 +27,6 @@ class Test_LoadSave(unittest.TestCase): ...@@ -27,7 +27,6 @@ class Test_LoadSave(unittest.TestCase):
os.unlink(new_filepath) os.unlink(new_filepath)
# test save 1d tensor # test save 1d tensor
#x = x[:, 0] # get mono signal
x = x[0, :] # get mono signal x = x[0, :] # get mono signal
x.squeeze_() # remove channel dim x.squeeze_() # remove channel dim
torchaudio.save(new_filepath, x, sr) torchaudio.save(new_filepath, x, sr)
...@@ -91,7 +90,7 @@ class Test_LoadSave(unittest.TestCase): ...@@ -91,7 +90,7 @@ class Test_LoadSave(unittest.TestCase):
offset = 15 offset = 15
x, _ = torchaudio.load(self.test_filepath) x, _ = torchaudio.load(self.test_filepath)
x_offset, _ = torchaudio.load(self.test_filepath, offset=offset) x_offset, _ = torchaudio.load(self.test_filepath, offset=offset)
self.assertTrue(x[:,offset:].allclose(x_offset)) self.assertTrue(x[:, offset:].allclose(x_offset))
# check number of frames # check number of frames
n = 201 n = 201
...@@ -132,7 +131,7 @@ class Test_LoadSave(unittest.TestCase): ...@@ -132,7 +131,7 @@ class Test_LoadSave(unittest.TestCase):
input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
x_sine_full, sr_sine = torchaudio.load(input_sine_path) x_sine_full, sr_sine = torchaudio.load(input_sine_path)
x_sine_part, _ = torchaudio.load(input_sine_path, num_frames=num_frames, offset=offset) x_sine_part, _ = torchaudio.load(input_sine_path, num_frames=num_frames, offset=offset)
l1_error = x_sine_full[:, offset:(num_frames+offset)].sub(x_sine_part).abs().sum().item() l1_error = x_sine_full[:, offset:(num_frames + offset)].sub(x_sine_part).abs().sum().item()
# test for the correct number of samples and that the correct portion was loaded # test for the correct number of samples and that the correct portion was loaded
self.assertEqual(x_sine_part.size(1), num_frames) self.assertEqual(x_sine_part.size(1), num_frames)
self.assertEqual(l1_error, 0.) self.assertEqual(l1_error, 0.)
...@@ -148,7 +147,7 @@ class Test_LoadSave(unittest.TestCase): ...@@ -148,7 +147,7 @@ class Test_LoadSave(unittest.TestCase):
# test with two channel mp3 # test with two channel mp3
x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath, normalization=True) x_2ch_full, sr_2ch = torchaudio.load(self.test_filepath, normalization=True)
x_2ch_part, _ = torchaudio.load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset) x_2ch_part, _ = torchaudio.load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset)
l1_error = x_2ch_full[:, offset:(offset+num_frames)].sub(x_2ch_part).abs().sum().item() l1_error = x_2ch_full[:, offset:(offset + num_frames)].sub(x_2ch_part).abs().sum().item()
self.assertEqual(x_2ch_part.size(1), num_frames) self.assertEqual(x_2ch_part.size(1), num_frames)
self.assertEqual(l1_error, 0.) self.assertEqual(l1_error, 0.)
......
...@@ -30,13 +30,14 @@ class TORCHAUDIODS(Dataset): ...@@ -30,13 +30,14 @@ class TORCHAUDIODS(Dataset):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
class Test_DataLoader(unittest.TestCase): class Test_DataLoader(unittest.TestCase):
def test_1(self): def test_1(self):
expected_size = (2, 1, 16000) expected_size = (2, 1, 16000)
ds = TORCHAUDIODS() ds = TORCHAUDIODS()
dl = DataLoader(ds, batch_size=2) dl = DataLoader(ds, batch_size=2)
for x in dl: for x in dl:
#print(x.size()) # print(x.size())
continue continue
self.assertTrue(x.size() == expected_size) self.assertTrue(x.size() == expected_size)
......
...@@ -120,7 +120,7 @@ class Test_LoadSave(unittest.TestCase): ...@@ -120,7 +120,7 @@ class Test_LoadSave(unittest.TestCase):
input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
x_sine_full, sr_sine = load(input_sine_path) x_sine_full, sr_sine = load(input_sine_path)
x_sine_part, _ = load(input_sine_path, num_frames=num_frames, offset=offset) x_sine_part, _ = load(input_sine_path, num_frames=num_frames, offset=offset)
l1_error = x_sine_full[offset:(num_frames+offset)].sub(x_sine_part).abs().sum().item() l1_error = x_sine_full[offset:(num_frames + offset)].sub(x_sine_part).abs().sum().item()
# test for the correct number of samples and that the correct portion was loaded # test for the correct number of samples and that the correct portion was loaded
self.assertEqual(x_sine_part.size(0), num_frames) self.assertEqual(x_sine_part.size(0), num_frames)
self.assertEqual(l1_error, 0.) self.assertEqual(l1_error, 0.)
...@@ -137,7 +137,7 @@ class Test_LoadSave(unittest.TestCase): ...@@ -137,7 +137,7 @@ class Test_LoadSave(unittest.TestCase):
# test with two channel mp3 # test with two channel mp3
x_2ch_full, sr_2ch = load(self.test_filepath, normalization=True) x_2ch_full, sr_2ch = load(self.test_filepath, normalization=True)
x_2ch_part, _ = load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset) x_2ch_part, _ = load(self.test_filepath, normalization=True, num_frames=num_frames, offset=offset)
l1_error = x_2ch_full[offset:(offset+num_frames)].sub(x_2ch_part).abs().sum().item() l1_error = x_2ch_full[offset:(offset + num_frames)].sub(x_2ch_part).abs().sum().item()
self.assertEqual(x_2ch_part.size(0), num_frames) self.assertEqual(x_2ch_part.size(0), num_frames)
self.assertEqual(l1_error, 0.) self.assertEqual(l1_error, 0.)
......
...@@ -17,7 +17,7 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -17,7 +17,7 @@ class Test_SoxEffectsChain(unittest.TestCase):
E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3]) E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3])
x, sr = E.sox_build_flow_effects() x, sr = E.sox_build_flow_effects()
# check if effects worked # check if effects worked
#print(x.size()) # print(x.size())
def test_rate_channels(self): def test_rate_channels(self):
target_rate = 16000 target_rate = 16000
...@@ -154,7 +154,7 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -154,7 +154,7 @@ class Test_SoxEffectsChain(unittest.TestCase):
E.append_effect_to_chain("trim", [offset, num_frames]) E.append_effect_to_chain("trim", [offset, num_frames])
x, sr = E.sox_build_flow_effects() x, sr = E.sox_build_flow_effects()
# check if effect worked # check if effect worked
self.assertTrue(x.allclose(x_orig[:,offset_int:(offset_int+num_frames_int)], rtol=1e-4, atol=1e-4)) self.assertTrue(x.allclose(x_orig[:, offset_int:(offset_int + num_frames_int)], rtol=1e-4, atol=1e-4))
def test_silence_contrast(self): def test_silence_contrast(self):
si, _ = torchaudio.info(self.test_filepath) si, _ = torchaudio.info(self.test_filepath)
...@@ -183,13 +183,14 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -183,13 +183,14 @@ class Test_SoxEffectsChain(unittest.TestCase):
E.append_effect_to_chain("fade", ["q", "0.25", "0", "0.33"]) E.append_effect_to_chain("fade", ["q", "0.25", "0", "0.33"])
x, _ = E.sox_build_flow_effects() x, _ = E.sox_build_flow_effects()
# check if effect worked # check if effect worked
#print(x.size()) # print(x.size())
def test_biquad_delay(self): def test_biquad_delay(self):
si, _ = torchaudio.info(self.test_filepath) si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath) E.set_input_file(self.test_filepath)
E.append_effect_to_chain("biquad", ["0.25136437", "0.50272873", "0.25136437", "1.0", "-0.17123075", "0.17668821"]) E.append_effect_to_chain("biquad", ["0.25136437", "0.50272873", "0.25136437",
"1.0", "-0.17123075", "0.17668821"])
E.append_effect_to_chain("delay", ["15000s"]) E.append_effect_to_chain("delay", ["15000s"])
x, _ = E.sox_build_flow_effects() x, _ = E.sox_build_flow_effects()
# check if effect worked # check if effect worked
......
...@@ -38,9 +38,11 @@ class Tester(unittest.TestCase): ...@@ -38,9 +38,11 @@ class Tester(unittest.TestCase):
length_new = int(length_orig * 1.2) length_new = int(length_orig * 1.2)
result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig) result = transforms.PadTrim(max_len=length_new, channels_first=False)(audio_orig)
self.assertEqual(result.size(0), length_new) self.assertEqual(result.size(0), length_new)
result = transforms.PadTrim(max_len=length_new, channels_first=True)(audio_orig.transpose(0, 1))
self.assertEqual(result.size(1), length_new)
audio_orig = self.sig.clone() audio_orig = self.sig.clone()
length_orig = audio_orig.size(0) length_orig = audio_orig.size(0)
length_new = int(length_orig * 0.8) length_new = int(length_orig * 0.8)
...@@ -147,7 +149,7 @@ class Tester(unittest.TestCase): ...@@ -147,7 +149,7 @@ class Tester(unittest.TestCase):
audio_orig = self.sig.clone() # (16000, 1) audio_orig = self.sig.clone() # (16000, 1)
audio_scaled = transforms.Scale()(audio_orig) # (16000, 1) audio_scaled = transforms.Scale()(audio_orig) # (16000, 1)
audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000) audio_scaled = transforms.LC2CL()(audio_scaled) # (1, 16000)
spectrogram_torch = transforms.MEL2()(audio_scaled) # (1, 319, 40) spectrogram_torch = transforms.MEL2(window_fn=torch.hamming_window, pad=10)(audio_scaled) # (1, 319, 40)
self.assertTrue(spectrogram_torch.dim() == 3) self.assertTrue(spectrogram_torch.dim() == 3)
self.assertTrue(spectrogram_torch.max() <= 0.) self.assertTrue(spectrogram_torch.max() <= 0.)
......
...@@ -44,7 +44,8 @@ def load(filepath, ...@@ -44,7 +44,8 @@ def load(filepath,
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int) Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames and
C is the number of channels
- int: the sample rate of the audio (as listed in the metadata of the file) - int: the sample rate of the audio (as listed in the metadata of the file)
Example:: Example::
...@@ -127,8 +128,7 @@ def save_encinfo(filepath, ...@@ -127,8 +128,7 @@ def save_encinfo(filepath,
>>> torchaudio.save('foo.wav', data, sample_rate) >>> torchaudio.save('foo.wav', data, sample_rate)
""" """
ch_idx = 0 if channels_first else 1 ch_idx, len_idx = (0, 1) if channels_first else (1, 0)
len_idx = 1 if channels_first else 0
# check if save directory exists # check if save directory exists
abs_dirpath = os.path.dirname(os.path.abspath(filepath)) abs_dirpath = os.path.dirname(os.path.abspath(filepath))
......
...@@ -44,7 +44,8 @@ class SoxEffectsChain(object): ...@@ -44,7 +44,8 @@ class SoxEffectsChain(object):
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int) Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames and
C is the number of channels
- int: the sample rate of the audio (as listed in the metadata of the file) - int: the sample rate of the audio (as listed in the metadata of the file)
Example:: Example::
......
...@@ -158,7 +158,7 @@ int read_audio_file( ...@@ -158,7 +158,7 @@ int read_audio_file(
void write_audio_file( void write_audio_file(
const std::string& file_name, const std::string& file_name,
at::Tensor tensor, const at::Tensor& tensor,
sox_signalinfo_t* si, sox_signalinfo_t* si,
sox_encodinginfo_t* ei, sox_encodinginfo_t* ei,
const char* file_type) { const char* file_type) {
...@@ -332,16 +332,9 @@ int build_flow_effects(const std::string& file_name, ...@@ -332,16 +332,9 @@ int build_flow_effects(const std::string& file_name,
int sr; int sr;
// Read the in-memory audio buffer or temp file that we just wrote. // Read the in-memory audio buffer or temp file that we just wrote.
#ifdef __APPLE__ #ifdef __APPLE__
/* certain effects will result in a target signal length of 0. /*
if (target_signal->length > 0) { Temporary filetype must have a valid header. Wav seems to work here while
if (target_signal->channels != output->signal.channels) { raw does not. Certain effects like chorus caused strange behavior on the mac.
std::cout << "output: " << output->signal.channels << "|" << output->signal.length << "\n";
std::cout << "interm: " << interm_signal.channels << "|" << interm_signal.length << "\n";
std::cout << "target: " << target_signal->channels << "|" << target_signal->length << "\n";
unlink(tmp_name);
throw std::runtime_error("unexpected number of audio channels");
}
}
*/ */
// read_audio_file reads the temporary file and returns the sr and otensor // read_audio_file reads the temporary file and returns the sr and otensor
sr = read_audio_file(tmp_name, otensor, ch_first, 0, 0, sr = read_audio_file(tmp_name, otensor, ch_first, 0, 0,
......
...@@ -26,10 +26,10 @@ int read_audio_file( ...@@ -26,10 +26,10 @@ int read_audio_file(
/// writing, or an error ocurred during writing of the audio data. /// writing, or an error ocurred during writing of the audio data.
void write_audio_file( void write_audio_file(
const std::string& file_name, const std::string& file_name,
at::Tensor tensor, at::Tensor& tensor,
sox_signalinfo_t* si, sox_signalinfo_t* si,
sox_encodinginfo_t* ei, sox_encodinginfo_t* ei,
const char* extension) const char* file_type)
/// Reads an audio file from the given `path` and returns a tuple of /// Reads an audio file from the given `path` and returns a tuple of
/// sox_signalinfo_t and sox_encodinginfo_t, which contain information about /// sox_signalinfo_t and sox_encodinginfo_t, which contain information about
...@@ -46,6 +46,13 @@ std::vector<std::string> get_effect_names(); ...@@ -46,6 +46,13 @@ std::vector<std::string> get_effect_names();
int initialize_sox(); int initialize_sox();
int shutdown_sox(); int shutdown_sox();
// Struct for build_flow_effects function
struct SoxEffect {
SoxEffect() : ename(""), eopts({""}) { }
std::string ename;
std::vector<std::string> eopts;
};
/// Build a SoX chain, flow the effects, and capture the results in a tensor. /// Build a SoX chain, flow the effects, and capture the results in a tensor.
/// An audio file from the given `path` flows through an effects chain given /// An audio file from the given `path` flows through an effects chain given
/// by a list of effects and effect options to an output buffer which is encoded /// by a list of effects and effect options to an output buffer which is encoded
......
from __future__ import division, print_function from __future__ import division, print_function
import torch import torch
from torch.autograd import Variable
import numpy as np import numpy as np
try: try:
import librosa import librosa
...@@ -8,18 +7,6 @@ except ImportError: ...@@ -8,18 +7,6 @@ except ImportError:
librosa = None librosa = None
def _check_is_variable(tensor):
if isinstance(tensor, torch.Tensor):
is_variable = False
tensor = Variable(tensor, requires_grad=False)
elif isinstance(tensor, Variable):
is_variable = True
else:
raise TypeError("tensor should be a Variable or Tensor, but is {}".format(type(tensor)))
return tensor, is_variable
class Compose(object): class Compose(object):
"""Composes several transforms together. """Composes several transforms together.
...@@ -73,8 +60,8 @@ class Scale(object): ...@@ -73,8 +60,8 @@ class Scale(object):
Tensor: Scaled by the scale factor. (default between -1.0 and 1.0) Tensor: Scaled by the scale factor. (default between -1.0 and 1.0)
""" """
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)): if not tensor.is_floating_point():
tensor = tensor.float() tensor = tensor.to(torch.float32)
return tensor / self.factor return tensor / self.factor
...@@ -101,18 +88,18 @@ class PadTrim(object): ...@@ -101,18 +88,18 @@ class PadTrim(object):
""" """
Returns: Returns:
Tensor: (c x Ln or (n x c) Tensor: (c x n) or (n x c)
""" """
assert tensor.size(self.ch_dim) < 128, \ assert tensor.size(self.ch_dim) < 128, \
"Too many channels ({}) detected, look at channels_first param.".format(tensor.size(self.ch_dim)) "Too many channels ({}) detected, see channels_first param.".format(tensor.size(self.ch_dim))
if self.max_len > tensor.size(self.len_dim): if self.max_len > tensor.size(self.len_dim):
padding = [self.max_len - tensor.size(self.len_dim)
padding_size = [self.max_len - tensor.size(self.len_dim) if i == self.len_dim if (i % 2 == 1) and (i // 2 != self.len_dim)
else tensor.size(self.ch_dim) else 0
for i in range(2)] for i in range(4)]
pad = torch.empty(padding_size, dtype=tensor.dtype).fill_(self.fill_value) with torch.no_grad():
tensor = torch.cat((tensor, pad), dim=self.len_dim) tensor = torch.nn.functional.pad(tensor, padding, "constant", self.fill_value)
elif self.max_len < tensor.size(self.len_dim): elif self.max_len < tensor.size(self.len_dim):
tensor = tensor.narrow(self.len_dim, 0, self.max_len) tensor = tensor.narrow(self.len_dim, 0, self.max_len)
return tensor return tensor
...@@ -138,8 +125,8 @@ class DownmixMono(object): ...@@ -138,8 +125,8 @@ class DownmixMono(object):
self.ch_dim = int(not channels_first) self.ch_dim = int(not channels_first)
def __call__(self, tensor): def __call__(self, tensor):
if isinstance(tensor, (torch.LongTensor, torch.IntTensor)): if not tensor.is_floating_point():
tensor = tensor.float() tensor = tensor.to(torch.float32)
tensor = torch.mean(tensor, self.ch_dim, True) tensor = torch.mean(tensor, self.ch_dim, True)
return tensor return tensor
...@@ -182,12 +169,8 @@ class SPECTROGRAM(object): ...@@ -182,12 +169,8 @@ class SPECTROGRAM(object):
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, window=torch.hann_window, wkwargs=None): pad=0, window_fn=torch.hann_window, wkwargs=None):
if isinstance(window, Variable): self.window = window_fn(ws) if wkwargs is None else window_fn(ws, **wkwargs)
self.window = window
else:
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs)
self.window = Variable(self.window, volatile=True)
self.sr = sr self.sr = sr
self.ws = ws self.ws = ws
self.hop = hop if hop is not None else ws // 2 self.hop = hop if hop is not None else ws // 2
...@@ -200,33 +183,27 @@ class SPECTROGRAM(object): ...@@ -200,33 +183,27 @@ class SPECTROGRAM(object):
def __call__(self, sig): def __call__(self, sig):
""" """
Args: Args:
sig (Tensor or Variable): Tensor of audio of size (c, n) sig (Tensor): Tensor of audio of size (c, n)
Returns: Returns:
spec_f (Tensor or Variable): channels x hops x n_fft (c, l, f), where channels spec_f (Tensor): channels x hops x n_fft (c, l, f), where channels
is unchanged, hops is the number of hops, and n_fft is the is unchanged, hops is the number of hops, and n_fft is the
number of fourier bins, which should be the window size divided number of fourier bins, which should be the window size divided
by 2 plus 1. by 2 plus 1.
""" """
sig, is_variable = _check_is_variable(sig)
assert sig.dim() == 2 assert sig.dim() == 2
if self.pad > 0: if self.pad > 0:
c, n = sig.size() with torch.no_grad():
new_sig = sig.new_empty(c, n + self.pad * 2) sig = torch.nn.functional.pad(sig, (self.pad, self.pad), "constant")
new_sig[:, :self.pad].zero_()
new_sig[:, -self.pad:].zero_()
new_sig.narrow(1, self.pad, n).copy_(sig)
sig = new_sig
spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws, spec_f = torch.stft(sig, self.n_fft, self.hop, self.ws,
self.window, center=False, self.window, center=False,
normalized=True, onesided=True).transpose(1, 2) normalized=True, onesided=True).transpose(1, 2)
spec_f /= self.window.pow(2).sum().sqrt() spec_f /= self.window.pow(2).sum().sqrt()
spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft) spec_f = spec_f.pow(2).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f if is_variable else spec_f.data return spec_f
class F2M(object): class F2M(object):
...@@ -247,7 +224,6 @@ class F2M(object): ...@@ -247,7 +224,6 @@ class F2M(object):
def __call__(self, spec_f): def __call__(self, spec_f):
spec_f, is_variable = _check_is_variable(spec_f)
n_fft = spec_f.size(2) n_fft = spec_f.size(2)
m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700)) m_min = 0. if self.f_min == 0 else 2595 * np.log10(1. + (self.f_min / 700))
...@@ -269,9 +245,8 @@ class F2M(object): ...@@ -269,9 +245,8 @@ class F2M(object):
if f_m != f_m_plus: if f_m != f_m_plus:
fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m) fb[f_m:f_m_plus, m - 1] = (f_m_plus - torch.arange(f_m, f_m_plus)) / (f_m_plus - f_m)
fb = Variable(fb)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels) spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m if is_variable else spec_m.data return spec_m
class SPEC2DB(object): class SPEC2DB(object):
...@@ -290,11 +265,10 @@ class SPEC2DB(object): ...@@ -290,11 +265,10 @@ class SPEC2DB(object):
def __call__(self, spec): def __call__(self, spec):
spec, is_variable = _check_is_variable(spec)
spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB spec_db = self.multiplier * torch.log10(spec / spec.max()) # power -> dB
if self.top_db is not None: if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new([self.top_db])) spec_db = torch.max(spec_db, spec_db.new([self.top_db]))
return spec_db if is_variable else spec_db.data return spec_db
class MEL2(object): class MEL2(object):
...@@ -322,9 +296,8 @@ class MEL2(object): ...@@ -322,9 +296,8 @@ class MEL2(object):
>>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m) >>> spec_mel = transforms.MEL2(sr)(sig) # (c, l, m)
""" """
def __init__(self, sr=16000, ws=400, hop=None, n_fft=None, def __init__(self, sr=16000, ws=400, hop=None, n_fft=None,
pad=0, n_mels=40, window=torch.hann_window, wkwargs=None): pad=0, n_mels=40, window_fn=torch.hann_window, wkwargs=None):
self.window = window(ws) if wkwargs is None else window(ws, **wkwargs) self.window_fn = window_fn
self.window = Variable(self.window, requires_grad=False)
self.sr = sr self.sr = sr
self.ws = ws self.ws = ws
self.hop = hop if hop is not None else ws // 2 self.hop = hop if hop is not None else ws // 2
...@@ -348,18 +321,16 @@ class MEL2(object): ...@@ -348,18 +321,16 @@ class MEL2(object):
""" """
sig, is_variable = _check_is_variable(sig)
transforms = Compose([ transforms = Compose([
SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft, SPECTROGRAM(self.sr, self.ws, self.hop, self.n_fft,
self.pad, self.window), self.pad, self.window_fn, self.wkwargs),
F2M(self.n_mels, self.sr, self.f_max, self.f_min), F2M(self.n_mels, self.sr, self.f_max, self.f_min),
SPEC2DB("power", self.top_db), SPEC2DB("power", self.top_db),
]) ])
spec_mel_db = transforms(sig) spec_mel_db = transforms(sig)
return spec_mel_db if is_variable else spec_mel_db.data return spec_mel_db
class MEL(object): class MEL(object):
...@@ -454,10 +425,10 @@ class MuLawEncoding(object): ...@@ -454,10 +425,10 @@ class MuLawEncoding(object):
if isinstance(x, np.ndarray): if isinstance(x, np.ndarray):
x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu) x_mu = np.sign(x) * np.log1p(mu * np.abs(x)) / np.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int) x_mu = ((x_mu + 1) / 2 * mu + 0.5).astype(int)
elif isinstance(x, (torch.Tensor, torch.LongTensor)): elif isinstance(x, torch.Tensor):
if isinstance(x, torch.LongTensor): if not x.is_floating_point():
x = x.float() x = x.to(torch.float)
mu = torch.FloatTensor([mu]) mu = torch.tensor(mu, dtype=x.dtype)
x_mu = torch.sign(x) * torch.log1p(mu * x_mu = torch.sign(x) * torch.log1p(mu *
torch.abs(x)) / torch.log1p(mu) torch.abs(x)) / torch.log1p(mu)
x_mu = ((x_mu + 1) / 2 * mu + 0.5).long() x_mu = ((x_mu + 1) / 2 * mu + 0.5).long()
...@@ -496,10 +467,10 @@ class MuLawExpanding(object): ...@@ -496,10 +467,10 @@ class MuLawExpanding(object):
if isinstance(x_mu, np.ndarray): if isinstance(x_mu, np.ndarray):
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu x = np.sign(x) * (np.exp(np.abs(x) * np.log1p(mu)) - 1.) / mu
elif isinstance(x_mu, (torch.Tensor, torch.LongTensor)): elif isinstance(x_mu, torch.Tensor):
if isinstance(x_mu, torch.LongTensor): if not x_mu.is_floating_point():
x_mu = x_mu.float() x_mu = x_mu.to(torch.float)
mu = torch.FloatTensor([mu]) mu = torch.tensor(mu, dtype=x_mu.dtype)
x = ((x_mu) / mu) * 2 - 1. x = ((x_mu) / mu) * 2 - 1.
x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.) / mu
return x return x
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment