Commit 97623fc1 authored by SeanNaren's avatar SeanNaren
Browse files

Added ability to save tensors

parent 35109196
......@@ -36,6 +36,7 @@ Quick Usage
```python
import torchaudio
sound, sample_rate = torchaudio.load('foo.mp3')
torchaudio.save('foo_save.mp3', sound, sample_rate) # saves tensor to file
```
API Reference
......@@ -49,3 +50,13 @@ audio.load(
)
```
torchaudio.save
```
saves a tensor into an audio file. The extension of the given path is used as the saving format.
audio.save(
string, # path to file
tensor, # NSamples x NChannels 2D tensor
number, # sample_rate of the audio to be saved as
)
```
import os
import torch
from cffi import FFI
......@@ -30,3 +32,27 @@ def load(filename, out=None):
func(bytes(filename), out, sample_rate_p)
sample_rate = sample_rate_p[0]
return out, sample_rate
def save(filepath, src, sample_rate):
assert torch.is_tensor(src)
assert not src.is_cuda
filename, extension = os.path.splitext(filepath)
assert type(sample_rate) == int
if isinstance(src, torch.FloatTensor):
func = th_sox.libthsox_Float_write_audio_file
elif isinstance(src, torch.DoubleTensor):
func = th_sox.libthsox_Double_write_audio_file
elif isinstance(src, torch.ByteTensor):
func = th_sox.libthsox_Byte_write_audio_file
elif isinstance(src, torch.CharTensor):
func = th_sox.libthsox_Char_write_audio_file
elif isinstance(src, torch.ShortTensor):
func = th_sox.libthsox_Short_write_audio_file
elif isinstance(src, torch.IntTensor):
func = th_sox.libthsox_Int_write_audio_file
elif isinstance(src, torch.LongTensor):
func = th_sox.libthsox_Long_write_audio_file
func(bytes(filepath), src, extension.replace('.', ''), sample_rate)
......@@ -44,4 +44,55 @@ void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sa
sox_close(fd);
}
void libthsox_(write_audio)(sox_format_t *fd, THTensor* src,
const char *extension, int sample_rate)
{
long nchannels = src->size[1];
long nsamples = src->size[0];
real* data = THTensor_(data)(src);
// convert audio to dest tensor
int x,k;
for (x=0; x<nsamples; x++) {
for (k=0; k<nchannels; k++) {
int32_t sample = (int32_t)(data[x*nchannels+k]);
size_t samples_written = sox_write(fd, &sample, 1);
if (samples_written != 1)
THError("[write_audio_file] write failed in sox_write");
}
}
}
void libthsox_(write_audio_file)(const char *file_name, THTensor* src,
const char *extension, int sample_rate)
{
if (THTensor_(isContiguous)(src) == 0)
THError("[write_audio_file] Input should be contiguous tensors");
long nchannels = src->size[1];
long nsamples = src->size[0];
sox_format_t *fd;
// Create sox objects and write into int32_t buffer
sox_signalinfo_t sinfo;
sinfo.rate = sample_rate;
sinfo.channels = nchannels;
sinfo.length = nsamples * nchannels;
sinfo.precision = sizeof(int32_t) * 8; /* precision in bits */
#if SOX_LIB_VERSION_CODE >= 918272 // >= 14.3.0
sinfo.mult = NULL;
#endif
fd = sox_open_write(file_name, &sinfo, NULL, extension, NULL, NULL);
if (fd == NULL)
THError("[write_audio_file] Failure to open file for writing");
libthsox_(write_audio)(fd, src, extension, sample_rate);
// free buffer and sox structures
sox_close(fd);
return;
}
#endif
......@@ -3,4 +3,5 @@
#else
void libthsox_(read_audio_file)(const char *file_name, THTensor* tensor, int* sample_rate);
void libthsox_(write_audio_file)(const char *file_name, THTensor* src, const char *extension, int sample_rate);
#endif
......@@ -15,3 +15,18 @@ void libthsox_Char_read_audio_file(const char *file_name, THCharTensor* tensor,
void libthsox_Short_read_audio_file(const char *file_name, THShortTensor* tensor, int* sample_rate);
void libthsox_Int_read_audio_file(const char *file_name, THIntTensor* tensor, int* sample_rate);
void libthsox_Long_read_audio_file(const char *file_name, THLongTensor* tensor, int* sample_rate);
void libthsox_Float_write_audio_file(const char *file_name, THFloatTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Double_write_audio_file(const char *file_name, THDoubleTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Byte_write_audio_file(const char *file_name, THByteTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Char_write_audio_file(const char *file_name, THCharTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Short_write_audio_file(const char *file_name, THShortTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Int_write_audio_file(const char *file_name, THIntTensor* tensor, const char *extension,
int sample_rate);
void libthsox_Long_write_audio_file(const char *file_name, THLongTensor* tensor, const char *extension,
int sample_rate);
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment