Unverified Commit d41d30ab authored by moto's avatar moto Committed by GitHub
Browse files

Separate CPU and GPU tests for functions torchscript test + Fix devices in two functionals (#528)



* Separate CPU and GPU tests for functions torchscript test

* fix indentation
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent d47f42a3
...@@ -10,6 +10,18 @@ import torchaudio.transforms as T ...@@ -10,6 +10,18 @@ import torchaudio.transforms as T
import common_utils import common_utils
def _assert_functional_consistency(func, tensor, device, shape_only=False):
tensor = tensor.to(device)
ts_func = torch.jit.script(func)
output = func(tensor)
ts_output = ts_func(tensor)
if shape_only:
assert ts_output.shape == output.shape, (ts_output.shape, output.shape)
else:
torch.testing.assert_allclose(ts_output, output)
def _assert_transforms_consistency(transform, tensor, device): def _assert_transforms_consistency(transform, tensor, device):
tensor = tensor.to(device) tensor = tensor.to(device)
transform = transform.to(device) transform = transform.to(device)
...@@ -19,295 +31,382 @@ def _assert_transforms_consistency(transform, tensor, device): ...@@ -19,295 +31,382 @@ def _assert_transforms_consistency(transform, tensor, device):
torch.testing.assert_allclose(ts_output, output) torch.testing.assert_allclose(ts_output, output)
def _assert_functional_consistency(py_method, *args, shape_only=False, **kwargs): class _FunctionalTestMixin:
jit_method = torch.jit.script(py_method) """Implements test for `functinoal` modul that are performed for different devices"""
device = None
jit_out = jit_method(*args, **kwargs)
py_out = py_method(*args, **kwargs)
if shape_only:
assert jit_out.shape == py_out.shape, (jit_out.shape, py_out.shape)
else:
torch.testing.assert_allclose(jit_out, py_out)
def _test_lfilter(waveform):
"""
Design an IIR lowpass filter using scipy.signal filter design
https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
Example
>>> from scipy.signal import iirdesign
>>> b, a = iirdesign(0.2, 0.3, 1, 60)
"""
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=waveform.device,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=waveform.device,
)
_assert_functional_consistency(F.lfilter, waveform, a_coeffs, b_coeffs)
def _assert_consistency(self, func, tensor, shape_only=False):
return _assert_functional_consistency(func, tensor, self.device, shape_only=shape_only)
class TestFunctional(unittest.TestCase):
"""Test functions in `functional` module."""
def test_spectrogram(self): def test_spectrogram(self):
tensor = torch.rand((1, 1000)) def func(tensor):
n_fft = 400 n_fft = 400
ws = 400 ws = 400
hop = 200 hop = 200
pad = 0 pad = 0
window = torch.hann_window(ws) window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2 power = 2.
normalize = False normalize = False
return F.spectrogram(tensor, pad, window, n_fft, hop, ws, power, normalize)
_assert_functional_consistency( tensor = torch.rand((1, 1000))
F.spectrogram, tensor, pad, window, n_fft, hop, ws, power, normalize self._assert_consistency(func, tensor)
)
def test_griffinlim(self): def test_griffinlim(self):
tensor = torch.rand((1, 201, 6)) def func(tensor):
n_fft = 400 n_fft = 400
ws = 400 ws = 400
hop = 200 hop = 200
window = torch.hann_window(ws) window = torch.hann_window(ws, device=tensor.device, dtype=tensor.dtype)
power = 2 power = 2.
normalize = False normalize = False
momentum = 0.99 momentum = 0.99
n_iter = 32 n_iter = 32
length = 1000 length = 1000
rand_int = False
return F.griffinlim(tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, rand_int)
_assert_functional_consistency( tensor = torch.rand((1, 201, 6))
F.griffinlim, tensor, window, n_fft, hop, ws, power, normalize, n_iter, momentum, length, 0 self._assert_consistency(func, tensor)
)
def test_compute_deltas(self): def test_compute_deltas(self):
def func(tensor):
win_length = 2 * 7 + 1
return F.compute_deltas(tensor, win_length=win_length)
channel = 13 channel = 13
n_mfcc = channel * 3 n_mfcc = channel * 3
time = 1021 time = 1021
win_length = 2 * 7 + 1 tensor = torch.randn(channel, n_mfcc, time)
specgram = torch.randn(channel, n_mfcc, time) self._assert_consistency(func, tensor)
_assert_functional_consistency(F.compute_deltas, specgram, win_length=win_length)
def test_detect_pitch_frequency(self): def test_detect_pitch_frequency(self):
filepath = os.path.join( filepath = os.path.join(
common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3') common_utils.TEST_DIR_PATH, 'assets', 'steam-train-whistle-daniel_simon.mp3')
waveform, sample_rate = torchaudio.load(filepath) waveform, _ = torchaudio.load(filepath)
_assert_functional_consistency(F.detect_pitch_frequency, waveform, sample_rate)
def func(tensor):
sample_rate = 44100
return F.detect_pitch_frequency(tensor, sample_rate)
self._assert_consistency(func, waveform)
def test_create_fb_matrix(self): def test_create_fb_matrix(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
def func(_):
n_stft = 100 n_stft = 100
f_min = 0.0 f_min = 0.0
f_max = 20.0 f_max = 20.0
n_mels = 10 n_mels = 10
sample_rate = 16000 sample_rate = 16000
return F.create_fb_matrix(n_stft, f_min, f_max, n_mels, sample_rate)
_assert_functional_consistency(F.create_fb_matrix, n_stft, f_min, f_max, n_mels, sample_rate) dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
def test_amplitude_to_DB(self): def test_amplitude_to_DB(self):
spec = torch.rand((6, 201)) def func(tensor):
multiplier = 10.0 multiplier = 10.0
amin = 1e-10 amin = 1e-10
db_multiplier = 0.0 db_multiplier = 0.0
top_db = 80.0 top_db = 80.0
return F.amplitude_to_DB(tensor, multiplier, amin, db_multiplier, top_db)
_assert_functional_consistency(F.amplitude_to_DB, spec, multiplier, amin, db_multiplier, top_db) tensor = torch.rand((6, 201))
self._assert_consistency(func, tensor)
def test_DB_to_amplitude(self): def test_DB_to_amplitude(self):
x = torch.rand((1, 100)) def func(tensor):
ref = 1. ref = 1.
power = 1. power = 1.
return F.DB_to_amplitude(tensor, ref, power)
_assert_functional_consistency(F.DB_to_amplitude, x, ref, power) tensor = torch.rand((1, 100))
self._assert_consistency(func, tensor)
def test_create_dct(self): def test_create_dct(self):
if self.device != torch.device('cpu'):
raise unittest.SkipTest('No need to perform test on device other than CPU')
def func(_):
n_mfcc = 40 n_mfcc = 40
n_mels = 128 n_mels = 128
norm = "ortho" norm = "ortho"
return F.create_dct(n_mfcc, n_mels, norm)
_assert_functional_consistency(F.create_dct, n_mfcc, n_mels, norm) dummy = torch.zeros(1, 1)
self._assert_consistency(func, dummy)
def test_mu_law_encoding(self): def test_mu_law_encoding(self):
tensor = torch.rand((1, 10)) def func(tensor):
qc = 256 qc = 256
return F.mu_law_encoding(tensor, qc)
_assert_functional_consistency(F.mu_law_encoding, tensor, qc) tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)
def test_mu_law_decoding(self): def test_mu_law_decoding(self):
tensor = torch.rand((1, 10)) def func(tensor):
qc = 256 qc = 256
return F.mu_law_decoding(tensor, qc)
_assert_functional_consistency(F.mu_law_decoding, tensor, qc) tensor = torch.rand((1, 10))
self._assert_consistency(func, tensor)
def test_complex_norm(self): def test_complex_norm(self):
complex_tensor = torch.randn(1, 2, 1025, 400, 2) def func(tensor):
power = 2 power = 2.
return F.complex_norm(tensor, power)
_assert_functional_consistency(F.complex_norm, complex_tensor, power) tensor = torch.randn(1, 2, 1025, 400, 2)
_assert_functional_consistency(func, tensor, self.device)
def test_mask_along_axis(self): def test_mask_along_axis(self):
specgram = torch.randn(2, 1025, 400) def func(tensor):
mask_param = 100 mask_param = 100
mask_value = 30. mask_value = 30.
axis = 2 axis = 2
return F.mask_along_axis(tensor, mask_param, mask_value, axis)
_assert_functional_consistency(F.mask_along_axis, specgram, mask_param, mask_value, axis) tensor = torch.randn(2, 1025, 400)
self._assert_consistency(func, tensor)
def test_mask_along_axis_iid(self): def test_mask_along_axis_iid(self):
specgrams = torch.randn(4, 2, 1025, 400) def func(tensor):
mask_param = 100 mask_param = 100
mask_value = 30. mask_value = 30.
axis = 2 axis = 2
return F.mask_along_axis_iid(tensor, mask_param, mask_value, axis)
_assert_functional_consistency(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis) tensor = torch.randn(4, 2, 1025, 400)
self._assert_consistency(func, tensor)
def test_gain(self): def test_gain(self):
tensor = torch.rand((1, 1000)) def func(tensor):
gainDB = 2.0 gainDB = 2.0
return F.gain(tensor, gainDB)
_assert_functional_consistency(F.gain, tensor, gainDB) tensor = torch.rand((1, 1000))
self._assert_consistency(func, tensor)
def test_dither_TPDF(self):
def func(tensor):
return F.dither(tensor, 'TPDF')
def test_dither(self):
tensor = torch.rand((2, 1000)) tensor = torch.rand((2, 1000))
self._assert_consistency(func, tensor, shape_only=True)
def test_dither_RPDF(self):
def func(tensor):
return F.dither(tensor, 'RPDF')
_assert_functional_consistency(F.dither, tensor, shape_only=True) tensor = torch.rand((2, 1000))
_assert_functional_consistency(F.dither, tensor, "RPDF", shape_only=True) self._assert_consistency(func, tensor, shape_only=True)
_assert_functional_consistency(F.dither, tensor, "GPDF", shape_only=True)
def test_dither_GPDF(self):
def func(tensor):
return F.dither(tensor, 'GPDF')
tensor = torch.rand((2, 1000))
self._assert_consistency(func, tensor, shape_only=True)
def test_lfilter(self): def test_lfilter(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_test_lfilter(waveform)
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") def func(tensor):
def test_lfilter_cuda(self): # Design an IIR lowpass filter using scipy.signal filter design
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") # https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.iirdesign.html#scipy.signal.iirdesign
waveform, _ = torchaudio.load(filepath, normalization=True) #
_test_lfilter(waveform.cuda(device=torch.device("cuda:0"))) # Example
# >>> from scipy.signal import iirdesign
# >>> b, a = iirdesign(0.2, 0.3, 1, 60)
b_coeffs = torch.tensor(
[
0.00299893,
-0.0051152,
0.00841964,
-0.00747802,
0.00841964,
-0.0051152,
0.00299893,
],
device=tensor.device,
dtype=tensor.dtype,
)
a_coeffs = torch.tensor(
[
1.0,
-4.8155751,
10.2217618,
-12.14481273,
8.49018171,
-3.3066882,
0.56088705,
],
device=tensor.device,
dtype=tensor.dtype,
)
return F.lfilter(tensor, a_coeffs, b_coeffs)
def test_lowpass(self): self._assert_consistency(func, waveform)
cutoff_freq = 3000
def test_lowpass(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.lowpass_biquad, waveform, sample_rate, cutoff_freq)
def test_highpass(self): def func(tensor):
cutoff_freq = 2000 sample_rate = 44100
cutoff_freq = 3000.
return F.lowpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
def test_highpass(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.highpass_biquad, waveform, sample_rate, cutoff_freq)
def func(tensor):
sample_rate = 44100
cutoff_freq = 2000.
return F.highpass_biquad(tensor, sample_rate, cutoff_freq)
self._assert_consistency(func, waveform)
def test_allpass(self): def test_allpass(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav')
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
return F.allpass_biquad(tensor, sample_rate, central_freq, q)
filepath = os.path.join(common_utils.TEST_DIR_PATH, 'assets', 'whitenoise.wav') self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.allpass_biquad, waveform, sample_rate, central_freq, q)
def test_bandpass_with_csg(self): def test_bandpass_with_csg(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
const_skirt_gain = True const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandpass_withou_csg(self): def test_bandpass_withou_csg(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
const_skirt_gain = False const_skirt_gain = True
return F.bandpass_biquad(tensor, sample_rate, central_freq, q, const_skirt_gain)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(
F.bandpass_biquad, waveform, sample_rate, central_freq, q, const_skirt_gain)
def test_bandreject(self): def test_bandreject(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
return F.bandreject_biquad(tensor, sample_rate, central_freq, q)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(
F.bandreject_biquad, waveform, sample_rate, central_freq, q)
def test_band_with_noise(self): def test_band_with_noise(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
noise = True noise = True
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_band_without_noise(self): def test_band_without_noise(self):
central_freq = 1000 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
central_freq = 1000.
q = 0.707 q = 0.707
noise = False noise = False
return F.band_biquad(tensor, sample_rate, central_freq, q, noise)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.band_biquad, waveform, sample_rate, central_freq, q, noise)
def test_treble(self): def test_treble(self):
gain = 40 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
central_freq = 1000 waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
gain = 40.
central_freq = 1000.
q = 0.707 q = 0.707
return F.treble_biquad(tensor, sample_rate, gain, central_freq, q)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.treble_biquad, waveform, sample_rate, gain, central_freq, q)
def test_deemph(self): def test_deemph(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.deemph_biquad, waveform, sample_rate)
def func(tensor):
sample_rate = 44100
return F.deemph_biquad(tensor, sample_rate)
self._assert_consistency(func, waveform)
def test_riaa(self): def test_riaa(self):
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, sample_rate = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.riaa_biquad, waveform, sample_rate)
def func(tensor):
sample_rate = 44100
return F.riaa_biquad(tensor, sample_rate)
self._assert_consistency(func, waveform)
def test_equalizer(self): def test_equalizer(self):
center_freq = 300 filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
gain = 1 waveform, _ = torchaudio.load(filepath, normalization=True)
def func(tensor):
sample_rate = 44100
center_freq = 300.
gain = 1.
q = 0.707 q = 0.707
return F.equalizer_biquad(tensor, sample_rate, center_freq, gain, q)
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") self._assert_consistency(func, waveform)
waveform, sample_rate = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(
F.equalizer_biquad, waveform, sample_rate, center_freq, gain, q)
def test_perf_biquad_filtering(self): def test_perf_biquad_filtering(self):
a = torch.tensor([0.7, 0.2, 0.6])
b = torch.tensor([0.4, 0.2, 0.9])
filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav") filepath = os.path.join(common_utils.TEST_DIR_PATH, "assets", "whitenoise.wav")
waveform, _ = torchaudio.load(filepath, normalization=True) waveform, _ = torchaudio.load(filepath, normalization=True)
_assert_functional_consistency(F.lfilter, waveform, a, b)
def func(tensor):
a = torch.tensor([0.7, 0.2, 0.6], device=tensor.device, dtype=tensor.dtype)
b = torch.tensor([0.4, 0.2, 0.9], device=tensor.device, dtype=tensor.dtype)
return F.lfilter(tensor, a, b)
self._assert_consistency(func, waveform)
class _TransformsTestMixin: class _TransformsTestMixin:
...@@ -392,6 +491,17 @@ class _TransformsTestMixin: ...@@ -392,6 +491,17 @@ class _TransformsTestMixin:
self._assert_consistency(T.Vol(1.1), waveform) self._assert_consistency(T.Vol(1.1), waveform)
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU"""
device = torch.device('cpu')
@unittest.skipIf(not torch.cuda.is_available(), 'CUDA not available')
class TestFunctionalCUDA(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on GPU"""
device = torch.device('cuda')
class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase): class TestTransformsCPU(_TransformsTestMixin, unittest.TestCase):
"""Test suite for Transforms module on CPU""" """Test suite for Transforms module on CPU"""
device = torch.device('cpu') device = torch.device('cpu')
......
...@@ -1210,13 +1210,16 @@ def mask_along_axis_iid( ...@@ -1210,13 +1210,16 @@ def mask_along_axis_iid(
if axis != 2 and axis != 3: if axis != 2 and axis != 3:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError('Only Frequency and Time masking are supported')
value = torch.rand(specgrams.shape[:2]) * mask_param device = specgrams.device
min_value = torch.rand(specgrams.shape[:2]) * (specgrams.size(axis) - value) dtype = specgrams.dtype
value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * mask_param
min_value = torch.rand(specgrams.shape[:2], device=device, dtype=dtype) * (specgrams.size(axis) - value)
# Create broadcastable mask # Create broadcastable mask
mask_start = (min_value.long())[..., None, None].float() mask_start = min_value[..., None, None]
mask_end = (min_value.long() + value.long())[..., None, None].float() mask_end = (min_value + value)[..., None, None]
mask = torch.arange(0, specgrams.size(axis)).float() mask = torch.arange(0, specgrams.size(axis), device=device, dtype=dtype)
# Per batch example masking # Per batch example masking
specgrams = specgrams.transpose(axis, -1) specgrams = specgrams.transpose(axis, -1)
...@@ -1298,6 +1301,8 @@ def compute_deltas( ...@@ -1298,6 +1301,8 @@ def compute_deltas(
>>> delta = compute_deltas(specgram) >>> delta = compute_deltas(specgram)
>>> delta2 = compute_deltas(delta) >>> delta2 = compute_deltas(delta)
""" """
device = specgram.device
dtype = specgram.dtype
# pack batch # pack batch
shape = specgram.size() shape = specgram.size()
...@@ -1312,7 +1317,7 @@ def compute_deltas( ...@@ -1312,7 +1317,7 @@ def compute_deltas(
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode) specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
kernel = (torch.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype).repeat(specgram.shape[1], 1, 1)) kernel = torch.arange(-n, n + 1, 1, device=device, dtype=dtype).repeat(specgram.shape[1], 1, 1)
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
...@@ -1431,7 +1436,7 @@ def _apply_probability_distribution( ...@@ -1431,7 +1436,7 @@ def _apply_probability_distribution(
signal_scaled_dis = signal_scaled + gaussian signal_scaled_dis = signal_scaled + gaussian
else: else:
# dtype needed for https://github.com/pytorch/pytorch/issues/32358 # dtype needed for https://github.com/pytorch/pytorch/issues/32358
TPDF = torch.bartlett_window(time_size + 1, dtype=torch.float) TPDF = torch.bartlett_window(time_size + 1, dtype=signal_scaled.dtype, device=signal_scaled.device)
TPDF = TPDF.repeat((channel_size + 1), 1) TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF signal_scaled_dis = signal_scaled + TPDF
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment