Unverified Commit 08f2bde4 authored by Artyom Astafurov's avatar Artyom Astafurov Committed by GitHub
Browse files

Update VAD docstring and check for input shape length (#1513)

* Update VAD docstring and check for input shape length

* Update docstring in forward for transform

* Address review feedback: merge tests, update wording
parent 22fe8026
import warnings
import torch
import torchaudio.transforms as T import torchaudio.transforms as T
from parameterized import parameterized from parameterized import parameterized
...@@ -61,3 +64,25 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase): ...@@ -61,3 +64,25 @@ class TestFunctionalFiltering(TempDirMixin, TorchaudioTestCase):
data, sample_rate = load_wav(path) data, sample_rate = load_wav(path)
result = T.Vad(sample_rate)(data) result = T.Vad(sample_rate)(data)
self.assert_sox_effect(result, path, ['vad']) self.assert_sox_effect(result, path, ['vad'])
def test_vad_warning(self):
"""vad should throw a warning if input dimension is greater than 2"""
sample_rate = 41100
data = torch.rand(5, 5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 1
data = torch.rand(5, sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0
data = torch.rand(sample_rate)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
T.Vad(sample_rate)(data)
assert len(w) == 0
import math import math
import warnings
from typing import Optional from typing import Optional
import torch import torch
...@@ -1374,7 +1375,10 @@ def vad( ...@@ -1374,7 +1375,10 @@ def vad(
so in order to trim from the back, the reverse effect must also be used. so in order to trim from the back, the reverse effect must also be used.
Args: Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)` waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
sample_rate (int): Sample rate of audio signal. sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection. trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level, This may need to be cahnged depending on the noise level, signal level,
...@@ -1420,6 +1424,15 @@ def vad( ...@@ -1420,6 +1424,15 @@ def vad(
http://sox.sourceforge.net/sox.html http://sox.sourceforge.net/sox.html
""" """
if waveform.ndim > 2:
warnings.warn(
"Expected input tensor dimension of 1 for single channel"
f" or 2 for multi-channel. Got {waveform.ndim} instead. "
"Batch semantics is not supported. "
"Please refer to https://github.com/pytorch/audio/issues/1348"
" and https://github.com/pytorch/audio/issues/1468."
)
measure_duration: float = ( measure_duration: float = (
2.0 / measure_freq if measure_duration is None else measure_duration 2.0 / measure_freq if measure_duration is None else measure_duration
) )
......
...@@ -1126,7 +1126,10 @@ class Vad(torch.nn.Module): ...@@ -1126,7 +1126,10 @@ class Vad(torch.nn.Module):
def forward(self, waveform: Tensor) -> Tensor: def forward(self, waveform: Tensor) -> Tensor:
r""" r"""
Args: Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)` waveform (Tensor): Tensor of audio of dimension `(channels, time)` or `(time)`
Tensor of shape `(channels, time)` is treated as a multi-channel recording
of the same event and the resulting output will be trimmed to the earliest
voice activity in any channel.
""" """
return F.vad( return F.vad(
waveform=waveform, waveform=waveform,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment