Commit d02dff76 authored by SeanNaren's avatar SeanNaren
Browse files

Refactors for better error checking/case checks

parent 97623fc1
...@@ -3,31 +3,25 @@ import os ...@@ -3,31 +3,25 @@ import os
import torch import torch
from cffi import FFI from cffi import FFI
ffi = FFI() ffi = FFI()
from ._ext import th_sox from ._ext import th_sox
def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
if not src.__module__ == 'torch':
raise TypeError('Expected a CPU based tensor, got %s' % type(src))
def load(filename, out=None): def load(filename, out=None):
if out is not None: if out is not None:
assert torch.is_tensor(out) check_input(out)
assert not out.is_cuda
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
typename = type(out).__name__.replace('Tensor', '')
if isinstance(out, torch.FloatTensor): func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
func = th_sox.libthsox_Float_read_audio_file
elif isinstance(out, torch.DoubleTensor):
func = th_sox.libthsox_Double_read_audio_file
elif isinstance(out, torch.ByteTensor):
func = th_sox.libthsox_Byte_read_audio_file
elif isinstance(out, torch.CharTensor):
func = th_sox.libthsox_Char_read_audio_file
elif isinstance(out, torch.ShortTensor):
func = th_sox.libthsox_Short_read_audio_file
elif isinstance(out, torch.IntTensor):
func = th_sox.libthsox_Int_read_audio_file
elif isinstance(out, torch.LongTensor):
func = th_sox.libthsox_Long_read_audio_file
sample_rate_p = ffi.new('int*') sample_rate_p = ffi.new('int*')
func(bytes(filename), out, sample_rate_p) func(bytes(filename), out, sample_rate_p)
sample_rate = sample_rate_p[0] sample_rate = sample_rate_p[0]
...@@ -35,24 +29,12 @@ def load(filename, out=None): ...@@ -35,24 +29,12 @@ def load(filename, out=None):
def save(filepath, src, sample_rate): def save(filepath, src, sample_rate):
assert torch.is_tensor(src)
assert not src.is_cuda
filename, extension = os.path.splitext(filepath) filename, extension = os.path.splitext(filepath)
assert type(sample_rate) == int if type(sample_rate) != int:
raise TypeError('Sample rate should be a integer')
if isinstance(src, torch.FloatTensor):
func = th_sox.libthsox_Float_write_audio_file check_input(src)
elif isinstance(src, torch.DoubleTensor): typename = type(src).__name__.replace('Tensor', '')
func = th_sox.libthsox_Double_write_audio_file func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
elif isinstance(src, torch.ByteTensor):
func = th_sox.libthsox_Byte_write_audio_file func(bytes(filepath), src, extension[1:], sample_rate)
elif isinstance(src, torch.CharTensor):
func = th_sox.libthsox_Char_write_audio_file
elif isinstance(src, torch.ShortTensor):
func = th_sox.libthsox_Short_write_audio_file
elif isinstance(src, torch.IntTensor):
func = th_sox.libthsox_Int_write_audio_file
elif isinstance(src, torch.LongTensor):
func = th_sox.libthsox_Long_write_audio_file
func(bytes(filepath), src, extension.replace('.', ''), sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment