Commit 301e2e98 authored by David Pollack's avatar David Pollack Committed by Soumith Chintala
Browse files

sox effects and documentation

parent db0da559
...@@ -7,8 +7,10 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio ...@@ -7,8 +7,10 @@ The :mod:`torchaudio` package consists of I/O, popular datasets and common audio
:maxdepth: 2 :maxdepth: 2
:caption: Package Reference :caption: Package Reference
sox_effects
datasets datasets
transforms transforms
legacy
.. automodule:: torchaudio .. automodule:: torchaudio
:members: :members:
torchaudio.legacy
======================
Legacy loading and save functions.
.. automodule:: torchaudio.legacy
:members:
torchaudio.sox_effects
======================
Create SoX effects chain for preprocessing audio.
.. currentmodule:: torchaudio.sox_effects
.. autoclass:: SoxEffect
:members:
.. autoclass:: SoxEffectsChain
:members: append_effect_to_chain, sox_build_flow_effects, clear_chain, set_input_file
...@@ -5,7 +5,7 @@ from torch.utils.cpp_extension import BuildExtension, CppExtension ...@@ -5,7 +5,7 @@ from torch.utils.cpp_extension import BuildExtension, CppExtension
setup( setup(
name="torchaudio", name="torchaudio",
version="0.1", version="0.2",
description="An audio package for PyTorch", description="An audio package for PyTorch",
url="https://github.com/pytorch/audio", url="https://github.com/pytorch/audio",
author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough", author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough",
......
...@@ -82,11 +82,29 @@ class Test_LoadSave(unittest.TestCase): ...@@ -82,11 +82,29 @@ class Test_LoadSave(unittest.TestCase):
self.assertEqual(sr, 44100) self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (2, 278756)) self.assertEqual(x.size(), (2, 278756))
# check normalizing # check no normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True) x, _ = torchaudio.load(self.test_filepath, normalization=False)
self.assertEqual(x.dtype, torch.float32) self.assertTrue(x.min() <= -1.0)
self.assertTrue(x.min() >= -1.0) self.assertTrue(x.max() >= 1.0)
self.assertTrue(x.max() <= 1.0)
# check offset
offset = 15
x, _ = torchaudio.load(self.test_filepath)
x_offset, _ = torchaudio.load(self.test_filepath, offset=offset)
self.assertTrue(x[:,offset:].allclose(x_offset))
# check number of frames
n = 201
x, _ = torchaudio.load(self.test_filepath, num_frames=n)
self.assertTrue(x.size(), (2, n))
# check channels first
x, _ = torchaudio.load(self.test_filepath, channels_first=False)
self.assertEqual(x.size(), (278756, 2))
# check different input tensor type
x, _ = torchaudio.load(self.test_filepath, torch.LongTensor(), normalization=False)
self.assertTrue(isinstance(x, torch.LongTensor))
# check raising errors # check raising errors
with self.assertRaises(OSError): with self.assertRaises(OSError):
...@@ -108,8 +126,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -108,8 +126,8 @@ class Test_LoadSave(unittest.TestCase):
os.unlink(output_path) os.unlink(output_path)
def test_4_load_partial(self): def test_4_load_partial(self):
num_frames = 100 num_frames = 101
offset = 200 offset = 201
# load entire mono sinewave wav file, load a partial copy and then compare # load entire mono sinewave wav file, load a partial copy and then compare
input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav') input_sine_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
x_sine_full, sr_sine = torchaudio.load(input_sine_path) x_sine_full, sr_sine = torchaudio.load(input_sine_path)
......
...@@ -16,10 +16,10 @@ class TORCHAUDIODS(Dataset): ...@@ -16,10 +16,10 @@ class TORCHAUDIODS(Dataset):
self.data = [os.path.join(self.asset_dirpath, fn) for fn in os.listdir(self.asset_dirpath)] self.data = [os.path.join(self.asset_dirpath, fn) for fn in os.listdir(self.asset_dirpath)]
self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav")) self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav"))
self.si.precision = 16 self.si.precision = 16
self.E = torchaudio.sox_effects.SoxEffects() self.E = torchaudio.sox_effects.SoxEffectsChain()
self.E.sox_append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz self.E.append_effect_to_chain("rate", [self.si.rate]) # resample to 16000hz
self.E.sox_append_effect_to_chain("channels", [self.si.channels]) # mono singal self.E.append_effect_to_chain("channels", [self.si.channels]) # mono singal
self.E.sox_append_effect_to_chain("trim", [0, 1]) # first sec of audio self.E.append_effect_to_chain("trim", [0, "16000s"]) # first 16000 samples of audio
def __getitem__(self, index): def __getitem__(self, index):
fn = self.data[index] fn = self.data[index]
......
...@@ -5,31 +5,40 @@ import math ...@@ -5,31 +5,40 @@ import math
import os import os
class Test_SoxEffects(unittest.TestCase): class Test_SoxEffectsChain(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__)) test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(test_dirpath, "assets", test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3") "steam-train-whistle-daniel_simon.mp3")
def test_single_channel(self):
fn_sine = os.path.join(self.test_dirpath, "assets", "sinewave.wav")
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(fn_sine)
E.append_effect_to_chain("echos", [0.8, 0.7, 40, 0.25, 63, 0.3])
x, sr = E.sox_build_flow_effects()
# check if effects worked
#print(x.size())
def test_rate_channels(self): def test_rate_channels(self):
target_rate = 16000 target_rate = 16000
target_channels = 1 target_channels = 1
E = torchaudio.sox_effects.SoxEffects() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath) E.set_input_file(self.test_filepath)
E.sox_append_effect_to_chain("rate", [target_rate]) E.append_effect_to_chain("rate", [target_rate])
E.sox_append_effect_to_chain("channels", [target_channels]) E.append_effect_to_chain("channels", [target_channels])
x, sr = E.sox_build_flow_effects() x, sr = E.sox_build_flow_effects()
# check if effects worked # check if effects worked
self.assertEqual(sr, target_rate) self.assertEqual(sr, target_rate)
self.assertEqual(x.size(0), target_channels) self.assertEqual(x.size(0), target_channels)
def test_other(self): def test_lowpass_speed(self):
speed = .8 speed = .8
si, _ = torchaudio.info(self.test_filepath) si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffects() E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath) E.set_input_file(self.test_filepath)
E.sox_append_effect_to_chain("lowpass", 100) E.append_effect_to_chain("lowpass", 100)
E.sox_append_effect_to_chain("speed", speed) E.append_effect_to_chain("speed", speed)
E.sox_append_effect_to_chain("rate", si.rate) E.append_effect_to_chain("rate", si.rate)
x, sr = E.sox_build_flow_effects() x, sr = E.sox_build_flow_effects()
# check if effects worked # check if effects worked
self.assertEqual(x.size(1), int((si.length / si.channels) / speed)) self.assertEqual(x.size(1), int((si.length / si.channels) / speed))
...@@ -43,17 +52,145 @@ class Test_SoxEffects(unittest.TestCase): ...@@ -43,17 +52,145 @@ class Test_SoxEffects(unittest.TestCase):
ei_out.encoding = torchaudio.get_sox_encoding_t(9) ei_out.encoding = torchaudio.get_sox_encoding_t(9)
ei_out.bits_per_sample = 8 ei_out.bits_per_sample = 8
si_in, ei_in = torchaudio.info(self.test_filepath) si_in, ei_in = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffects(out_siginfo=si_out, out_encinfo=ei_out) E = torchaudio.sox_effects.SoxEffectsChain(out_siginfo=si_out, out_encinfo=ei_out)
E.set_input_file(self.test_filepath) E.set_input_file(self.test_filepath)
x, sr = E.sox_build_flow_effects() x, sr = E.sox_build_flow_effects()
# Note: the sample rate is reported as "changed", but no downsampling occured # Note: the sample rate is reported as "changed", but no downsampling occured
# also the number of channels has not changed. Run rate and channels effects # also the number of channels has not changed. Run rate and channels effects
# to make those changes # to make those changes. However, the output was encoded into ulaw because the
# number of unique values in the output is less than 256.
self.assertLess(x.unique().size(0), 2**8) self.assertLess(x.unique().size(0), 2**8)
self.assertEqual(x.size(0), si_in.channels) self.assertEqual(x.size(0), si_in.channels)
self.assertEqual(sr, si_out.rate) self.assertEqual(sr, si_out.rate)
self.assertEqual(x.numel(), si_in.length) self.assertEqual(x.numel(), si_in.length)
def test_band_chorus(self):
si_in, ei_in = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain(out_encinfo=ei_in, out_siginfo=si_in)
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("band", ["-n", "10k", "3.5k"])
E.append_effect_to_chain("chorus", [.5, .7, 55, 0.4, .25, 2, '-s'])
x, sr = E.sox_build_flow_effects()
#print(x.size(), sr)
def test_synth(self):
si_in, ei_in = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain(out_encinfo=ei_in, out_siginfo=si_in)
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("synth", ["1", "pinknoise", "mix"])
E.append_effect_to_chain("rate", [44100])
E.append_effect_to_chain("channels", [2])
x, sr = E.sox_build_flow_effects()
#print(x.size(), sr)
def test_gain(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", ["5"])
x, sr = E.sox_build_flow_effects()
E.clear_chain()
self.assertTrue(x.abs().max().item(), 1.)
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", ["-e", "-5"])
x, sr = E.sox_build_flow_effects()
E.clear_chain()
self.assertLess(x.abs().max().item(), 1.)
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", ["-b", "8"])
x, sr = E.sox_build_flow_effects()
E.clear_chain()
self.assertTrue(x.abs().max().item(), 1.)
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", ["-n", "-10"])
x, sr = E.sox_build_flow_effects()
E.clear_chain()
self.assertLess(x.abs().max().item(), 1.)
def test_tempo(self):
tempo = .8
si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("tempo", ["-s", tempo])
x, sr = E.sox_build_flow_effects()
# check if effect worked
self.assertEqual(x.size(1), int((si.length / si.channels) / tempo))
def test_trim(self):
x_orig, _ = torchaudio.load(self.test_filepath)
offset = "10000s"
offset_int = int(offset[:-1])
num_frames = "200s"
num_frames_int = int(num_frames[:-1])
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("trim", [offset, num_frames])
x, sr = E.sox_build_flow_effects()
# check if effect worked
self.assertTrue(x.allclose(x_orig[:,offset_int:(offset_int+num_frames_int)], rtol=1e-4, atol=1e-4))
def test_silence_contrast(self):
si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("silence", [1, 100, 1])
E.append_effect_to_chain("contrast", [])
x, sr = E.sox_build_flow_effects()
# check if effect worked
self.assertLess(x.numel(), si.length)
def test_reverse(self):
x_orig, _ = torchaudio.load(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("reverse", "")
x_rev, _ = E.sox_build_flow_effects()
# check if effect worked
rev_idx = torch.LongTensor(range(x_orig.size(1))[::-1])
self.assertTrue(x_orig.allclose(x_rev[:, rev_idx], rtol=1e-5, atol=2e-5))
def test_compand_fade(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("compand", ["0.3,1", "6:-70,-60,-20", "-5", "-90", "0.2"])
E.append_effect_to_chain("fade", ["q", "0.25", "0", "0.33"])
x, _ = E.sox_build_flow_effects()
# check if effect worked
#print(x.size())
def test_biquad_delay(self):
si, _ = torchaudio.info(self.test_filepath)
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("biquad", ["0.25136437", "0.50272873", "0.25136437", "1.0", "-0.17123075", "0.17668821"])
E.append_effect_to_chain("delay", ["15000s"])
x, _ = E.sox_build_flow_effects()
# check if effect worked
self.assertTrue(x.size(1) == (si.length / si.channels) + 15000)
def test_invalid_effect_name(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
# there is no effect named "special"
with self.assertRaises(LookupError):
E.append_effect_to_chain("special", [""])
def test_unimplemented_effect(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
# the sox spectrogram function is not implemented in torchaudio
with self.assertRaises(NotImplementedError):
E.append_effect_to_chain("spectrogram", [""])
def test_invalid_effect_options(self):
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
# first two options should be combined to "0.3,1"
E.append_effect_to_chain("compand", ["0.3", "1", "6:-70,-60,-20", "-5", "-90", "0.2"])
with self.assertRaises(RuntimeError):
E.sox_build_flow_effects()
if __name__ == '__main__': if __name__ == '__main__':
torchaudio.initialize_sox() torchaudio.initialize_sox()
unittest.main() unittest.main()
......
from __future__ import division, print_function
import os.path import os.path
import torch import torch
import _torch_sox import _torch_sox
from torchaudio import transforms, datasets, sox_effects from torchaudio import transforms, datasets, sox_effects, legacy
def check_input(src): def check_input(src):
...@@ -17,7 +18,7 @@ def load(filepath, ...@@ -17,7 +18,7 @@ def load(filepath,
out=None, out=None,
normalization=True, normalization=True,
channels_first=True, channels_first=True,
num_frames=-1, num_frames=0,
offset=0, offset=0,
signalinfo=None, signalinfo=None,
encodinginfo=None, encodinginfo=None,
...@@ -27,13 +28,13 @@ def load(filepath, ...@@ -27,13 +28,13 @@ def load(filepath,
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
out (Tensor, optional): an output Tensor to use instead of creating one out (Tensor, optional): an output Tensor to use instead of creating one
normalization (bool, number, or function, optional): If boolean `True`, then output is divided by `1 << 31` normalization (bool, number, or callable, optional): If boolean `True`, then output is divided by `1 << 31`
(assumes 16-bit depth audio, and normalizes to `[0, 1]`. (assumes signed 32-bit audio), and normalizes to `[0, 1]`.
If `number`, then output is divided by that number If `number`, then output is divided by that number
If `function`, then the output is passed as a parameter If `callable`, then the output is passed as a parameter
to the given function, then the output is divided by to the given function, then the output is divided by
the result. the result.
num_frames (int, optional): number of frames to load. -1 to load everything after the offset. num_frames (int, optional): number of frames to load. 0 to load everything after the offset.
offset (int, optional): number of frames from the start of the file to begin data loading. offset (int, optional): number of frames from the start of the file to begin data loading.
signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the signalinfo (sox_signalinfo_t, optional): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine audio type cannot be automatically determine
...@@ -42,18 +43,18 @@ def load(filepath, ...@@ -42,18 +43,18 @@ def load(filepath,
filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically filetype (str, optional): a filetype or extension to be set if sox cannot determine it automatically
Returns: tuple(Tensor, int) Returns: tuple(Tensor, int)
- Tensor: output Tensor of size `[L x C]` where L is the number of audio frames, C is the number of channels - Tensor: output Tensor of size `[C x L]` or `[L x C]` where L is the number of audio frames, C is the number of channels
- int: the sample-rate of the audio (as listed in the metadata of the file) - int: the sample-rate of the audio (as listed in the metadata of the file)
Example:: Example::
>>> data, sample_rate = torchaudio.load('foo.mp3') >>> data, sample_rate = torchaudio.load('foo.mp3')
>>> print(data.size()) >>> print(data.size())
torch.Size([278756, 2]) torch.Size([2, 278756])
>>> print(sample_rate) >>> print(sample_rate)
44100 44100
>>> data_volume_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max()) >>> data_vol_normalized, _ = torchaudio.load('foo.mp3', normalization=lambda x: torch.abs(x).max())
>>> print(data_volume_normalized.abs().max()) >>> print(data_vol_normalized.abs().max())
1. 1.
""" """
...@@ -88,6 +89,9 @@ def load(filepath, ...@@ -88,6 +89,9 @@ def load(filepath,
def save(filepath, src, sample_rate, precision=16, channels_first=True): def save(filepath, src, sample_rate, precision=16, channels_first=True):
"""Convenience function for `save_encinfo`.
"""
si = sox_signalinfo_t() si = sox_signalinfo_t()
ch_idx = 0 if channels_first else 1 ch_idx = 0 if channels_first else 1
si.rate = sample_rate si.rate = sample_rate
...@@ -97,12 +101,17 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True): ...@@ -97,12 +101,17 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
return save_encinfo(filepath, src, channels_first, si) return save_encinfo(filepath, src, channels_first, si)
def save_encinfo(filepath, src, channels_first=True, signalinfo=None, encodinginfo=None, filetype=None): def save_encinfo(filepath,
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc. src,
channels_first=True,
signalinfo=None,
encodinginfo=None,
filetype=None):
"""Saves a Tensor of an audio signal to disk as a standard format like mp3, wav, etc.
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
src (Tensor): an input 2D Tensor of shape `[L x C]` where L is src (Tensor): an input 2D Tensor of shape `[C x L]` or `[L x C]` where L is
the number of audio frames, C is the number of channels the number of audio frames, C is the number of channels
signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the signalinfo (sox_signalinfo_t): a sox_signalinfo_t type, which could be helpful if the
audio type cannot be automatically determine audio type cannot be automatically determine
...@@ -129,10 +138,10 @@ def save_encinfo(filepath, src, channels_first=True, signalinfo=None, encodingin ...@@ -129,10 +138,10 @@ def save_encinfo(filepath, src, channels_first=True, signalinfo=None, encodingin
if src.dim() == 1: if src.dim() == 1:
# 1d tensors as assumed to be mono signals # 1d tensors as assumed to be mono signals
src.unsqueeze_(ch_idx) src.unsqueeze_(ch_idx)
elif src.dim() > 2 or src.size(ch_idx) > src.size(len_idx): elif src.dim() > 2 or src.size(ch_idx) > 16:
# assumes num_samples > num_channels # assumes num_channels < 16
raise ValueError( raise ValueError(
"Expected format (L x C), C < L, but found {}".format(src.size())) "Expected format where C < 16, but found {}".format(src.size()))
# sox stores the sample rate as a float, though practically sample rates are almost always integers # sox stores the sample rate as a float, though practically sample rates are almost always integers
# convert integers to floats # convert integers to floats
if not isinstance(signalinfo.rate, float): if not isinstance(signalinfo.rate, float):
...@@ -178,31 +187,10 @@ def info(filepath): ...@@ -178,31 +187,10 @@ def info(filepath):
return _torch_sox.get_info(filepath) return _torch_sox.get_info(filepath)
def effect_names():
"""Gets list of valid sox effect names
Returns: list[str]
Example::
>>> EFFECT_NAMES = torchaudio.effect_names()
"""
return _torch_sox.get_effect_names()
def SoxEffect():
"""Create a object to hold sox effect and options to pass between python and c++
Returns: SoxEffects(object)
- ename (str), name of effect
- eopts (list[str]), list of effect options
"""
return _torch_sox.SoxEffect()
def sox_signalinfo_t(): def sox_signalinfo_t():
"""Create a sox_signalinfo_t object. This object can be used to set the sample r"""Create a sox_signalinfo_t object. This object can be used to set the sample
rate, number of channels, length, bit precision and headroom multiplier rate, number of channels, length, bit precision and headroom multiplier
primarily for effects primarily for effects
Returns: sox_signalinfo_t(object) Returns: sox_signalinfo_t(object)
- rate (float), sample rate as a float, practically will likely be an integer float - rate (float), sample rate as a float, practically will likely be an integer float
...@@ -210,18 +198,25 @@ def sox_signalinfo_t(): ...@@ -210,18 +198,25 @@ def sox_signalinfo_t():
- precision (int), bit precision - precision (int), bit precision
- length (int), length of audio, 0 for unspecified and -1 for unknown - length (int), length of audio, 0 for unspecified and -1 for unknown
- mult (float, optional), headroom multiplier for effects and None for no multiplier - mult (float, optional), headroom multiplier for effects and None for no multiplier
Example::
>>> si = torchaudio.sox_signalinfo_t()
>>> si.channels = 1
>>> si.rate = 16000.
>>> si.precision = 16
>>> si.length = 0
""" """
return _torch_sox.sox_signalinfo_t() return _torch_sox.sox_signalinfo_t()
def sox_encodinginfo_t(): def sox_encodinginfo_t():
"""Create a sox_encodinginfo_t object. This object can be used to set the encoding """Create a sox_encodinginfo_t object. This object can be used to set the encoding
type, bit precision, compression factor, reverse bytes, reverse nibbles, type, bit precision, compression factor, reverse bytes, reverse nibbles,
reverse bits and endianness. This can be used in an effects chain to encode the reverse bits and endianness. This can be used in an effects chain to encode the
final output or to save a file with a specific encoding. For example, one could final output or to save a file with a specific encoding. For example, one could
use the sox ulaw encoding to do 8-bit ulaw encoding. Note in a tensor output use the sox ulaw encoding to do 8-bit ulaw encoding. Note in a tensor output
the result will be a 32-bit number, but number of unique values will be determined by the result will be a 32-bit number, but number of unique values will be determined by
the bit precision. the bit precision.
Returns: sox_encodinginfo_t(object) Returns: sox_encodinginfo_t(object)
- encoding (sox_encoding_t), output encoding - encoding (sox_encoding_t), output encoding
...@@ -231,6 +226,17 @@ def sox_encodinginfo_t(): ...@@ -231,6 +226,17 @@ def sox_encodinginfo_t():
- reverse_nibbles (sox_option_t), reverse nibbles, use sox_option_default - reverse_nibbles (sox_option_t), reverse nibbles, use sox_option_default
- reverse_bits (sox_option_t), reverse bytes, use sox_option_default - reverse_bits (sox_option_t), reverse bytes, use sox_option_default
- opposite_endian (sox_bool), change endianness, use sox_false - opposite_endian (sox_bool), change endianness, use sox_false
Example::
>>> ei = torchaudio.sox_encodinginfo_t()
>>> ei.encoding = torchaudio.get_sox_encoding_t(1)
>>> ei.bits_per_sample = 16
>>> ei.compression = 0
>>> ei.reverse_bytes = torchaudio.get_sox_option_t(2)
>>> ei.reverse_nibbles = torchaudio.get_sox_option_t(2)
>>> ei.reverse_bits = torchaudio.get_sox_option_t(2)
>>> ei.opposite_endian = torchaudio.get_sox_bool(0)
""" """
ei = _torch_sox.sox_encodinginfo_t() ei = _torch_sox.sox_encodinginfo_t()
sdo = get_sox_option_t(2) # sox_default_option sdo = get_sox_option_t(2) # sox_default_option
...@@ -245,7 +251,7 @@ def get_sox_encoding_t(i=None): ...@@ -245,7 +251,7 @@ def get_sox_encoding_t(i=None):
Args: Args:
i (int, optional): choose type or get a dict with all possible options i (int, optional): choose type or get a dict with all possible options
use .__members__ to see all options when not specified use `__members__` to see all options when not specified
Returns: Returns:
sox_encoding_t: a sox_encoding_t type for output encoding sox_encoding_t: a sox_encoding_t type for output encoding
""" """
...@@ -261,7 +267,7 @@ def get_sox_option_t(i=2): ...@@ -261,7 +267,7 @@ def get_sox_option_t(i=2):
Args: Args:
i (int, optional): choose type or get a dict with all possible options i (int, optional): choose type or get a dict with all possible options
use .__members__ to see all options when not specified. use `__members__` to see all options when not specified.
Defaults to sox_option_default. Defaults to sox_option_default.
Returns: Returns:
sox_option_t: a sox_option_t type sox_option_t: a sox_option_t type
...@@ -277,7 +283,7 @@ def get_sox_bool(i=0): ...@@ -277,7 +283,7 @@ def get_sox_bool(i=0):
Args: Args:
i (int, optional): choose type or get a dict with all possible options i (int, optional): choose type or get a dict with all possible options
use .__members__ to see all options when not specified. use `__members__` to see all options when not specified.
Defaults to sox_false. Defaults to sox_false.
Returns: Returns:
sox_bool: a sox_bool type sox_bool: a sox_bool type
...@@ -289,22 +295,25 @@ def get_sox_bool(i=0): ...@@ -289,22 +295,25 @@ def get_sox_bool(i=0):
def initialize_sox(): def initialize_sox():
"""Initialize sox for effects chain. Not required for simple loading. Importantly, """Initialize sox for use with effects chains. This is not required for simple
only initialize this once and do not shutdown until you have done effect chain loading. Importantly, only run `initialize_sox` once and do not shutdown
calls even when loading multiple files. after each effect chain, but rather once you are finished with all effects chains.
""" """
return _torch_sox.initialize_sox() return _torch_sox.initialize_sox()
def shutdown_sox(): def shutdown_sox():
"""Showdown sox for effects chain. Not required for simple loading. Importantly, """Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result seg faults. only call once. Attempting to re-initialize sox will result in seg faults.
""" """
return _torch_sox.shutdown_sox() return _torch_sox.shutdown_sox()
def _audio_normalization(signal, normalization): def _audio_normalization(signal, normalization):
# assumes signed 32-bit depth, which is what sox uses internally """Audio normalization of a tensor in-place. The normalization can be a bool,
a number, or a callable that takes the audio tensor as an input. SoX uses
32-bit signed integers internally, thus bool normalizes based on that assumption.
"""
if not normalization: if not normalization:
return return
......
...@@ -35,13 +35,18 @@ def make_manifest(dir): ...@@ -35,13 +35,18 @@ def make_manifest(dir):
def read_audio(fp, downsample=True): def read_audio(fp, downsample=True):
sig, sr = torchaudio.load(fp)
if downsample: if downsample:
# 48khz -> 16 khz E = torchaudio.sox_effects.SoxEffects()
if sig.size(0) % 3 == 0: E.set_input_file(fp)
sig = sig[::3].contiguous() E.sox_append_effect_to_chain("gain", ["-h"])
else: E.sox_append_effect_to_chain("channels", [1])
sig = sig[:-(sig.size(0) % 3):3].contiguous() E.sox_append_effect_to_chain("rate", [16000])
E.sox_append_effect_to_chain("gain", ["-rh"])
E.sox_append_effect_to_chain("dither", ["-s"])
sig, sr = E.sox_build_flow_effects()
else:
sig, sr = torchaudio.load(fp)
sig = sig.contiguous()
return sig, sr return sig, sr
...@@ -168,8 +173,8 @@ class VCTK(data.Dataset): ...@@ -168,8 +173,8 @@ class VCTK(data.Dataset):
# download files # download files
try: try:
os.makedirs(os.path.join(self.root, self.raw_folder))
os.makedirs(os.path.join(self.root, self.processed_folder)) os.makedirs(os.path.join(self.root, self.processed_folder))
os.makedirs(os.path.join(self.root, self.raw_folder))
except OSError as e: except OSError as e:
if e.errno == errno.EEXIST: if e.errno == errno.EEXIST:
pass pass
...@@ -191,6 +196,7 @@ class VCTK(data.Dataset): ...@@ -191,6 +196,7 @@ class VCTK(data.Dataset):
os.unlink(file_path) os.unlink(file_path)
# process and save as torch files # process and save as torch files
torchaudio.initialize_sox()
print('Processing...') print('Processing...')
shutil.copyfile( shutil.copyfile(
os.path.join(dset_abs_path, "COPYING"), os.path.join(dset_abs_path, "COPYING"),
...@@ -213,10 +219,10 @@ class VCTK(data.Dataset): ...@@ -213,10 +219,10 @@ class VCTK(data.Dataset):
f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0] f_rel_no_ext = os.path.basename(f).rsplit(".", 1)[0]
sig = read_audio(f, downsample=self.downsample)[0] sig = read_audio(f, downsample=self.downsample)[0]
tensors.append(sig) tensors.append(sig)
lengths.append(sig.size(0)) lengths.append(sig.size(1))
labels.append(utterences[f_rel_no_ext]) labels.append(utterences[f_rel_no_ext])
self.max_len = sig.size(0) if sig.size( self.max_len = sig.size(1) if sig.size(
0) > self.max_len else self.max_len 1) > self.max_len else self.max_len
# sort sigs/labels: longest -> shortest # sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted( tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)]) zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
...@@ -232,5 +238,5 @@ class VCTK(data.Dataset): ...@@ -232,5 +238,5 @@ class VCTK(data.Dataset):
self._write_info((n * self.chunk_size) + i + 1) self._write_info((n * self.chunk_size) + i + 1)
if not self.dev_mode: if not self.dev_mode:
shutil.rmtree(raw_abs_dir, ignore_errors=True) shutil.rmtree(raw_abs_dir, ignore_errors=True)
torchaudio.shutdown_sox()
print('Done!') print('Done!')
...@@ -128,12 +128,12 @@ class YESNO(data.Dataset): ...@@ -128,12 +128,12 @@ class YESNO(data.Dataset):
full_path = os.path.join(dset_abs_path, f) full_path = os.path.join(dset_abs_path, f)
sig, sr = torchaudio.load(full_path) sig, sr = torchaudio.load(full_path)
tensors.append(sig) tensors.append(sig)
lengths.append(sig.size(0)) lengths.append(sig.size(1))
labels.append(os.path.basename(f).split(".", 1)[0].split("_")) labels.append(os.path.basename(f).split(".", 1)[0].split("_"))
# sort sigs/labels: longest -> shortest # sort sigs/labels: longest -> shortest
tensors, labels = zip(*[(b, c) for (a, b, c) in sorted( tensors, labels = zip(*[(b, c) for (a, b, c) in sorted(
zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)]) zip(lengths, tensors, labels), key=lambda x: x[0], reverse=True)])
self.max_len = tensors[0].size(0) self.max_len = tensors[0].size(1)
torch.save( torch.save(
(tensors, labels), (tensors, labels),
os.path.join( os.path.join(
......
from __future__ import division, print_function
import os.path import os.path
import torch import torch
import _torch_sox import _torch_sox
from torchaudio import save as save_new, load as load_new import torchaudio
def load(filepath, out=None, normalization=None, num_frames=-1, offset=0): def load(filepath, out=None, normalization=None, num_frames=0, offset=0):
"""Loads an audio file from disk into a Tensor. The default options have """Loads an audio file from disk into a Tensor. The default options have
changed as of torchaudio 0.2 and this function maintains option defaults changed as of torchaudio 0.2 and this function maintains option defaults
from version 0.1. from version 0.1.
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
...@@ -26,20 +27,20 @@ def load(filepath, out=None, normalization=None, num_frames=-1, offset=0): ...@@ -26,20 +27,20 @@ def load(filepath, out=None, normalization=None, num_frames=-1, offset=0):
Example:: Example::
>>> data, sample_rate = torchaudio.load('foo.mp3') >>> data, sample_rate = torchaudio.legacy.load('foo.mp3')
>>> print(data.size()) >>> print(data.size())
torch.Size([278756, 2]) torch.Size([278756, 2])
>>> print(sample_rate) >>> print(sample_rate)
44100 44100
""" """
return load_new(filepath, out, normalization, False, num_frames, offset) return torchaudio.load(filepath, out, normalization, False, num_frames, offset)
def save(filepath, src, sample_rate, precision=32): def save(filepath, src, sample_rate, precision=32):
"""Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc. """Saves a Tensor with audio signal to disk as a standard format like mp3, wav, etc.
The default options have changed as of torchaudio 0.2 and this function maintains The default options have changed as of torchaudio 0.2 and this function maintains
option defaults from version 0.1. option defaults from version 0.1.
Args: Args:
filepath (string): path to audio file filepath (string): path to audio file
...@@ -50,8 +51,8 @@ def save(filepath, src, sample_rate, precision=32): ...@@ -50,8 +51,8 @@ def save(filepath, src, sample_rate, precision=32):
Example:: Example::
>>> data, sample_rate = torchaudio.load('foo.mp3') >>> data, sample_rate = torchaudio.legacy.load('foo.mp3')
>>> torchaudio.save('foo.wav', data, sample_rate) >>> torchaudio.legacy.save('foo.wav', data, sample_rate)
""" """
save_new(filepath, src, sample_rate, precision, False) torchaudio.save(filepath, src, sample_rate, precision, False)
from __future__ import division, print_function
import torch import torch
import _torch_sox import _torch_sox
import torchaudio import torchaudio
EFFECT_NAMES = set(_torch_sox.get_effect_names())
def effect_names():
""" """Gets list of valid sox effect names
Notes:
Returns: list[str]
sox_signalinfo_t {
sox_rate_t rate; /**< samples per second, 0 if unknown */ Example::
unsigned channels; /**< number of sound channels, 0 if unknown */ >>> EFFECT_NAMES = torchaudio.sox_effects.effect_names()
unsigned precision; /**< bits per sample, 0 if unknown */ """
sox_uint64_t length; /**< samples * chans in file, 0 if unspecified, -1 if unknown */ return _torch_sox.get_effect_names()
double * mult; /**< Effects headroom multiplier; may be null */
}
def SoxEffect():
typedef struct sox_encodinginfo_t { """Create an object for passing sox effect information between python and c++
sox_encoding_t encoding; /**< format of sample numbers */
unsigned bits_per_sample; /**< 0 if unknown or variable; uncompressed value if lossless; compressed value if lossy */ Returns: SoxEffect(object)
double compression; /**< compression factor (where applicable) */ - ename (str), name of effect
sox_option_t reverse_bytes; /** use sox_option_default */ - eopts (list[str]), list of effect options
sox_option_t reverse_nibbles; /** use sox_option_default */ """
sox_option_t reverse_bits; /** use sox_option_default */ return _torch_sox.SoxEffect()
sox_bool opposite_endian; /** use sox_false */
}
class SoxEffectsChain(object):
sox_encodings_t = { """SoX effects chain class.
"SOX_ENCODING_UNKNOWN", """
"SOX_ENCODING_SIGN2",
"SOX_ENCODING_UNSIGNED", EFFECTS_AVAILABLE = set(effect_names())
"SOX_ENCODING_FLOAT", EFFECTS_UNIMPLEMENTED = set(["spectrogram", "splice", "noiseprof", "fir"])
"SOX_ENCODING_FLOAT_TEXT",
"SOX_ENCODING_FLAC",
"SOX_ENCODING_HCOM",
"SOX_ENCODING_WAVPACK",
"SOX_ENCODING_WAVPACKF",
"SOX_ENCODING_ULAW",
"SOX_ENCODING_ALAW",
"SOX_ENCODING_G721",
"SOX_ENCODING_G723",
"SOX_ENCODING_CL_ADPCM",
"SOX_ENCODING_CL_ADPCM16",
"SOX_ENCODING_MS_ADPCM",
"SOX_ENCODING_IMA_ADPCM",
"SOX_ENCODING_OKI_ADPCM",
"SOX_ENCODING_DPCM",
"SOX_ENCODING_DWVW",
"SOX_ENCODING_DWVWN",
"SOX_ENCODING_GSM",
"SOX_ENCODING_MP3",
"SOX_ENCODING_VORBIS",
"SOX_ENCODING_AMR_WB",
"SOX_ENCODING_AMR_NB",
"SOX_ENCODING_CVSD",
"SOX_ENCODING_LPC10",
"SOX_ENCODING_OPUS",
"SOX_ENCODINGS"
}
"""
class SoxEffects(object):
def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"): def __init__(self, normalization=True, channels_first=True, out_siginfo=None, out_encinfo=None, filetype="raw"):
self.input_file = None self.input_file = None
...@@ -73,15 +43,12 @@ class SoxEffects(object): ...@@ -73,15 +43,12 @@ class SoxEffects(object):
self.normalization = normalization self.normalization = normalization
self.channels_first = channels_first self.channels_first = channels_first
def sox_check_effect(self, e): def append_effect_to_chain(self, ename, eargs=None):
if e.lower() not in EFFECT_NAMES: """Append effect to a sox effects chain.
raise LookupError("Effect name, {}, not valid".format(e.lower())) """
return e.lower() e = SoxEffect()
def sox_append_effect_to_chain(self, ename, eargs=None):
e = torchaudio.SoxEffect()
# check if we have a valid effect # check if we have a valid effect
ename = self.sox_check_effect(ename) ename = self._check_effect(ename)
if eargs is None or eargs == []: if eargs is None or eargs == []:
eargs = [""] eargs = [""]
elif not isinstance(eargs, list): elif not isinstance(eargs, list):
...@@ -96,13 +63,15 @@ class SoxEffects(object): ...@@ -96,13 +63,15 @@ class SoxEffects(object):
self.chain.append(e) self.chain.append(e)
def sox_build_flow_effects(self, out=None): def sox_build_flow_effects(self, out=None):
"""Build effects chain and flow effects from input file to output tensor
"""
# initialize output tensor # initialize output tensor
if out is not None: if out is not None:
torchaudio.check_input(out) torchaudio.check_input(out)
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
if not len(self.chain): if not len(self.chain):
e = torchaudio.SoxEffect() e = SoxEffect()
e.ename = "no_effects" e.ename = "no_effects"
e.eopts = [""] e.eopts = [""]
self.chain.append(e) self.chain.append(e)
...@@ -122,11 +91,22 @@ class SoxEffects(object): ...@@ -122,11 +91,22 @@ class SoxEffects(object):
return out, sr return out, sr
def clear_chain(self): def clear_chain(self):
"""Clear effects chain in python
"""
self.chain = [] self.chain = []
def set_input_file(self, input_file): def set_input_file(self, input_file):
"""Set input file for input of chain
"""
self.input_file = input_file self.input_file = input_file
def _check_effect(self, e):
if e.lower() in self.EFFECTS_UNIMPLEMENTED:
raise NotImplementedError("This effect ({}) is not implement in torchaudio".format(e))
elif e.lower() not in self.EFFECTS_AVAILABLE:
raise LookupError("Effect name, {}, not valid".format(e.lower()))
return e.lower()
# https://stackoverflow.com/questions/12472338/flattening-a-list-recursively # https://stackoverflow.com/questions/12472338/flattening-a-list-recursively
# convenience function to flatten list recursively # convenience function to flatten list recursively
def _flatten(self, x): def _flatten(self, x):
......
...@@ -109,18 +109,21 @@ int read_audio_file( ...@@ -109,18 +109,21 @@ int read_audio_file(
sox_encodinginfo_t* ei, sox_encodinginfo_t* ei,
const char* ft) { const char* ft) {
SoxDescriptor fd(sox_open_read( SoxDescriptor fd(sox_open_read(file_name.c_str(), si, ei, ft));
file_name.c_str(),
/*signal=*/si,
/*encoding=*/ei,
/*filetype=*/ft));
if (fd.get() == nullptr) { if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file"); throw std::runtime_error("Error opening audio file");
} }
// signal info
const int number_of_channels = fd->signal.channels; const int number_of_channels = fd->signal.channels;
const int sample_rate = fd->signal.rate; const int sample_rate = fd->signal.rate;
const int64_t total_length = fd->signal.length; const int64_t total_length = fd->signal.length;
// multiply offset and number of frames by number of channels
offset *= number_of_channels;
nframes *= number_of_channels;
if (total_length == 0) { if (total_length == 0) {
throw std::runtime_error("Error reading audio file: unknown length"); throw std::runtime_error("Error reading audio file: unknown length");
} }
...@@ -133,14 +136,10 @@ int read_audio_file( ...@@ -133,14 +136,10 @@ int read_audio_file(
if (offset > 0) { if (offset > 0) {
buffer_length -= offset; buffer_length -= offset;
} }
if (nframes != -1 && buffer_length > nframes) { if (nframes > 0 && buffer_length > nframes) {
buffer_length = nframes; buffer_length = nframes;
} }
// buffer length and offset need to be multipled by the number of channels
buffer_length *= number_of_channels;
offset *= number_of_channels;
// seek to offset point before reading data // seek to offset point before reading data
if (sox_seek(fd.get(), offset, 0) == SOX_EOF) { if (sox_seek(fd.get(), offset, 0) == SOX_EOF) {
throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples"); throw std::runtime_error("sox_seek reached EOF, try reducing offset or num_samples");
...@@ -149,6 +148,7 @@ int read_audio_file( ...@@ -149,6 +148,7 @@ int read_audio_file(
// read data and fill output tensor // read data and fill output tensor
read_audio(fd, output, buffer_length); read_audio(fd, output, buffer_length);
// L x C -> C x L, if desired
if (ch_first) { if (ch_first) {
output.transpose_(1, 0); output.transpose_(1, 0);
} }
...@@ -167,7 +167,6 @@ void write_audio_file( ...@@ -167,7 +167,6 @@ void write_audio_file(
"Error writing audio file: input tensor must be contiguous"); "Error writing audio file: input tensor must be contiguous");
} }
// remove ?
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0 #if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
si->mult = nullptr; si->mult = nullptr;
#endif #endif
...@@ -248,20 +247,12 @@ int build_flow_effects(const std::string& file_name, ...@@ -248,20 +247,12 @@ int build_flow_effects(const std::string& file_name,
target_encoding->opposite_endian = sox_false; // Reverse endianness target_encoding->opposite_endian = sox_false; // Reverse endianness
} }
// set target precision / bits_per_sample if it's still 0
//if (target_signal->precision == 0)
// target_signal->precision = input->signal.precision;
//if (target_encoding->bits_per_sample == 0)
// target_encoding->bits_per_sample = input->signal.precision;
// check for rate or channels effect and change the output signalinfo accordingly // check for rate or channels effect and change the output signalinfo accordingly
for (SoxEffect se : pyeffs) { for (SoxEffect se : pyeffs) {
if (se.ename == "rate") { if (se.ename == "rate") {
target_signal->rate = std::stod(se.eopts[0]); target_signal->rate = std::stod(se.eopts[0]);
//se.eopts[0] = "";
} else if (se.ename == "channels") { } else if (se.ename == "channels") {
target_signal->channels = std::stoi(se.eopts[0]); target_signal->channels = std::stoi(se.eopts[0]);
//se.eopts[0] = "";
} }
} }
...@@ -271,7 +262,6 @@ int build_flow_effects(const std::string& file_name, ...@@ -271,7 +262,6 @@ int build_flow_effects(const std::string& file_name,
// create buffer and buffer_size for output in memwrite // create buffer and buffer_size for output in memwrite
char* buffer; char* buffer;
size_t buffer_size; size_t buffer_size;
//const char* otype = (file_type.empty()) ? (const char*) "raw" : file_type.c_str();
#ifdef __APPLE__ #ifdef __APPLE__
// According to Mozilla Deepspeech sox_open_memstream_write doesn't work // According to Mozilla Deepspeech sox_open_memstream_write doesn't work
// with OSX // with OSX
...@@ -287,7 +277,9 @@ int build_flow_effects(const std::string& file_name, ...@@ -287,7 +277,9 @@ int build_flow_effects(const std::string& file_name,
target_encoding, target_encoding,
file_type, nullptr); file_type, nullptr);
#endif #endif
assert(output); if (output == nullptr) {
throw std::runtime_error("Error opening output memstream/temporary file");
}
// Setup the effects chain to decode/resample // Setup the effects chain to decode/resample
sox_effects_chain_t* chain = sox_effects_chain_t* chain =
sox_create_effects_chain(&input->encoding, &output->encoding); sox_create_effects_chain(&input->encoding, &output->encoding);
...@@ -307,11 +299,12 @@ int build_flow_effects(const std::string& file_name, ...@@ -307,11 +299,12 @@ int build_flow_effects(const std::string& file_name,
} else { } else {
int num_opts = tae.eopts.size(); int num_opts = tae.eopts.size();
char* sox_args[max_num_eopts]; char* sox_args[max_num_eopts];
//for(std::string s : tae.eopts) {
for(std::vector<std::string>::size_type i = 0; i != tae.eopts.size(); i++) { for(std::vector<std::string>::size_type i = 0; i != tae.eopts.size(); i++) {
sox_args[i] = (char*) tae.eopts[i].c_str(); sox_args[i] = (char*) tae.eopts[i].c_str();
} }
sox_effect_options(e, num_opts, sox_args); if(sox_effect_options(e, num_opts, sox_args) != SOX_SUCCESS) {
throw std::runtime_error("invalid effect options, see SoX docs for details");
}
} }
sox_add_effect(chain, e, &interm_signal, &input->signal); sox_add_effect(chain, e, &interm_signal, &input->signal);
free(e); free(e);
...@@ -331,9 +324,21 @@ int build_flow_effects(const std::string& file_name, ...@@ -331,9 +324,21 @@ int build_flow_effects(const std::string& file_name,
sox_close(output); sox_close(output);
sox_close(input); sox_close(input);
// Resize output tensor to desired dimensions // Resize output tensor to desired dimensions, different effects result in output->signal.length,
int nc = interm_signal.channels; // interm_signal.length and buffer size being inconsistent with the result of the file output.
int ns = interm_signal.length; // We prioritize in the order: output->signal.length > interm_signal.length > buffer_size
int nc, ns;
if (output->signal.length == 0) {
if (interm_signal.length > (buffer_size * 10)) {
ns = buffer_size / 2;
} else {
ns = interm_signal.length;
}
nc = interm_signal.channels;
} else {
nc = output->signal.channels;
ns = output->signal.length;
}
otensor.resize_({ns/nc, nc}); otensor.resize_({ns/nc, nc});
otensor = otensor.contiguous(); otensor = otensor.contiguous();
......
...@@ -27,7 +27,6 @@ int read_audio_file( ...@@ -27,7 +27,6 @@ int read_audio_file(
void write_audio_file( void write_audio_file(
const std::string& file_name, const std::string& file_name,
at::Tensor tensor, at::Tensor tensor,
bool ch_first,
sox_signalinfo_t* si, sox_signalinfo_t* si,
sox_encodinginfo_t* ei, sox_encodinginfo_t* ei,
const char* extension) const char* extension)
...@@ -55,6 +54,7 @@ int shutdown_sox(); ...@@ -55,6 +54,7 @@ int shutdown_sox();
/// and the sample rate of the output tensor. /// and the sample rate of the output tensor.
int build_flow_effects(const std::string& file_name, int build_flow_effects(const std::string& file_name,
at::Tensor otensor, at::Tensor otensor,
bool ch_first,
sox_signalinfo_t* target_signal, sox_signalinfo_t* target_signal,
sox_encodinginfo_t* target_encoding, sox_encodinginfo_t* target_encoding,
const char* file_type, const char* file_type,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment