Commit 18c01bef authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Soumith Chintala
Browse files

Bring back C++ extensions again (#47)

parent d5eadbdc
......@@ -28,7 +28,6 @@ Installation
------------
```bash
pip install cffi
python setup.py install
```
......
import os
import torch
from torch.utils.ffi import create_extension
this_file = os.path.dirname(__file__)
sources = ['torchaudio/src/th_sox.c']
headers = [
'torchaudio/src/th_sox.h',
]
defines = []
ffi = create_extension(
'torchaudio._ext.th_sox',
package=True,
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
libraries=['sox'],
include_dirs=['torchaudio/src'],
)
if __name__ == '__main__':
ffi.build()
#!/usr/bin/env python
import os
import sys
from setuptools import setup, find_packages
import build
this_file = os.path.dirname(__file__)
from torch.utils.cpp_extension import BuildExtension, CppExtension
setup(
name="torchaudio",
version="0.1",
description="An audio package for PyTorch",
url="https://github.com/pytorch/audio",
author="Soumith Chintala, David Pollack, Sean Naren",
author="Soumith Chintala, David Pollack, Sean Naren, Peter Goldsborough",
author_email="soumith@pytorch.org",
# Require cffi.
install_requires=["cffi>=1.0.0", "torch>=0.4"],
setup_requires=["cffi>=1.0.0", "torch>=0.4"],
# Exclude the build files.
packages=find_packages(exclude=["build"]),
# Package where to put the extensions. Has to be a prefix of build.py.
ext_package="",
# Extensions to compile.
cffi_modules=[
os.path.join(this_file, "build.py:ffi")
ext_modules=[
CppExtension(
'_torch_sox', ['torchaudio/torch_sox.cpp'], libraries=['sox']),
],
)
cmdclass={'build_ext': BuildExtension})
......@@ -7,17 +7,19 @@ import os
class Test_LoadSave(unittest.TestCase):
test_dirpath = os.path.dirname(os.path.realpath(__file__))
test_filepath = os.path.join(
test_dirpath, "assets", "steam-train-whistle-daniel_simon.mp3")
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")
def test_load(self):
# check normal loading
x, sr = torchaudio.load(self.test_filepath)
self.assertEqual(sr, 44100)
self.assertEqual(x.size(), (278756, 2))
self.assertGreater(x.sum(), 0)
# check normalizing
x, sr = torchaudio.load(self.test_filepath, normalization=True)
self.assertEqual(x.dtype, torch.float32)
self.assertTrue(x.min() >= -1.0)
self.assertTrue(x.max() <= 1.0)
......@@ -26,8 +28,8 @@ class Test_LoadSave(unittest.TestCase):
torchaudio.load("file-does-not-exist.mp3")
with self.assertRaises(OSError):
tdir = os.path.join(os.path.dirname(
self.test_dirpath), "torchaudio")
tdir = os.path.join(
os.path.dirname(self.test_dirpath), "torchaudio")
torchaudio.load(tdir)
def test_save(self):
......@@ -78,24 +80,35 @@ class Test_LoadSave(unittest.TestCase):
# don't save to folders that don't exist
with self.assertRaises(OSError):
new_filepath = os.path.join(
self.test_dirpath, "no-path", "test.wav")
new_filepath = os.path.join(self.test_dirpath, "no-path",
"test.wav")
torchaudio.save(new_filepath, x, sr)
# save created file
sinewave_filepath = os.path.join(
self.test_dirpath, "assets", "sinewave.wav")
sinewave_filepath = os.path.join(self.test_dirpath, "assets",
"sinewave.wav")
sr = 16000
freq = 440
volume = 0.3
y = (torch.cos(2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y = (torch.cos(
2 * math.pi * torch.arange(0, 4 * sr) * freq / sr)).float()
y.unsqueeze_(1)
# y is between -1 and 1, so must scale
y = (y * volume * 2**31).long()
torchaudio.save(sinewave_filepath, y, sr)
self.assertTrue(os.path.isfile(sinewave_filepath))
def test_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
torchaudio.save(output_path, tensor, sample_rate)
tensor2, sample_rate2 = torchaudio.load(output_path)
self.assertTrue(tensor.allclose(tensor2))
self.assertEqual(sample_rate, sample_rate2)
os.unlink(output_path)
if __name__ == '__main__':
unittest.main()
import os
import sys
import os.path
import torch
from cffi import FFI
ffi = FFI()
from ._ext import th_sox
import _torch_sox
from torchaudio import transforms
from torchaudio import datasets
if sys.version_info >= (3, 0):
_bytes = bytes
else:
def _bytes(s, e):
return s.encode(e)
def get_tensor_type_name(tensor):
return tensor.type().replace('torch.', '').replace('Tensor', '')
......@@ -55,22 +44,20 @@ def load(filepath, out=None, normalization=None):
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))
# initialize output tensor
if out is not None:
check_input(out)
else:
out = torch.FloatTensor()
# load audio signal
typename = get_tensor_type_name(out)
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*')
func(str(filepath).encode("utf-8"), out, sample_rate_p)
sample_rate = sample_rate_p[0]
sample_rate = _torch_sox.read_audio_file(filepath, out)
# normalize if needed
if isinstance(normalization, bool) and normalization:
out /= 1 << 31 # assuming 16-bit depth
elif isinstance(normalization, (float, int)):
out /= normalization # normalize with custom value
return out, sample_rate
......@@ -111,9 +98,6 @@ def save(filepath, src, sample_rate):
src = src * (1 << 31) # assuming 16-bit depth
src = src.long()
# save data to file
filename, extension = os.path.splitext(filepath)
extension = os.path.splitext(filepath)[1]
check_input(src)
typename = get_tensor_type_name(src)
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(_bytes(filepath, "utf-8"), src,
_bytes(extension[1:], "utf-8"), sample_rate)
_torch_sox.write_audio_file(filepath, src, extension[1:], sample_rate)
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/th_sox.c"
#else
void libthsox_(read_audio)(sox_format_t *fd, THTensor* tensor,
int* sample_rate, size_t nsamples)
{
int nchannels = fd->signal.channels;
long buffer_size = fd->signal.length;
if (buffer_size == 0) {
if (nsamples != -1) {
buffer_size = nsamples;
} else {
THError("[read_audio] Unknown length");
}
}
*sample_rate = (int) fd->signal.rate;
int32_t *buffer = (int32_t *)malloc(sizeof(int32_t) * buffer_size);
size_t samples_read = sox_read(fd, buffer, buffer_size);
if (samples_read == 0)
THError("[read_audio] Empty file or read failed in sox_read");
// alloc tensor
THTensor_(resize2d)(tensor, samples_read / nchannels, nchannels );
real *tensor_data = THTensor_(data)(tensor);
// convert audio to dest tensor
int x,k;
for (x=0; x<samples_read/nchannels; x++) {
for (k=0; k<nchannels; k++) {
*tensor_data++ = (real)buffer[x*nchannels+k];
}
}
// free buffer and sox structures
free(buffer);
}
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate)
{
// Create sox objects and read into int32_t buffer
sox_format_t *fd;
fd = sox_open_read(file_name, NULL, NULL, NULL);
if (fd == NULL)
THError("[read_audio_file] Failure to read file");
libthsox_(read_audio)(fd, tensor, sample_rate, -1);
sox_close(fd);
}
void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
const char *extension, int sample_rate)
{
long nchannels = src->size[1];
long nsamples = src->size[0];
real* data = THTensor_(data)(src);
// convert audio to dest tensor
int x,k;
for (x=0; x<nsamples; x++) {
for (k=0; k<nchannels; k++) {
int32_t sample = (int32_t)(data[x*nchannels+k]);
size_t samples_written = sox_write(fd, &sample, 1);
if (samples_written != 1)
THError("[write_audio_file] write failed in sox_write");
}
}
}
void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
const char *extension, int sample_rate)
{
if (THTensor_(isContiguous)(src) == 0)
THError("[write_audio_file] Input should be contiguous tensors");
long nchannels = src->size[1];
long nsamples = src->size[0];
sox_format_t *fd;
// Create sox objects and write into int32_t buffer
sox_signalinfo_t sinfo;
sinfo.rate = sample_rate;
sinfo.channels = nchannels;
sinfo.length = nsamples * nchannels;
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
sinfo.mult = NULL;
#endif
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
if (fd == NULL)
THError("[write_audio_file] Failure to open file for writing");
libthsox_(write_audio)(fd, src, extension, sample_rate);
// free buffer and sox structures
sox_close(fd);
return;
}
#endif
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "generic/th_sox.h"
#else
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
#endif
#include <TH/TH.h>
#include <sox.h>
#define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME)
#define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor)
#define libthsox_(NAME) TH_CONCAT_4(libthsox_, Real, _, NAME)
#include "generic/th_sox.c"
#include "THGenerateAllTypes.h"
/* #include <TH/TH.h> */
/* #define torch_(NAME) TH_CONCAT_3(torch_, Real, NAME) */
/* #define torch_Tensor TH_CONCAT_STRING_3(torch., Real, Tensor) */
/* #define libthsox_(NAME) TH_CONCAT_4(libthsox_, Real, _, NAME) */
/* #include "generic/th_sox.h" */
/* #include "THGenerateAllTypes.h" */
/* gcc -E th_sox.h -I /home/soumith/code/pytorch/torch/lib/include/TH -I /home/soumith/code/pytorch/torch/lib/include/ -I .|grep libthsox */
void libthsox_Float_read_audio_file(const char *file_name, THFloatTensor* tensor, int* sample_rate);
void libthsox_Double_read_audio_file(const char *file_name, THDoubleTensor* tensor, int* sample_rate);
void libthsox_Byte_read_audio_file(const char *file_name, THByteTensor* tensor, int* sample_rate);
void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor, int* sample_rate);
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);
void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
int sample_rate);
\ No newline at end of file
#include <torch/torch.h>
#include <sox.h>
#include <algorithm>
#include <cstdint>
#include <stdexcept>
#include <vector>
namespace torch {
namespace audio {
namespace {
/// Helper struct to safely close the sox_format_t descriptor.
struct SoxDescriptor {
explicit SoxDescriptor(sox_format_t* fd) noexcept : fd_(fd) {}
SoxDescriptor(const SoxDescriptor& other) = delete;
SoxDescriptor(SoxDescriptor&& other) = delete;
SoxDescriptor& operator=(const SoxDescriptor& other) = delete;
SoxDescriptor& operator=(SoxDescriptor&& other) = delete;
~SoxDescriptor() {
sox_close(fd_);
}
sox_format_t* operator->() noexcept {
return fd_;
}
sox_format_t* get() noexcept {
return fd_;
}
private:
sox_format_t* fd_;
};
void read_audio(
SoxDescriptor& fd,
at::Tensor output,
int64_t number_of_channels,
int64_t buffer_length) {
std::vector<sox_sample_t> buffer(buffer_length);
const int64_t samples_read = sox_read(fd.get(), buffer.data(), buffer_length);
if (samples_read == 0) {
throw std::runtime_error(
"Error reading audio file: empty file or read failed in sox_read");
}
output.resize_({samples_read / number_of_channels, number_of_channels});
output = output.contiguous();
AT_DISPATCH_ALL_TYPES(output.type(), "read_audio_buffer", [&] {
auto* data = output.data<scalar_t>();
std::copy(buffer.begin(), buffer.begin() + samples_read, data);
});
}
int64_t write_audio(SoxDescriptor& fd, at::Tensor tensor) {
std::vector<sox_sample_t> buffer(tensor.numel());
AT_DISPATCH_ALL_TYPES(tensor.type(), "write_audio_buffer", [&] {
auto* data = tensor.data<scalar_t>();
std::copy(data, data + tensor.numel(), buffer.begin());
});
const auto samples_written =
sox_write(fd.get(), buffer.data(), buffer.size());
return samples_written;
}
} // namespace
int read_audio_file(const std::string& file_name, at::Tensor output) {
SoxDescriptor fd(sox_open_read(
file_name.c_str(),
/*signal=*/nullptr,
/*encoding=*/nullptr,
/*filetype=*/nullptr));
if (fd.get() == nullptr) {
throw std::runtime_error("Error opening audio file");
}
const int64_t number_of_channels = fd->signal.channels;
const int sample_rate = fd->signal.rate;
const int64_t buffer_length = fd->signal.length;
if (buffer_length == 0) {
throw std::runtime_error("Error reading audio file: unknown length");
}
read_audio(fd, output, number_of_channels, buffer_length);
return sample_rate;
}
void write_audio_file(
const std::string& file_name,
at::Tensor tensor,
const std::string& extension,
int sample_rate) {
if (!tensor.is_contiguous()) {
throw std::runtime_error(
"Error writing audio file: input tensor must be contiguous");
}
sox_signalinfo_t signal;
signal.rate = sample_rate;
signal.channels = tensor.size(1);
signal.length = tensor.numel();
signal.precision = 32; // precision in bits
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
signal.mult = nullptr;
#endif
SoxDescriptor fd(sox_open_write(
file_name.c_str(),
&signal,
/*encoding=*/nullptr,
extension.c_str(),
/*filetype=*/nullptr,
/*oob=*/nullptr));
if (fd.get() == nullptr) {
throw std::runtime_error(
"Error writing audio file: could not open file for writing");
}
const auto samples_written = write_audio(fd, tensor);
if (samples_written != tensor.numel()) {
throw std::runtime_error(
"Error writing audio file: could not write entire buffer");
}
}
} // namespace audio
} // namespace torch
PYBIND11_MODULE(_torch_sox, m) {
m.def(
"read_audio_file",
&torch::audio::read_audio_file,
"Reads an audio file into a tensor");
m.def(
"write_audio_file",
&torch::audio::write_audio_file,
"Writes data from a tensor into an audio file");
}
#include <string>
namespace at {
struct Tensor;
} // namespace at
namespace torch { namespace audio {
/// Reads an audio file from the given `path` into the `output` `Tensor` and
/// returns the sample rate of the audio file.
/// Throws `std::runtime_error` if the audio file could not be opened, or an
/// error ocurred during reading of the audio data.
int read_audio_file(const std::string& path, at::Tensor output);
/// Writes the data of a `Tensor` into an audio file at the given `path`, with
/// a certain extension (e.g. `wav`or `mp3`) and sample rate.
/// Throws `std::runtime_error` when the audio file could not be opened for
/// writing, or an error ocurred during writing of the audio data.
void write_audio_file(
const std::string& path,
at::Tensor tensor,
const std::string& extension,
int sample_rate);
}} // namespace torch::audio
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment