Unverified Commit 60fd113c authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

replace reshape by view. (#409)

parent b32606d6
......@@ -129,7 +129,7 @@ def istft(
# pack batch
shape = stft_matrix.size()
stft_matrix = stft_matrix.reshape(-1, shape[-3], shape[-2], shape[-1])
stft_matrix = stft_matrix.view(-1, shape[-3], shape[-2], shape[-1])
dtype = stft_matrix.dtype
device = stft_matrix.device
......@@ -214,7 +214,7 @@ def istft(
y = (y / window_envelop).squeeze(1) # size (channel, expected_signal_len)
# unpack batch
y = y.reshape(shape[:-3] + y.shape[-1:])
y = y.view(shape[:-3] + y.shape[-1:])
if stft_matrix_dim == 3: # remove the channel dimension
y = y.squeeze(0)
......@@ -253,7 +253,7 @@ def spectrogram(
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
waveform = waveform.view(-1, shape[-1])
# default values are consistent with librosa.core.spectrum._spectrogram
spec_f = _stft(
......@@ -261,7 +261,7 @@ def spectrogram(
)
# unpack batch
spec_f = spec_f.reshape(shape[:-1] + spec_f.shape[-3:])
spec_f = spec_f.view(shape[:-1] + spec_f.shape[-3:])
if normalized:
spec_f /= window.pow(2.).sum().sqrt()
......@@ -317,7 +317,7 @@ def griffinlim(
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = specgram.view([-1] + list(shape[-2:]))
specgram = specgram.pow(1 / power)
......@@ -363,7 +363,7 @@ def griffinlim(
length=length)
# unpack batch
waveform = waveform.reshape(shape[:-2] + waveform.shape[-1:])
waveform = waveform.view(shape[:-2] + waveform.shape[-1:])
return waveform
......@@ -587,7 +587,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = complex_specgrams.view([-1] + list(shape[-3:]))
time_steps = torch.arange(0,
complex_specgrams.size(-2),
......@@ -627,7 +627,7 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
complex_specgrams_stretch = complex_specgrams_stretch.view(shape[:-3] + complex_specgrams_stretch.shape[1:])
return complex_specgrams_stretch
......@@ -654,7 +654,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
waveform = waveform.view(-1, shape[-1])
assert(a_coeffs.size(0) == b_coeffs.size(0))
assert(len(waveform.size()) == 2)
......@@ -697,7 +697,7 @@ def lfilter(waveform, a_coeffs, b_coeffs):
output = torch.clamp(padded_output_waveform[:, (n_order - 1):], min=-1., max=1.)
# unpack batch
output = output.reshape(shape[:-1] + output.shape[-1:])
output = output.view(shape[:-1] + output.shape[-1:])
return output
......@@ -876,7 +876,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = specgram.view([-1] + list(shape[-2:]))
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
......@@ -893,7 +893,7 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
raise ValueError('Only Frequency and Time masking are supported')
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
specgram = specgram.view(shape[:-2] + specgram.shape[-2:])
return specgram
......@@ -925,7 +925,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
# pack batch
shape = specgram.size()
specgram = specgram.reshape(1, -1, shape[-1])
specgram = specgram.view(1, -1, shape[-1])
assert win_length >= 3
......@@ -945,7 +945,7 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
output = torch.nn.functional.conv1d(specgram, kernel, groups=specgram.shape[1]) / denom
# unpack batch
output = output.reshape(shape)
output = output.view(shape)
return output
......@@ -974,11 +974,10 @@ def _add_noise_shaping(dithered_waveform, waveform):
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
"""
wf_shape = waveform.size()
waveform = waveform.reshape(-1, wf_shape[-1])
waveform = waveform.view(-1, waveform.size()[-1])
dithered_shape = dithered_waveform.size()
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
dithered_waveform = dithered_waveform.view(-1, dithered_shape[-1])
error = dithered_waveform - waveform
......@@ -989,7 +988,7 @@ def _add_noise_shaping(dithered_waveform, waveform):
error[index] = error_offset[:waveform.size()[1]]
noise_shaped = dithered_waveform + error
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
return noise_shaped.view(dithered_shape[:-1] + noise_shaped.shape[-1:])
def _apply_probability_distribution(waveform, density_function="TPDF"):
......@@ -1020,7 +1019,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
waveform = waveform.view(-1, shape[-1])
channel_size = waveform.size()[0] - 1
time_size = waveform.size()[-1] - 1
......@@ -1060,7 +1059,7 @@ def _apply_probability_distribution(waveform, density_function="TPDF"):
quantised_signal = quantised_signal_scaled / down_scaling
# unpack batch
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
return quantised_signal.view(shape[:-1] + quantised_signal.shape[-1:])
def dither(waveform, density_function="TPDF", noise_shaping=False):
......@@ -1231,7 +1230,7 @@ def detect_pitch_frequency(
# pack batch
shape = list(waveform.size())
waveform = waveform.reshape([-1] + shape[-1:])
waveform = waveform.view([-1] + shape[-1:])
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
......@@ -1242,6 +1241,6 @@ def detect_pitch_frequency(
freq = sample_rate / (EPSILON + indices.to(torch.float))
# unpack batch
freq = freq.reshape(shape[:-1] + list(freq.shape[-1:]))
freq = freq.view(shape[:-1] + list(freq.shape[-1:]))
return freq
......@@ -215,7 +215,7 @@ class MelScale(torch.nn.Module):
# pack batch
shape = specgram.size()
specgram = specgram.reshape(-1, shape[-2], shape[-1])
specgram = specgram.view(-1, shape[-2], shape[-1])
if self.fb.numel() == 0:
tmp_fb = F.create_fb_matrix(specgram.size(1), self.f_min, self.f_max, self.n_mels, self.sample_rate)
......@@ -228,7 +228,7 @@ class MelScale(torch.nn.Module):
mel_specgram = torch.matmul(specgram.transpose(1, 2), self.fb).transpose(1, 2)
# unpack batch
mel_specgram = mel_specgram.reshape(shape[:-2] + mel_specgram.shape[-2:])
mel_specgram = mel_specgram.view(shape[:-2] + mel_specgram.shape[-2:])
return mel_specgram
......@@ -349,7 +349,7 @@ class MFCC(torch.nn.Module):
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
waveform = waveform.view(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
......@@ -362,7 +362,7 @@ class MFCC(torch.nn.Module):
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
mfcc = mfcc.view(shape[:-1] + mfcc.shape[-2:])
return mfcc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment