Unverified Commit c90c18d7 authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Move batch from vocoder transform to functional (#350)

* fixing errors in docstring.

* move batch to functional.
parent c74e580f
...@@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -469,13 +469,13 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
factor of ``rate``. factor of ``rate``.
Args: Args:
complex_specgrams (torch.Tensor): Dimension of `(channel, freq, time, complex=2)` complex_specgrams (torch.Tensor): Dimension of `(..., freq, time, complex=2)`
rate (float): Speed-up factor rate (float): Speed-up factor
phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension phase_advance (torch.Tensor): Expected phase advance in each bin. Dimension
of (freq, 1) of (freq, 1)
Returns: Returns:
complex_specgrams_stretch (torch.Tensor): Dimension of `(channel, complex_specgrams_stretch (torch.Tensor): Dimension of `(...,
freq, ceil(time/rate), complex=2)` freq, ceil(time/rate), complex=2)`
Example Example
...@@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -490,6 +490,10 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
torch.Size([2, 1025, 231, 2]) torch.Size([2, 1025, 231, 2])
""" """
# pack batch
shape = complex_specgrams.size()
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
time_steps = torch.arange(0, time_steps = torch.arange(0,
complex_specgrams.size(-2), complex_specgrams.size(-2),
rate, rate,
...@@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance): ...@@ -527,6 +531,9 @@ def phase_vocoder(complex_specgrams, rate, phase_advance):
complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1) complex_specgrams_stretch = torch.stack([real_stretch, imag_stretch], dim=-1)
# unpack batch
complex_specgrams_stretch = complex_specgrams_stretch.reshape(shape[:-3] + complex_specgrams_stretch.shape[1:])
return complex_specgrams_stretch return complex_specgrams_stretch
...@@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -775,6 +782,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
torch.Tensor: Masked spectrogram of dimensions (channel, freq, time) torch.Tensor: Masked spectrogram of dimensions (channel, freq, time)
""" """
# pack batch
shape = specgram.size()
specgram = specgram.reshape([-1] + list(shape[-2:]))
value = torch.rand(1) * mask_param value = torch.rand(1) * mask_param
min_value = torch.rand(1) * (specgram.size(axis) - value) min_value = torch.rand(1) * (specgram.size(axis) - value)
...@@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis): ...@@ -789,7 +800,10 @@ def mask_along_axis(specgram, mask_param, mask_value, axis):
else: else:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError('Only Frequency and Time masking are supported')
return specgram # unpack batch
specgram = specgram.reshape(shape[:-2] + specgram.shape[-2:])
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
def compute_deltas(specgram, win_length=5, mode="replicate"): def compute_deltas(specgram, win_length=5, mode="replicate"):
......
...@@ -380,9 +380,9 @@ class ComplexNorm(torch.nn.Module): ...@@ -380,9 +380,9 @@ class ComplexNorm(torch.nn.Module):
def forward(self, complex_tensor): def forward(self, complex_tensor):
r""" r"""
Args: Args:
complex_tensor (Tensor): Tensor shape of `(*, complex=2)` complex_tensor (Tensor): Tensor shape of `(..., complex=2)`
Returns: Returns:
Tensor: norm of the input tensor, shape of `(*, )` Tensor: norm of the input tensor, shape of `(..., )`
""" """
return F.complex_norm(complex_tensor, self.power) return F.complex_norm(complex_tensor, self.power)
...@@ -438,14 +438,14 @@ class TimeStretch(torch.jit.ScriptModule): ...@@ -438,14 +438,14 @@ class TimeStretch(torch.jit.ScriptModule):
# type: (Tensor, Optional[float]) -> Tensor # type: (Tensor, Optional[float]) -> Tensor
r""" r"""
Args: Args:
complex_specgrams (Tensor): complex spectrogram (*, channel, freq, time, complex=2) complex_specgrams (Tensor): complex spectrogram (..., freq, time, complex=2)
overriding_rate (float or None): speed up to apply to this batch. overriding_rate (float or None): speed up to apply to this batch.
If no rate is passed, use ``self.fixed_rate`` If no rate is passed, use ``self.fixed_rate``
Returns: Returns:
(Tensor): Stretched complex spectrogram of dimension (*, channel, freq, ceil(time/rate), complex=2) (Tensor): Stretched complex spectrogram of dimension (..., freq, ceil(time/rate), complex=2)
""" """
assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (*, complex=2)" assert complex_specgrams.size(-1) == 2, "complex_specgrams should be a complex tensor, shape (..., complex=2)"
if overriding_rate is None: if overriding_rate is None:
rate = self.fixed_rate rate = self.fixed_rate
...@@ -458,16 +458,12 @@ class TimeStretch(torch.jit.ScriptModule): ...@@ -458,16 +458,12 @@ class TimeStretch(torch.jit.ScriptModule):
if rate == 1.0: if rate == 1.0:
return complex_specgrams return complex_specgrams
shape = complex_specgrams.size() return F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
complex_specgrams = complex_specgrams.reshape([-1] + list(shape[-3:]))
complex_specgrams = F.phase_vocoder(complex_specgrams, rate, self.phase_advance)
return complex_specgrams.reshape(shape[:-3] + complex_specgrams.shape[-3:])
class _AxisMasking(torch.nn.Module): class _AxisMasking(torch.nn.Module):
r""" r"""Apply masking to a spectrogram.
Apply masking to a spectrogram.
Args: Args:
mask_param (int): Maximum possible length of the mask mask_param (int): Maximum possible length of the mask
axis: What dimension the mask is applied on axis: What dimension the mask is applied on
...@@ -486,26 +482,22 @@ class _AxisMasking(torch.nn.Module): ...@@ -486,26 +482,22 @@ class _AxisMasking(torch.nn.Module):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
r""" r"""
Args: Args:
specgram (torch.Tensor): Tensor of dimension (*, channel, freq, time) specgram (torch.Tensor): Tensor of dimension (..., freq, time)
Returns: Returns:
torch.Tensor: Masked spectrogram of dimensions (*, channel, freq, time) torch.Tensor: Masked spectrogram of dimensions (..., freq, time)
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4: if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
else: else:
shape = specgram.size() return F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
specgram = specgram.reshape([-1] + list(shape[-2:]))
specgram = F.mask_along_axis(specgram, self.mask_param, mask_value, self.axis)
return specgram.reshape(shape[:-2] + specgram.shape[-2:])
class FrequencyMasking(_AxisMasking): class FrequencyMasking(_AxisMasking):
r""" r"""Apply masking to a spectrogram in the frequency domain.
Apply masking to a spectrogram in the frequency domain.
Args: Args:
freq_mask_param (int): maximum possible length of the mask. freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param). Indices uniformly sampled from [0, freq_mask_param).
...@@ -518,8 +510,8 @@ class FrequencyMasking(_AxisMasking): ...@@ -518,8 +510,8 @@ class FrequencyMasking(_AxisMasking):
class TimeMasking(_AxisMasking): class TimeMasking(_AxisMasking):
r""" r"""Apply masking to a spectrogram in the time domain.
Apply masking to a spectrogram in the time domain.
Args: Args:
time_mask_param (int): maximum possible length of the mask. time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param). Indices uniformly sampled from [0, time_mask_param).
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment