"git@developer.sourcefind.cn:OpenDAS/ktransformers.git" did not exist on "e5b001d76fba6da67987ca0c0d6e699226a633a1"
Unverified Commit c80d9a71 authored by moto's avatar moto Committed by GitHub
Browse files

Fix docstring of masking behavior (#612)

parent 867d669b
......@@ -1442,7 +1442,6 @@ def mask_along_axis_iid(
r"""
Apply a mask along ``axis``. Mask will be applied from indices ``[v_0, v_0 + v)``, where
``v`` is sampled from ``uniform(0, mask_param)``, and ``v_0`` from ``uniform(0, max_v - v)``.
All examples will have the same mask interval.
Args:
specgrams (Tensor): Real spectrograms (batch, channel, freq, time)
......
......@@ -779,6 +779,7 @@ class _AxisMasking(torch.nn.Module):
mask_param (int): Maximum possible length of the mask.
axis (int): What dimension the mask is applied on.
iid_masks (bool): Applies iid masks to each of the examples in the batch dimension.
This option is applicable only when the input tensor is 4D.
"""
__constants__ = ['mask_param', 'axis', 'iid_masks']
......@@ -798,7 +799,6 @@ class _AxisMasking(torch.nn.Module):
Returns:
Tensor: Masked spectrogram of dimensions (..., freq, time).
"""
# if iid_masks flag marked and specgram has a batch dimension
if self.iid_masks and specgram.dim() == 4:
return F.mask_along_axis_iid(specgram, self.mask_param, mask_value, self.axis + 1)
......@@ -812,10 +812,10 @@ class FrequencyMasking(_AxisMasking):
Args:
freq_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, freq_mask_param).
iid_masks (bool, optional): whether to apply the same mask to all
the examples/channels in the batch. (Default: ``False``)
iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
"""
def __init__(self, freq_mask_param: int, iid_masks: bool = False) -> None:
super(FrequencyMasking, self).__init__(freq_mask_param, 1, iid_masks)
......@@ -826,10 +826,10 @@ class TimeMasking(_AxisMasking):
Args:
time_mask_param (int): maximum possible length of the mask.
Indices uniformly sampled from [0, time_mask_param).
iid_masks (bool, optional): whether to apply the same mask to all
the examples/channels in the batch. (Default: ``False``)
iid_masks (bool, optional): whether to apply different masks to each
example/channel in the batch. (Default: ``False``)
This option is applicable only when the input tensor is 4D.
"""
def __init__(self, time_mask_param: int, iid_masks: bool = False) -> None:
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment