Commit 2936245a authored by Soumith Chintala's avatar Soumith Chintala Committed by GitHub
Browse files

Merge pull request #3 from SeanNaren/save

Added ability to save tensors
parents 35109196 1fdb6eae
...@@ -36,6 +36,7 @@ Quick Usage ...@@ -36,6 +36,7 @@ Quick Usage
```python ```python
import torchaudio import torchaudio
sound, sample_rate = torchaudio.load('foo.mp3') sound, sample_rate = torchaudio.load('foo.mp3')
torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file
``` ```
API Reference API Reference
...@@ -49,3 +50,13 @@ audio.load( ...@@ -49,3 +50,13 @@ audio.load(
) )
``` ```
torchaudio.save
```
saves a tensor into an audio file. The extension of the given path is used as the saving format.
audio.save(
string, # path to file
tensor, # NSamples x NChannels 2D tensor
number, # sample_rate of the audio to be saved as
)
```
import os
import torch import torch
from cffi import FFI from cffi import FFI
ffi = FFI() ffi = FFI()
from ._ext import th_sox from ._ext import th_sox
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if not src.__module__ == 'torch':
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filename, out=None): def load(filename, out=None):
if out is not None: if out is not None:
assert torch.is_tensor(out) check_input(out)
assert not out.is_cuda
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
typename = type(out).__name__.replace('Tensor', '')
if isinstance(out, torch.FloatTensor): func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
func = th_sox.libthsox_Float_read_audio_file
elif isinstance(out, torch.DoubleTensor):
func = th_sox.libthsox_Double_read_audio_file
elif isinstance(out, torch.ByteTensor):
func = th_sox.libthsox_Byte_read_audio_file
elif isinstance(out, torch.CharTensor):
func = th_sox.libthsox_Char_read_audio_file
elif isinstance(out, torch.ShortTensor):
func = th_sox.libthsox_Short_read_audio_file
elif isinstance(out, torch.IntTensor):
func = th_sox.libthsox_Int_read_audio_file
elif isinstance(out, torch.LongTensor):
func = th_sox.libthsox_Long_read_audio_file
sample_rate_p = ffi.new('int*') sample_rate_p = ffi.new('int*')
func(bytes(filename), out, sample_rate_p) func(bytes(filename), out, sample_rate_p)
sample_rate = sample_rate_p[0] sample_rate = sample_rate_p[0]
return out, sample_rate return out, sample_rate
def save(filepath, src, sample_rate):
filename, extension = os.path.splitext(filepath)
if type(sample_rate) != int:
raise TypeError('Sample rate should be a integer')
check_input(src)
typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(bytes(filepath), src, extension[1:], sample_rate)
...@@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa ...@@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa
sox_close(fd); sox_close(fd);
} }
void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
const char *extension, int sample_rate)
{
long nchannels = src->size[1];
long nsamples = src->size[0];
real* data = THTensor_(data)(src);
// convert audio to dest tensor
int x,k;
for (x=0; x<nsamples; x++) {
for (k=0; k<nchannels; k++) {
int32_t sample = (int32_t)(data[x*nchannels+k]);
size_t samples_written = sox_write(fd, &sample, 1);
if (samples_written != 1)
THError("[write_audio_file] write failed in sox_write");
}
}
}
void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
const char *extension, int sample_rate)
{
if (THTensor_(isContiguous)(src) == 0)
THError("[write_audio_file] Input should be contiguous tensors");
long nchannels = src->size[1];
long nsamples = src->size[0];
sox_format_t *fd;
// Create sox objects and write into int32_t buffer
sox_signalinfo_t sinfo;
sinfo.rate = sample_rate;
sinfo.channels = nchannels;
sinfo.length = nsamples * nchannels;
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
sinfo.mult = NULL;
#endif
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
if (fd == NULL)
THError("[write_audio_file] Failure to open file for writing");
libthsox_(write_audio)(fd, src, extension, sample_rate);
// free buffer and sox structures
sox_close(fd);
return;
}
#endif #endif
...@@ -3,4 +3,5 @@ ...@@ -3,4 +3,5 @@
#else #else
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate); void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
#endif #endif
...@@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor, ...@@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor,
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate); void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate); void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate); void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);
void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
int sample_rate);
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment