Unverified Commit c80d9a71 authored by moto's avatar moto Committed by GitHub
Browse files

Fix docstring of masking behavior (#612)

parent 867d669b
...@@ -1442,7 +1442,6 @@ def mask_along_axis_iid( ...@@ -1442,7 +1442,6 @@ def mask_along_axis_iid(
r""" r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``. ``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.
Args: Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time) specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
......
...@@ -779,6 +779,7 @@ class _AxisMasking(torch.nn.Module): ...@@ -779,6 +779,7 @@ class _AxisMasking(torch.nn.Module):
mask_param (int): Maximum possible length of the mask. mask_param (int): Maximum possible length of the mask.
axis (int): What dimension the mask is applied on. axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension. iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D.
""" """
__constants__ = ['mask_param', 'axis', 'iid_masks'] __constants__ = ['mask_param', 'axis', 'iid_masks']
...@@ -798,7 +799,6 @@ class _AxisMasking(torch.nn.Module): ...@@ -798,7 +799,6 @@ class _AxisMasking(torch.nn.Module):
Returns: Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time). Tensor: Masked spectrogram of dimensions (..., freq, time).
""" """
# if iid_masks flag marked and specgram has a batch dimension # if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4: if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1) return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
...@@ -812,10 +812,10 @@ class FrequencyMasking(_AxisMasking): ...@@ -812,10 +812,10 @@ class FrequencyMasking(_AxisMasking):
Args: Args:
freq_mask_param (int): maximum possible length of the mask. freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param). Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool, optional): whether to apply the same mask to all iid_masks (bool, optional): whether to apply different masks to each
the examples/channels in the batch. (Default: ``False``) example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
""" """
def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None: def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks) super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
...@@ -826,10 +826,10 @@ class TimeMasking(_AxisMasking): ...@@ -826,10 +826,10 @@ class TimeMasking(_AxisMasking):
Args: Args:
time_mask_param (int): maximum possible length of the mask. time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param). Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool, optional): whether to apply the same mask to all iid_masks (bool, optional): whether to apply different masks to each
the examples/channels in the batch. (Default: ``False``) example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
""" """
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None: def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment