Unverified Commit d1d6dbc6 authored by Zack Kneupper's avatar Zack Kneupper Committed by GitHub
Browse files

Simplify axis value checks (#1501)

parent d49e6e45
......@@ -732,7 +732,7 @@ def mask_along_axis_iid(
Tensor: Masked spectrograms of dimensions (batch, channel, freq, time)
"""
if axis != 2 and axis != 3:
if axis not in [2, 3]:
raise ValueError('Only Frequency and Time masking are supported')
device = specgrams.device
......@@ -774,7 +774,7 @@ def mask_along_axis(
Returns:
Tensor: Masked spectrogram of dimensions (channel, freq, time)
"""
if axis != 1 and axis != 2:
if axis not in [1, 2]:
raise ValueError('Only Frequency and Time masking are supported')
# pack batch
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment