Commit 8a41ecdc authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Rewrote C code with C++ extensions

parent 0dfcbfde
...@@ -28,7 +28,6 @@ Installation ...@@ -28,7 +28,6 @@ Installation
------------ ------------
```bash ```bash
pip install cffi
python setup.py install python setup.py install
``` ```
......
import os
import torch
from torch.utils.ffi import create_extension
this_file = os.path.dirname(__file__)
sources = ['torchaudio/src/th_sox.c']
headers = [
'torchaudio/src/th_sox.h',
]
defines = []
ffi = create_extension(
'torchaudio._ext.th_sox',
package=True,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
libraries=['sox'],
include_dirs=['torchaudio/src'],
)
if __name__ == '__main__':
ffi.build()
#!/usr/bin/env python #!/usr/bin/env python
import os
import sys
from setuptools import setup, find_packages from setuptools import setup, find_packages
from torch.utils.cpp_extension import BuildExtension, CppExtension
import build
this_file = os.path.dirname(__file__)
setup( setup(
name="torchaudio", name="torchaudio",
version="0.1", version="0.1",
description="An audio package for PyTorch", description="An audio package for PyTorch",
url="https://github.com/pytorch/audio", url="https://github.com/pytorch/audio",
author="Soumith Chintala, David Pollack, Sean Naren", author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough",
author_email="soumith@pytorch.org", author_email="soumith@pytorch.org",
# Require cffi.
install_requires=["cffi>=1.0.0"],
setup_requires=["cffi>=1.0.0"],
# Exclude the build files. # Exclude the build files.
packages=find_packages(exclude=["build"]), packages=find_packages(exclude=["build"]),
# Package where to put the extensions. Has to be a prefix of build.py. ext_modules=[
ext_package="", CppExtension(
# Extensions to compile. '_torch_sox', ['torchaudio/torch_sox.cpp'], libraries=['sox']),
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
], ],
) cmdclass={'build_ext': BuildExtension})
No preview for this file type
File added
...@@ -15,6 +15,8 @@ class Test_LoadSave(unittest.TestCase): ...@@ -15,6 +15,8 @@ class Test_LoadSave(unittest.TestCase):
x, sr = torchaudio.load(self.test_filepath) x, sr = torchaudio.load(self.test_filepath)
self.assertEqual(sr, 44100) self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756, 2)) self.assertEqual(x.size(), (278756, 2))
self.assertGreater(x.sum(), 0)
print
# check normalizing # check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True) x, sr = torchaudio.load(self.test_filepath, normalization=True)
......
import os import os.path
import sys
import torch import torch
import _torch_sox
from cffi import FFI
ffi = FFI()
from ._ext import th_sox
from torchaudio import transforms from torchaudio import transforms
from torchaudio import datasets from torchaudio import datasets
if sys.version_info >= (3, 0):
_bytes = bytes
else:
def _bytes(s, e):
return s.encode(e)
def get_tensor_type_name(tensor): def get_tensor_type_name(tensor):
return tensor.type().replace('torch.', '').replace('Tensor', '') return tensor.type().replace('torch.', '').replace('Tensor', '')
...@@ -55,22 +44,20 @@ def load(filepath, out=None, normalization=None): ...@@ -55,22 +44,20 @@ def load(filepath, out=None, normalization=None):
# check if valid file # check if valid file
if not os.path.isfile(filepath): if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath)) raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor # initialize output tensor
if out is not None: if out is not None:
check_input(out) check_input(out)
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
# load audio signal
typename = get_tensor_type_name(out) sample_rate = _torch_sox.read_audio_file(filepath, out)
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*')
func(str(filepath).encode("utf-8"), out, sample_rate_p)
sample_rate = sample_rate_p[0]
# normalize if needed # normalize if needed
if isinstance(normalization, bool) and normalization: if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth out /= 1 << 31 # assuming 16-bit depth
elif isinstance(normalization, (float, int)): elif isinstance(normalization, (float, int)):
out /= normalization # normalize with custom value out /= normalization # normalize with custom value
return out, sample_rate return out, sample_rate
...@@ -111,9 +98,6 @@ def save(filepath, src, sample_rate): ...@@ -111,9 +98,6 @@ def save(filepath, src, sample_rate):
src = src * (1 << 31) # assuming 16-bit depth src = src * (1 << 31) # assuming 16-bit depth
src = src.long() src = src.long()
# save data to file # save data to file
filename, extension = os.path.splitext(filepath) extension = os.path.splitext(filepath)[1]
check_input(src) check_input(src)
typename = get_tensor_type_name(src) _torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate)
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(_bytes(filepath, "utf-8"), src,
_bytes(extension[1:], "utf-8"), sample_rate)
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/th_sox.c"
#else
void libthsox_(read_audio)(sox_format_t *fd, THTensor* tensor,
int* sample_rate, size_t nsamples)
{
int nchannels = fd->signal.channels;
long buffer_size = fd->signal.length;
if (buffer_size == 0) {
if (nsamples != -1) {
buffer_size = nsamples;
} else {
THError("[read_audio] Unknown length");
}
}
*sample_rate = (int) fd->signal.rate;
int32_t *buffer = (int32_t *)malloc(sizeof(int32_t) * buffer_size);
size_t samples_read = sox_read(fd, buffer, buffer_size);
if (samples_read == 0)
THError("[read_audio] Empty file or read failed in sox_read");
// alloc tensor
THTensor_(resize2d)(tensor, samples_read / nchannels, nchannels );
real *tensor_data = THTensor_(data)(tensor);
// convert audio to dest tensor
int x,k;
for (x=0; x<samples_read/nchannels; x++) {
for (k=0; k<nchannels; k++) {
*tensor_data++ = (real)buffer[x*nchannels+k];
}
}
// free buffer and sox structures
free(buffer);
}
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate)
{
// Create sox objects and read into int32_t buffer
sox_format_t *fd;
fd = sox_open_read(file_name, NULL, NULL, NULL);
if (fd == NULL)
THError("[read_audio_file] Failure to read file");
libthsox_(read_audio)(fd, tensor, sample_rate, -1);
sox_close(fd);
}
void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
const char *extension, int sample_rate)
{
long nchannels = src->size[1];
long nsamples = src->size[0];
real* data = THTensor_(data)(src);
// convert audio to dest tensor
int x,k;
for (x=0; x<nsamples; x++) {
for (k=0; k<nchannels; k++) {
int32_t sample = (int32_t)(data[x*nchannels+k]);
size_t samples_written = sox_write(fd, &sample, 1);
if (samples_written != 1)
THError("[write_audio_file] write failed in sox_write");
}
}
}
void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
const char *extension, int sample_rate)
{
if (THTensor_(isContiguous)(src) == 0)
THError("[write_audio_file] Input should be contiguous tensors");
long nchannels = src->size[1];
long nsamples = src->size[0];
sox_format_t *fd;
// Create sox objects and write into int32_t buffer
sox_signalinfo_t sinfo;
sinfo.rate = sample_rate;
sinfo.channels = nchannels;
sinfo.length = nsamples * nchannels;
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
sinfo.mult = NULL;
#endif
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
if (fd == NULL)
THError("[write_audio_file] Failure to open file for writing");
libthsox_(write_audio)(fd, src, extension, sample_rate);
// free buffer and sox structures
sox_close(fd);
return;
}
#endif
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/th_sox.h"
#else
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
#endif
#include <TH/TH.h>
#include <sox.h>
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)
#define libthsox_(NAME) TH_CONCAT_4(libthsox_, Real, _, NAME)
#include "generic/th_sox.c"
#include "THGenerateAllTypes.h"
/* #include <TH/TH.h> */
/* #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) */
/* #define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor) */
/* #define libthsox_(NAME) TH_CONCAT_4(libthsox_, Real, _, NAME) */
/* #include "generic/th_sox.h" */
/* #include "THGenerateAllTypes.h" */
/* gcc -E th_sox.h -I /home/soumith/code/pytorch/torch/lib/include/TH -I /home/soumith/code/pytorch/torch/lib/include/ -I .|grep libthsox */
void libthsox_Float_read_audio_file(const char *file_name, THFloatTensor* tensor, int* sample_rate);
void libthsox_Double_read_audio_file(const char *file_name, THDoubleTensor* tensor, int* sample_rate);
void libthsox_Byte_read_audio_file(const char *file_name, THByteTensor* tensor, int* sample_rate);
void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor, int* sample_rate);
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);
void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
int sample_rate);
\ No newline at end of file
#include <torch/torch.h>
#include <sox.h>
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <vector>
namespace torch { namespace audio {
int read_audio_file(const std::string& file_name, at::Tensor output) {
sox_format_t* fd = sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr);
if (fd == nullptr) {
throw std::runtime_error("Error opening audio file");
}
const int64_t number_of_channels = fd->signal.channels;
const int sample_rate = fd->signal.rate;
const int64_t buffer_length = fd->signal.length;
if (buffer_length == 0) {
throw std::runtime_error("Error reading audio file: unknown length");
}
std::vector<int32_t> buffer(buffer_length);
const int64_t samples_read = sox_read(fd, buffer.data(), buffer_length);
if (samples_read == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read failed in sox_read");
}
output.resize_({samples_read / number_of_channels, number_of_channels});
output = output.contiguous();
AT_DISPATCH_ALL_TYPES(output.type(), "write_audio_buffer", [&] {
auto* data = output.data<scalar_t>();
std::copy(buffer.begin(), buffer.begin() + samples_read, data);
});
return sample_rate;
}
void write_audio_file(
const std::string& file_name,
at::Tensor tensor,
const std::string& extension,
int sample_rate) {
if (!tensor.is_contiguous()) {
throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous");
}
// Create sox objects and write into int32_t buffer.
sox_signalinfo_t signal;
signal.rate = sample_rate;
signal.channels = tensor.size(1);
signal.length = tensor.numel();
signal.precision = 32; // precision in bits
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
signal.mult = nullptr;
#endif
sox_format_t* fd = sox_open_write(
file_name.c_str(),
&signal,
/*encoding=*/nullptr,
extension.c_str(),
/*filetype=*/nullptr,
/*oob=*/nullptr);
if (fd == nullptr) {
throw std::runtime_error(
"Error writing audio file: could not open file for writing");
}
std::vector<int32_t> buffer(tensor.numel());
AT_DISPATCH_ALL_TYPES(tensor.type(), "write_audio_buffer", [&] {
auto* data = tensor.data<scalar_t>();
std::copy(data, data + tensor.numel(), buffer.begin());
});
const auto samples_written = sox_write(fd, buffer.data(), buffer.size());
// Free buffer and sox structures.
sox_close(fd);
if (samples_written != buffer.size()) {
throw std::runtime_error(
"Error writing audio file: could not write entire buffer");
}
}
}} // namespace torch::audio
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"read_audio_file",
&torch::audio::read_audio_file,
"Reads an audio file into a tensor");
m.def(
"write_audio_file",
&torch::audio::write_audio_file,
"Writes data from a tensor into an audio file");
}
#include <string>
#include <tuple>
namespace at {
struct Tensor;
} // namespace at
namespace torch { namespace audio {
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
int read_audio_file(const std::string& path, at::Tensor output);
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
/// Throws `std::runtime_error` when the audio file could not be opened for
/// writing, or an error ocurred during writing of the audio data.
void write_audio_file(
const std::string& path,
at::Tensor tensor,
const std::string& extension,
int sample_rate);
}} // namespace torch::audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment