Commit 74b5faa7 authored by David Pollack's avatar David Pollack Committed by GitHub
Browse files

Merge pull request #19 from raff/python2

Quick and dirty fix for "bytes" (in torchaudio.save) not available in python2
parents 9edce71c f010e165
from __future__ import print_function
import torch import torch
import torchaudio import torchaudio
import torchaudio.transforms as transforms import torchaudio.transforms as transforms
......
import os import os
import sys
import torch import torch
...@@ -10,6 +11,11 @@ from ._ext import th_sox ...@@ -10,6 +11,11 @@ from ._ext import th_sox
from torchaudio import transforms from torchaudio import transforms
from torchaudio import datasets from torchaudio import datasets
if sys.version_info >= (3, 0):
_bytes = bytes
else:
_bytes = lambda s, e: s.encode(e)
def check_input(src): def check_input(src):
if not torch.is_tensor(src): if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src)) raise TypeError('Expected a tensor, got %s' % type(src))
...@@ -64,4 +70,4 @@ def save(filepath, src, sample_rate): ...@@ -64,4 +70,4 @@ def save(filepath, src, sample_rate):
check_input(src) check_input(src)
typename = type(src).__name__.replace('Tensor', '') typename = type(src).__name__.replace('Tensor', '')
func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename)) func = getattr(th_sox, 'libthsox_{}_write_audio_file'.format(typename))
func(bytes(filepath, "utf-8"), src, bytes(extension[1:], "utf-8"), sample_rate) func(_bytes(filepath, "utf-8"), src, _bytes(extension[1:], "utf-8"), sample_rate)
from __future__ import division from __future__ import division, print_function
import torch import torch
import numpy as np import numpy as np
try: try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment