Unverified Commit 13057829 authored by cpuhrsch's avatar cpuhrsch Committed by GitHub
Browse files

Revert "replace reshape by view. (#409)" (#594)

This reverts commit 60fd113c.
parent d678357f
......@@ -113,7 +113,7 @@ def istft(
# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1])
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
dtype = stft_matrix.dtype
device = stft_matrix.device
......@@ -196,7 +196,7 @@ def istft(
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
# unpack batch
y = y.view(shape[:-3] + y.shape[-1:])
y = y.reshape(shape[:-3] + y.shape[-1:])
if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
......@@ -241,7 +241,7 @@ def spectrogram(
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = torch.stft(
......@@ -249,7 +249,7 @@ def spectrogram(
)
# unpack batch
spec_f = spec_f.view(shape[:-1] + spec_f.shape[-3:])
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
if normalized:
spec_f /= window.pow(2.).sum().sqrt()
......@@ -314,7 +314,7 @@ def griffinlim(
# pack batch
shape = specgram.size()
specgram = specgram.view([-1] + list(shape[-2:]))
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = specgram.pow(1 / power)
......@@ -360,7 +360,7 @@ def griffinlim(
length=length)
# unpack batch
waveform = waveform.view(shape[:-2] + waveform.shape[-1:])
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
return waveform
......@@ -623,7 +623,7 @@ def phase_vocoder(
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.view([-1] + list(shape[-3:]))
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
time_steps = torch.arange(0,
complex_specgrams.size(-2),
......@@ -663,7 +663,7 @@ def phase_vocoder(
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.view(shape[:-3] + complex_specgrams_stretch.shape[1:])
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
return complex_specgrams_stretch
......@@ -689,7 +689,7 @@ def lfilter(
"""
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])
assert (a_coeffs.size(0) == b_coeffs.size(0))
assert (len(waveform.size()) == 2)
......@@ -732,7 +732,7 @@ def lfilter(
output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
# unpack batch
output = output.view(shape[:-1] + output.shape[-1:])
output = output.reshape(shape[:-1] + output.shape[-1:])
return output
......@@ -1362,7 +1362,7 @@ def mask_along_axis(
# pack batch
shape = specgram.size()
specgram = specgram.view([-1] + list(shape[-2:]))
specgram = specgram.reshape([-1] + list(shape[-2:]))
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
......@@ -1379,7 +1379,7 @@ def mask_along_axis(
raise ValueError('Only Frequency and Time masking are supported')
# unpack batch
specgram = specgram.view(shape[:-2] + specgram.shape[-2:])
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram
......@@ -1416,7 +1416,7 @@ def compute_deltas(
# pack batch
shape = specgram.size()
specgram = specgram.view(1, -1, shape[-1])
specgram = specgram.reshape(1, -1, shape[-1])
assert win_length >= 3
......@@ -1432,7 +1432,7 @@ def compute_deltas(
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
# unpack batch
output = output.view(shape)
output = output.reshape(shape)
return output
......@@ -1466,10 +1466,11 @@ def _add_noise_shaping(
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
"""
waveform = waveform.view(-1, waveform.size()[-1])
wf_shape = waveform.size()
waveform = waveform.reshape(-1, wf_shape[-1])
dithered_shape = dithered_waveform.size()
dithered_waveform = dithered_waveform.view(-1, dithered_shape[-1])
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
error = dithered_waveform - waveform
......@@ -1480,7 +1481,7 @@ def _add_noise_shaping(
error[index] = error_offset[:waveform.size()[1]]
noise_shaped = dithered_waveform + error
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(
......@@ -1513,7 +1514,7 @@ def _apply_probability_distribution(
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])
channel_size = waveform.size()[0] - 1
time_size = waveform.size()[-1] - 1
......@@ -1554,7 +1555,7 @@ def _apply_probability_distribution(
quantised_signal = quantised_signal_scaled / down_scaling
# unpack batch
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
def dither(
......@@ -1732,7 +1733,7 @@ def detect_pitch_frequency(
"""
# pack batch
shape = list(waveform.size())
waveform = waveform.view([-1] + shape[-1:])
waveform = waveform.reshape([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
......@@ -1743,7 +1744,7 @@ def detect_pitch_frequency(
freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))
freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
return freq
......
......@@ -247,7 +247,7 @@ class MelScale(torch.nn.Module):
# pack batch
shape = specgram.size()
specgram = specgram.view(-1, shape[-2], shape[-1])
specgram = specgram.reshape(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
......@@ -260,7 +260,7 @@ class MelScale(torch.nn.Module):
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
......@@ -485,7 +485,7 @@ class MFCC(torch.nn.Module):
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
waveform = waveform.reshape(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
......@@ -498,7 +498,7 @@ class MFCC(torch.nn.Module):
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
return mfcc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment