__init__.py 1.2 KB
Newer Older
SeanNaren's avatar
SeanNaren committed
1
2
import os

Soumith Chintala's avatar
Soumith Chintala committed
3
4
5
import torch

from cffi import FFI
6

Soumith Chintala's avatar
Soumith Chintala committed
7
8
9
ffi = FFI()
from ._ext import th_sox

David Pollack's avatar
David Pollack committed
10
11
from torchaudio import transforms
from torchaudio import datasets
12
13
14
15
16
17
18
19

def check_input(src):
    if not torch.is_tensor(src):
        raise TypeError('Expected a tensor, got %s' % type(src))
    if not src.__module__ == 'torch':
        raise TypeError('Expected a CPU based tensor, got %s' % type(src))


Soumith Chintala's avatar
Soumith Chintala committed
20
21
def load(filename, out=None):
    if out is not None:
22
        check_input(out)
Soumith Chintala's avatar
Soumith Chintala committed
23
24
    else:
        out = torch.FloatTensor()
25
26
27
    typename = type(out).__name__.replace('Tensor', '')
    func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
    sample_rate_p = ffi.new('int*')
David Pollack's avatar
David Pollack committed
28
    func(str(filename).encode("utf-8"), out, sample_rate_p)
Soumith Chintala's avatar
Soumith Chintala committed
29
30
    sample_rate = sample_rate_p[0]
    return out, sample_rate
SeanNaren's avatar
SeanNaren committed
31
32
33
34


def save(filepath, src, sample_rate):
    filename, extension = os.path.splitext(filepath)
35
36
37
38
39
40
41
    if type(sample_rate) != int:
        raise TypeError('Sample rate should be a integer')

    check_input(src)
    typename = type(src).__name__.replace('Tensor', '')
    func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))

David Pollack's avatar
David Pollack committed
42
    func(bytes(filepath, "utf-8"), src, bytes(extension[1:], "utf-8"), sample_rate)