Unverified Commit 3ecc7016 authored by Artyom Astafurov's avatar Artyom Astafurov Committed by GitHub
Browse files

Port sox::vad (#578)

* initial test, stub function, transform and docstring

* add draft working implementation, update docstrings

* merge VadSate into Vad calss, move Channel into Vad class

* remove functional stub for vad

* add wav file for test

* refactor _measure() to improve performance

* rename argument

* replace copy_ with assignment

* refactor init, update documentation, update test for readability

* clean up default values

* move code from transforms.py to funtional.py and integrate state into a function

* remove Channel state class

* fix calcuation of a flush point

* make multiple channels work

* clean up multi-channel, update test

* rename variables and re-org arguments for _measure

* fix linting errors

* add torchscript consistency test and fix errors

* support and test batch consistency, fix normalization

* update documentation, switch torchscript consistancy test to use transform to improve coverage

* fix linting errors

* remove un-used imports

* address PR comments

* add doc references into rst
parent bc82ffe2
...@@ -162,3 +162,8 @@ Functions to perform common audio operations. ...@@ -162,3 +162,8 @@ Functions to perform common audio operations.
~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: sliding_window_cmn .. autofunction:: sliding_window_cmn
:hidden:`vad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: vad
...@@ -19,7 +19,7 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -19,7 +19,7 @@ Transforms are common audio transforms. They can be chained together using :clas
:hidden:`GriffinLim` :hidden:`GriffinLim`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: GriffinLim .. autoclass:: GriffinLim
.. automethod:: forward .. automethod:: forward
...@@ -128,10 +128,17 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -128,10 +128,17 @@ Transforms are common audio transforms. They can be chained together using :clas
.. autoclass:: Vol .. autoclass:: Vol
.. automethod:: forward .. automethod:: forward
:hidden:`SlidingWindowCmn` :hidden:`SlidingWindowCmn`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: SlidingWindowCmn .. autoclass:: SlidingWindowCmn
.. automethod:: forward .. automethod:: forward
:hidden:`Vad`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: Vad
.. automethod:: forward
...@@ -83,6 +83,11 @@ class TestFunctional(unittest.TestCase): ...@@ -83,6 +83,11 @@ class TestFunctional(unittest.TestCase):
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True) _test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=True)
_test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False) _test_batch(F.sliding_window_cmn, waveform, center=False, norm_vars=False)
def test_vad(self):
filepath = common_utils.get_asset_path("vad-hello-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
_test_batch(F.vad, waveform, sample_rate=sample_rate)
class TestTransforms(unittest.TestCase): class TestTransforms(unittest.TestCase):
"""Test suite for classes defined in `transforms` module""" """Test suite for classes defined in `transforms` module"""
......
...@@ -249,6 +249,24 @@ class Test_SoxEffectsChain(unittest.TestCase): ...@@ -249,6 +249,24 @@ class Test_SoxEffectsChain(unittest.TestCase):
# check if effect worked # check if effect worked
self.assertTrue(x.allclose(z, rtol=1e-4, atol=1e-4)) self.assertTrue(x.allclose(z, rtol=1e-4, atol=1e-4))
def test_vad(self):
sample_files = [
common_utils.get_asset_path("vad-hello-stereo-44100.wav"),
common_utils.get_asset_path("vad-hello-mono-32000.wav")
]
for sample_file in sample_files:
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(sample_file)
E.append_effect_to_chain("vad")
x, _ = E.sox_build_flow_effects()
x_orig, sample_rate = torchaudio.load(sample_file)
vad = torchaudio.transforms.Vad(sample_rate)
y = vad(x_orig)
self.assertTrue(x.allclose(y, rtol=1e-4, atol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
with AudioBackendScope("sox"): with AudioBackendScope("sox"):
......
...@@ -557,6 +557,11 @@ class _TransformsTestMixin: ...@@ -557,6 +557,11 @@ class _TransformsTestMixin:
tensor = torch.rand((1000, 10)) tensor = torch.rand((1000, 10))
self._assert_consistency(T.SlidingWindowCmn(), tensor) self._assert_consistency(T.SlidingWindowCmn(), tensor)
def test_Vad(self):
filepath = common_utils.get_asset_path("vad-hello-mono-32000.wav")
waveform, sample_rate = torchaudio.load(filepath)
self._assert_consistency(T.Vad(sample_rate=sample_rate), waveform)
class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase): class TestFunctionalCPU(_FunctionalTestMixin, unittest.TestCase):
"""Test suite for Functional module on CPU""" """Test suite for Functional module on CPU"""
......
...@@ -37,6 +37,7 @@ __all__ = [ ...@@ -37,6 +37,7 @@ __all__ = [
'mask_along_axis', 'mask_along_axis',
'mask_along_axis_iid', 'mask_along_axis_iid',
'sliding_window_cmn', 'sliding_window_cmn',
'vad',
] ]
...@@ -1836,3 +1837,299 @@ def sliding_window_cmn( ...@@ -1836,3 +1837,299 @@ def sliding_window_cmn(
if len(input_shape) == 2: if len(input_shape) == 2:
cmn_waveform = cmn_waveform.squeeze(0) cmn_waveform = cmn_waveform.squeeze(0)
return cmn_waveform return cmn_waveform
def _measure(
measure_len_ws: int,
samples: Tensor,
spectrum: Tensor,
noise_spectrum: Tensor,
spectrum_window: Tensor,
spectrum_start: int,
spectrum_end: int,
cepstrum_window: Tensor,
cepstrum_start: int,
cepstrum_end: int,
noise_reduction_amount: float,
measure_smooth_time_mult: float,
noise_up_time_mult: float,
noise_down_time_mult: float,
index_ns: int,
boot_count: int
) -> float:
assert spectrum.size()[-1] == noise_spectrum.size()[-1]
samplesLen_ns = samples.size()[-1]
dft_len_ws = spectrum.size()[-1]
dftBuf = torch.zeros(dft_len_ws)
_index_ns = torch.tensor([index_ns] + [
(index_ns + i) % samplesLen_ns
for i in range(1, measure_len_ws)
])
dftBuf[:measure_len_ws] = \
samples[_index_ns] * spectrum_window[:measure_len_ws]
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
dftBuf[measure_len_ws:dft_len_ws].zero_()
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torch.rfft(dftBuf, 1)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()
mult: float = boot_count / (1. + boot_count) \
if boot_count >= 0 \
else measure_smooth_time_mult
_d = complex_norm(_dftBuf[spectrum_start:spectrum_end])
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
_d = spectrum[spectrum_start:spectrum_end] ** 2
_zeros = torch.zeros(spectrum_end - spectrum_start)
_mult = _zeros \
if boot_count >= 0 \
else torch.where(
_d > noise_spectrum[spectrum_start:spectrum_end],
torch.tensor(noise_up_time_mult), # if
torch.tensor(noise_down_time_mult) # else
)
noise_spectrum[spectrum_start:spectrum_end].mul_(_mult).add_(_d * (1 - _mult))
_d = torch.sqrt(
torch.max(
_zeros,
_d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end]))
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end:dft_len_ws >> 1].zero_()
# lsx_safe_rdft((int)p->dft_len_ws >> 1, 1, c->dftBuf);
_cepstrum_Buf = torch.rfft(_cepstrum_Buf, 1)
result: float = float(torch.sum(
complex_norm(
_cepstrum_Buf[cepstrum_start:cepstrum_end],
power=2.0)))
result = \
math.log(result / (cepstrum_end - cepstrum_start)) \
if result > 0 \
else -math.inf
return max(0, 21 + result)
def vad(
waveform: Tensor,
sample_rate: int,
trigger_level: float = 7.0,
trigger_time: float = 0.25,
search_time: float = 1.0,
allowed_gap: float = 0.25,
pre_trigger_time: float = 0.0,
# Fine-tuning parameters
boot_time: float = .35,
noise_up_time: float = .1,
noise_down_time: float = .01,
noise_reduction_amount: float = 1.35,
measure_freq: float = 20.0,
measure_duration: Optional[float] = None,
measure_smooth_time: float = .4,
hp_filter_freq: float = 50.,
lp_filter_freq: float = 6000.,
hp_lifter_freq: float = 150.,
lp_lifter_freq: float = 2000.,
) -> Tensor:
r"""Voice Activity Detector. Similar to SoX implementation.
Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
The algorithm currently uses a simple cepstral power measurement to detect voice,
so may be fooled by other things, especially music.
The effect can trim only from the front of the audio,
so in order to trim from the back, the reverse effect must also be used.
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level,
and other characteristics of the input audio. (Default: 7.0)
trigger_time (float, optional): The time constant (in seconds)
used to help ignore short bursts of sound. (Default: 0.25)
search_time (float, optional): The amount of audio (in seconds)
to search for quieter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 1.0)
allowed_gap (float, optional): The allowed gap (in seconds) between
quiteter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 0.25)
pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
boot_time (float, optional) The algorithm (internally) uses adaptive noise
estimation/reduction in order to detect the start of the wanted audio.
This option sets the time for the initial noise estimate. (Default: 0.35)
noise_up_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is increasing. (Default: 0.1)
noise_down_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is decreasing. (Default: 0.01)
noise_reduction_amount (float, optional) Amount of noise reduction to use in
the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
measure_freq (float, optional) Frequency of the algorithm’s
processing/measurements. (Default: 20.0)
measure_duration: (float, optional) Measurement duration.
(Default: Twice the measurement period; i.e. with overlap.)
measure_smooth_time (float, optional) Time constant used to smooth
spectral measurements. (Default: 0.4)
hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
at the input to the detector algorithm. (Default: 50.0)
lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
at the input to the detector algorithm. (Default: 6000.0)
hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
in the detector algorithm. (Default: 150.0)
lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
in the detector algorithm. (Default: 2000.0)
Returns:
Tensor: Tensor of audio of dimension (..., time).
References:
http://sox.sourceforge.net/sox.html
"""
measure_duration: float = 2.0 / measure_freq \
if measure_duration is None \
else measure_duration
measure_len_ws = int(sample_rate * measure_duration + .5)
measure_len_ns = measure_len_ws
# for (dft_len_ws = 16; dft_len_ws < measure_len_ws; dft_len_ws <<= 1);
dft_len_ws = 16
while (dft_len_ws < measure_len_ws):
dft_len_ws *= 2
measure_period_ns = int(sample_rate / measure_freq + .5)
measures_len = math.ceil(search_time * measure_freq)
search_pre_trigger_len_ns = measures_len * measure_period_ns
gap_len = int(allowed_gap * measure_freq + .5)
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + .5)
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
spectrum_window = torch.zeros(measure_len_ws)
for i in range(measure_len_ws):
# sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
spectrum_window[i] = 2. / math.sqrt(float(measure_len_ws))
# lsx_apply_hann(spectrum_window, (int)measure_len_ws);
spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float)
spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + .5)
spectrum_start: int = max(spectrum_start, 1)
spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + .5)
spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
cepstrum_window = torch.zeros(spectrum_end - spectrum_start)
for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2. / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
cepstrum_start = math.ceil(sample_rate * .5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * .5 / hp_lifter_freq)
cepstrum_end = min(cepstrum_end, dft_len_ws // 4)
assert cepstrum_end > cepstrum_start
noise_up_time_mult = math.exp(-1. / (noise_up_time * measure_freq))
noise_down_time_mult = math.exp(-1. / (noise_down_time * measure_freq))
measure_smooth_time_mult = math.exp(-1. / (measure_smooth_time * measure_freq))
trigger_meas_time_mult = math.exp(-1. / (trigger_time * measure_freq))
boot_count_max = int(boot_time * measure_freq - .5)
measure_timer_ns = measure_len_ns
boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0
# pack batch
shape = waveform.size()
waveform = waveform.view(-1, shape[-1])
n_channels, ilen = waveform.size()
mean_meas = torch.zeros(n_channels)
samples = torch.zeros(n_channels, samplesLen_ns)
spectrum = torch.zeros(n_channels, dft_len_ws)
noise_spectrum = torch.zeros(n_channels, dft_len_ws)
measures = torch.zeros(n_channels, measures_len)
has_triggered: bool = False
num_measures_to_flush: int = 0
pos: int = 0
while (pos < ilen and not has_triggered):
measure_timer_ns -= 1
for i in range(n_channels):
samples[i, samplesIndex_ns] = waveform[i, pos]
# if (!p->measure_timer_ns) {
if (measure_timer_ns == 0):
index_ns: int = \
(samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
meas: float = _measure(
measure_len_ws=measure_len_ws,
samples=samples[i],
spectrum=spectrum[i],
noise_spectrum=noise_spectrum[i],
spectrum_window=spectrum_window,
spectrum_start=spectrum_start,
spectrum_end=spectrum_end,
cepstrum_window=cepstrum_window,
cepstrum_start=cepstrum_start,
cepstrum_end=cepstrum_end,
noise_reduction_amount=noise_reduction_amount,
measure_smooth_time_mult=measure_smooth_time_mult,
noise_up_time_mult=noise_up_time_mult,
noise_down_time_mult=noise_down_time_mult,
index_ns=index_ns,
boot_count=boot_count)
measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1. - trigger_meas_time_mult)
has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered:
n: int = measures_len
k: int = measures_index
jTrigger: int = n
jZero: int = n
j: int = 0
for j in range(n):
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
k = (k + n - 1) % n
j = min(j, jZero)
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
num_measures_to_flush = (min(max(num_measures_to_flush, j), n))
# end if has_triggered
# end if (measure_timer_ns == 0):
# end for
samplesIndex_ns += 1
pos += 1
# end while
if samplesIndex_ns == samplesLen_ns:
samplesIndex_ns = 0
if measure_timer_ns == 0:
measure_timer_ns = measure_period_ns
measures_index += 1
measures_index = measures_index % measures_len
if boot_count >= 0:
boot_count = -1 if boot_count == boot_count_max else boot_count + 1
if has_triggered:
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
res = waveform[:, pos - samplesLen_ns + flushedLen_ns:]
# unpack batch
return res.view(shape[:-1] + res.shape[-1:])
...@@ -27,6 +27,7 @@ __all__ = [ ...@@ -27,6 +27,7 @@ __all__ = [
'FrequencyMasking', 'FrequencyMasking',
'TimeMasking', 'TimeMasking',
'SlidingWindowCmn', 'SlidingWindowCmn',
'Vad',
] ]
...@@ -907,3 +908,120 @@ class SlidingWindowCmn(torch.nn.Module): ...@@ -907,3 +908,120 @@ class SlidingWindowCmn(torch.nn.Module):
cmn_waveform = F.sliding_window_cmn( cmn_waveform = F.sliding_window_cmn(
waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars) waveform, self.cmn_window, self.min_cmn_window, self.center, self.norm_vars)
return cmn_waveform return cmn_waveform
class Vad(torch.nn.Module):
r"""Voice Activity Detector. Similar to SoX implementation.
Attempts to trim silence and quiet background sounds from the ends of recordings of speech.
The algorithm currently uses a simple cepstral power measurement to detect voice,
so may be fooled by other things, especially music.
The effect can trim only from the front of the audio,
so in order to trim from the back, the reverse effect must also be used.
Args:
sample_rate (int): Sample rate of audio signal.
trigger_level (float, optional): The measurement level used to trigger activity detection.
This may need to be cahnged depending on the noise level, signal level,
and other characteristics of the input audio. (Default: 7.0)
trigger_time (float, optional): The time constant (in seconds)
used to help ignore short bursts of sound. (Default: 0.25)
search_time (float, optional): The amount of audio (in seconds)
to search for quieter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 1.0)
allowed_gap (float, optional): The allowed gap (in seconds) between
quiteter/shorter bursts of audio to include prior
to the detected trigger point. (Default: 0.25)
pre_trigger_time (float, optional): The amount of audio (in seconds) to preserve
before the trigger point and any found quieter/shorter bursts. (Default: 0.0)
boot_time (float, optional) The algorithm (internally) uses adaptive noise
estimation/reduction in order to detect the start of the wanted audio.
This option sets the time for the initial noise estimate. (Default: 0.35)
noise_up_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is increasing. (Default: 0.1)
noise_down_time (float, optional) Time constant used by the adaptive noise estimator
for when the noise level is decreasing. (Default: 0.01)
noise_reduction_amount (float, optional) Amount of noise reduction to use in
the detection algorithm (e.g. 0, 0.5, ...). (Default: 1.35)
measure_freq (float, optional) Frequency of the algorithm’s
processing/measurements. (Default: 20.0)
measure_duration: (float, optional) Measurement duration.
(Default: Twice the measurement period; i.e. with overlap.)
measure_smooth_time (float, optional) Time constant used to smooth
spectral measurements. (Default: 0.4)
hp_filter_freq (float, optional) "Brick-wall" frequency of high-pass filter applied
at the input to the detector algorithm. (Default: 50.0)
lp_filter_freq (float, optional) "Brick-wall" frequency of low-pass filter applied
at the input to the detector algorithm. (Default: 6000.0)
hp_lifter_freq (float, optional) "Brick-wall" frequency of high-pass lifter used
in the detector algorithm. (Default: 150.0)
lp_lifter_freq (float, optional) "Brick-wall" frequency of low-pass lifter used
in the detector algorithm. (Default: 2000.0)
References:
http://sox.sourceforge.net/sox.html
"""
def __init__(self,
sample_rate: int,
trigger_level: float = 7.0,
trigger_time: float = 0.25,
search_time: float = 1.0,
allowed_gap: float = 0.25,
pre_trigger_time: float = 0.0,
boot_time: float = .35,
noise_up_time: float = .1,
noise_down_time: float = .01,
noise_reduction_amount: float = 1.35,
measure_freq: float = 20.0,
measure_duration: Optional[float] = None,
measure_smooth_time: float = .4,
hp_filter_freq: float = 50.,
lp_filter_freq: float = 6000.,
hp_lifter_freq: float = 150.,
lp_lifter_freq: float = 2000.) -> None:
super().__init__()
self.sample_rate = sample_rate
self.trigger_level = trigger_level
self.trigger_time = trigger_time
self.search_time = search_time
self.allowed_gap = allowed_gap
self.pre_trigger_time = pre_trigger_time
self.boot_time = boot_time
self.noise_up_time = noise_up_time
self.noise_down_time = noise_up_time
self.noise_reduction_amount = noise_reduction_amount
self.measure_freq = measure_freq
self.measure_duration = measure_duration
self.measure_smooth_time = measure_smooth_time
self.hp_filter_freq = hp_filter_freq
self.lp_filter_freq = lp_filter_freq
self.hp_lifter_freq = hp_lifter_freq
self.lp_lifter_freq = lp_lifter_freq
def forward(self, waveform: Tensor) -> Tensor:
r"""
Args:
waveform (Tensor): Tensor of audio of dimension `(..., time)`
"""
return F.vad(
waveform=waveform,
sample_rate=self.sample_rate,
trigger_level=self.trigger_level,
trigger_time=self.trigger_time,
search_time=self.search_time,
allowed_gap=self.allowed_gap,
pre_trigger_time=self.pre_trigger_time,
boot_time=self.boot_time,
noise_up_time=self.noise_up_time,
noise_down_time=self.noise_up_time,
noise_reduction_amount=self.noise_reduction_amount,
measure_freq=self.measure_freq,
measure_duration=self.measure_duration,
measure_smooth_time=self.measure_smooth_time,
hp_filter_freq=self.hp_filter_freq,
lp_filter_freq=self.lp_filter_freq,
hp_lifter_freq=self.hp_lifter_freq,
lp_lifter_freq=self.lp_lifter_freq,
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment