Unverified Commit d1d6dbc6 authored by Zack Kneupper's avatar Zack Kneupper Committed by GitHub
Browse files

Simplify axis value checks (#1501)

parent d49e6e45
...@@ -732,7 +732,7 @@ def mask_along_axis_iid( ...@@ -732,7 +732,7 @@ def mask_along_axis_iid(
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time) Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
""" """
if axis != 2 and axis != 3: if axis not in [2, 3]:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError('Only Frequency and Time masking are supported')
device = specgrams.device device = specgrams.device
...@@ -774,7 +774,7 @@ def mask_along_axis( ...@@ -774,7 +774,7 @@ def mask_along_axis(
Returns: Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time) Tensor: Masked spectrogram of dimensions (channel, freq, time)
""" """
if axis != 1 and axis != 2: if axis not in [1, 2]:
raise ValueError('Only Frequency and Time masking are supported') raise ValueError('Only Frequency and Time masking are supported')
# pack batch # pack batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment