Unverified Commit c90c18d7 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Move batch from vocoder transform to functional (#350)

* fixing errors in docstring.

* move batch to functional.
parent c74e580f
......@@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
factor of ``rate``.
Args:
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)`
complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1)
Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel,
complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
freq, ceil(time/rate), complex=2)`
Example
......@@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
torch.Size([2, 1025, 231, 2])
"""
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
time_steps = torch.arange(0,
complex_specgrams.size(-2),
rate,
......@@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
return complex_specgrams_stretch
......@@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value)
......@@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
else:
raise ValueError('Only Frequency and Time masking are supported')
return specgram
# unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
def compute_deltas(specgram, win_length=5, mode="replicate"):
......
......@@ -380,9 +380,9 @@ class ComplexNorm(torch.nn.Module):
def forward(self, complex_tensor):
r"""
Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)`
complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
Returns:
Tensor: norm of the input tensor, shape of `(*, )`
Tensor: norm of the input tensor, shape of `(..., )`
"""
return F.complex_norm(complex_tensor, self.power)
......@@ -438,14 +438,14 @@ class TimeStretch(torch.jit.ScriptModule):
# type: (Tensor, Optional[float]) -> Tensor
r"""
Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2)
complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate``
Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2)
(Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
"""
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)"
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if overriding_rate is None:
rate = self.fixed_rate
......@@ -458,16 +458,12 @@ class TimeStretch(torch.jit.ScriptModule):
if rate == 1.0:
return complex_specgrams
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
class _AxisMasking(torch.nn.Module):
r"""
Apply masking to a spectrogram.
r"""Apply masking to a spectrogram.
Args:
mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on
......@@ -486,26 +482,22 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor
r"""
Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time)
specgram (torch.Tensor): Tensor of dimension (..., freq, time)
Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time)
torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else:
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
class FrequencyMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the frequency domain.
r"""Apply masking to a spectrogram in the frequency domain.
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
......@@ -518,8 +510,8 @@ class FrequencyMasking(_AxisMasking):
class TimeMasking(_AxisMasking):
r"""
Apply masking to a spectrogram in the time domain.
r"""Apply masking to a spectrogram in the time domain.
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment