"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "4a3d52850b2a1a3da47c91525b8899465b76606e"
Commit 9801caf6 authored by Karl Ostmo's avatar Karl Ostmo Committed by Vincent QB
Browse files

Fix several errors in tests run by Travis (#380)

* Declare file encoding to support special characters

* fix missing utf_8_encoder error in Travis tests

* Py 2.7 backwards-compat iterator

* ensure integer argument to torch.nn.functional.pad

* cast match.ceil result as integer
parent 805d7922
...@@ -42,6 +42,12 @@ def unicode_csv_reader(unicode_csv_data, **kwargs): ...@@ -42,6 +42,12 @@ def unicode_csv_reader(unicode_csv_data, **kwargs):
csv.field_size_limit(maxInt) csv.field_size_limit(maxInt)
if six.PY2: if six.PY2:
# Implementation borrowed from docs:
# https://docs.python.org/3.0/library/csv.html#examples
def utf_8_encoder(unicode_csv_data):
for line in unicode_csv_data:
yield line.encode('utf-8')
# csv.py doesn't do Unicode; encode temporarily as UTF-8: # csv.py doesn't do Unicode; encode temporarily as UTF-8:
csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs) csv_reader = csv.reader(utf_8_encoder(unicode_csv_data), **kwargs)
for row in csv_reader: for row in csv_reader:
...@@ -338,6 +344,10 @@ class _ThreadedIterator(threading.Thread): ...@@ -338,6 +344,10 @@ class _ThreadedIterator(threading.Thread):
raise StopIteration raise StopIteration
return next_item return next_item
# Required for Python 2.7 compatibility
def next(self):
return self.__next__()
def bg_iterator(iterable, maxsize): def bg_iterator(iterable, maxsize):
return _ThreadedIterator(iterable, maxsize=maxsize) return _ThreadedIterator(iterable, maxsize=maxsize)
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import math import math
...@@ -1093,12 +1095,12 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low): ...@@ -1093,12 +1095,12 @@ def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
EPSILON = 10 ** (-9) EPSILON = 10 ** (-9)
# Number of lags to check # Number of lags to check
lags = math.ceil(sample_rate / freq_low) lags = int(math.ceil(sample_rate / freq_low))
frame_size = int(math.ceil(sample_rate * frame_time)) frame_size = int(math.ceil(sample_rate * frame_time))
waveform_length = waveform.size()[-1] waveform_length = waveform.size()[-1]
num_of_frames = math.ceil(waveform_length / frame_size) num_of_frames = int(math.ceil(waveform_length / frame_size))
p = lags + num_of_frames * frame_size - waveform_length p = lags + num_of_frames * frame_size - waveform_length
waveform = torch.nn.functional.pad(waveform, (0, p)) waveform = torch.nn.functional.pad(waveform, (0, p))
...@@ -1147,7 +1149,7 @@ def _find_max_per_frame(nccf, sample_rate, freq_high): ...@@ -1147,7 +1149,7 @@ def _find_max_per_frame(nccf, sample_rate, freq_high):
to the first half of lags, then the latter is taken. to the first half of lags, then the latter is taken.
""" """
lag_min = math.ceil(sample_rate / freq_high) lag_min = int(math.ceil(sample_rate / freq_high))
# Find near enough max that is smallest # Find near enough max that is smallest
......
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
from warnings import warn from warnings import warn
import math import math
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment