"llm/git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "fcf4d60eeea12b3b25877b09aae2c3e6a38b5bbe"
Commit c844ac63 authored by Peter Goldsborough's avatar Peter Goldsborough Committed by Sam Gross
Browse files

Fixes after tensor/variable merge (#33)

parent 4a3e5006
...@@ -18,10 +18,14 @@ else: ...@@ -18,10 +18,14 @@ else:
return s.encode(e) return s.encode(e)
def get_tensor_type_name(tensor):
return tensor.type().replace('torch.', '').replace('Tensor', '')
def check_input(src): def check_input(src):
if not torch.is_tensor(src): if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src)) raise TypeError('Expected a tensor, got %s' % type(src))
if not src.__module__ == 'torch': if src.is_cuda:
raise TypeError('Expected a CPU based tensor, got %s' % type(src)) raise TypeError('Expected a CPU based tensor, got %s' % type(src))
...@@ -57,7 +61,7 @@ def load(filepath, out=None, normalization=None): ...@@ -57,7 +61,7 @@ def load(filepath, out=None, normalization=None):
else: else:
out = torch.FloatTensor() out = torch.FloatTensor()
# load audio signal # load audio signal
typename = type(out).__name__.replace('Tensor', '') typename = get_tensor_type_name(out)
func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_read_audio_file'.format(typename))
sample_rate_p = ffi.new('int*') sample_rate_p = ffi.new('int*')
func(str(filepath).encode("utf-8"), out, sample_rate_p) func(str(filepath).encode("utf-8"), out, sample_rate_p)
...@@ -109,7 +113,7 @@ def save(filepath, src, sample_rate): ...@@ -109,7 +113,7 @@ def save(filepath, src, sample_rate):
# save data to file # save data to file
filename, extension = os.path.splitext(filepath) filename, extension = os.path.splitext(filepath)
check_input(src) check_input(src)
typename = type(src).__name__.replace('Tensor', '') typename = get_tensor_type_name(src)
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(_bytes(filepath, "utf-8"), src, func(_bytes(filepath, "utf-8"), src,
_bytes(extension[1:], "utf-8"), sample_rate) _bytes(extension[1:], "utf-8"), sample_rate)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment