Unverified Commit 08f2bde4 authored by Artyom Astafurov's avatar Artyom Astafurov Committed by GitHub
Browse files

Update VAD docstring and check for input shape length (#1513)

* Update VAD docstring and check for input shape length

* Update docstring in forward for transform

* Address review feedback: merge tests, update wording
parent 22fe8026
import warnings
import torch
import torchaudio.transforms as T
from parameterized import parameterized
......@@ -61,3 +64,25 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, sample_rate = load_wav(path)
result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ['vad'])
def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2"""
sample_rate = 41100
data = torch.rand(5, 5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 1
data = torch.rand(5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0
data = torch.rand(sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0
import math
import warnings
from typing import Optional
import torch
......@@ -1374,7 +1375,10 @@ def vad(
so in order to trim from the back, the reverse effect must also be used.
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level,
......@@ -1420,6 +1424,15 @@ def vad(
http://sox.sourceforge.net/sox.html
"""
if waveform.ndim > 2:
warnings.warn(
"Expected input tensor dimension of 1 for single channel"
f" or 2 for multi-channel. Got {waveform.ndim} instead. "
"Batch semantics is not supported. "
"Please refer to https://github.com/pytorch/audio/issues/1348"
" and https://github.com/pytorch/audio/issues/1468."
)
measure_duration: float = (
2.0 / measure_freq if measure_duration is None else measure_duration
)
......
......@@ -1126,7 +1126,10 @@ class Vad(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
"""
return F.vad(
waveform=waveform,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment