__init__.py 1.12 KB
Newer Older
SeanNaren's avatar
SeanNaren committed
1
2
import os

Soumith Chintala's avatar
Soumith Chintala committed
3
4
5
import torch

from cffi import FFI
6

Soumith Chintala's avatar
Soumith Chintala committed
7
8
9
ffi = FFI()
from ._ext import th_sox

10
11
12
13
14
15
16
17

def check_input(src):
    if not torch.is_tensor(src):
        raise TypeError('Expected a tensor, got %s' % type(src))
    if not src.__module__ == 'torch':
        raise TypeError('Expected a CPU based tensor, got %s' % type(src))


Soumith Chintala's avatar
Soumith Chintala committed
18
19
def load(filename, out=None):
    if out is not None:
20
        check_input(out)
Soumith Chintala's avatar
Soumith Chintala committed
21
22
    else:
        out = torch.FloatTensor()
23
24
25
    typename = type(out).__name__.replace('Tensor', '')
    func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
    sample_rate_p = ffi.new('int*')
SeanNaren's avatar
SeanNaren committed
26
    func(str(filename).encode("ascii"), out, sample_rate_p)
Soumith Chintala's avatar
Soumith Chintala committed
27
28
    sample_rate = sample_rate_p[0]
    return out, sample_rate
SeanNaren's avatar
SeanNaren committed
29
30
31
32


def save(filepath, src, sample_rate):
    filename, extension = os.path.splitext(filepath)
33
34
35
36
37
38
39
    if type(sample_rate) != int:
        raise TypeError('Sample rate should be a integer')

    check_input(src)
    typename = type(src).__name__.replace('Tensor', '')
    func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))

David Pollack's avatar
David Pollack committed
40
    func(bytes(filepath, "ascii"), src, extension[1:], sample_rate)