Commit 1e117f57 authored by Kuba Rad's avatar Kuba Rad Committed by Facebook GitHub Bot
Browse files

Optimize Torchaudio Vad (#3382)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3382

The voice activity detector function was unoptimized, confusingly written, and buggy.

The optimizations created here allow for the function to run roughly 17x faster.
The main optimizations were to loop over windows of audio rather than individual audio samples. Reducing the number of copies also helped.

There was an off by one error where the array slice referenced was [1: 16001] (for the default settings) instead of [0: 16000]

Reviewed By: hwangjeff

Differential Revision: D44749359

fbshipit-source-id: c76c9412e70cdc6fcd527d113603c88f78480558
parent c3ca2562
...@@ -1390,11 +1390,11 @@ def _measure( ...@@ -1390,11 +1390,11 @@ def _measure(
cepstrum_end: int, cepstrum_end: int,
noise_reduction_amount: float, noise_reduction_amount: float,
measure_smooth_time_mult: float, measure_smooth_time_mult: float,
noise_up_time_mult: float, noise_up_time_mult: Tensor,
noise_down_time_mult: float, noise_down_time_mult: Tensor,
index_ns: int,
boot_count: int, boot_count: int,
) -> float: ) -> float:
device = samples.device
if spectrum.size(-1) != noise_spectrum.size(-1): if spectrum.size(-1) != noise_spectrum.size(-1):
raise ValueError( raise ValueError(
...@@ -1402,37 +1402,29 @@ def _measure( ...@@ -1402,37 +1402,29 @@ def _measure(
f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}" f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}"
) )
samplesLen_ns = samples.size()[-1]
dft_len_ws = spectrum.size()[-1] dft_len_ws = spectrum.size()[-1]
dftBuf = torch.zeros(dft_len_ws) dftBuf = torch.zeros(dft_len_ws, device=device)
_index_ns = torch.tensor([index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)]) dftBuf[:measure_len_ws] = samples * spectrum_window[:measure_len_ws]
dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws]
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
dftBuf[measure_len_ws:dft_len_ws].zero_()
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf); # lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torch.fft.rfft(dftBuf) _dftBuf = torch.fft.rfft(dftBuf)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()
mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
_d = _dftBuf[spectrum_start:spectrum_end].abs() _d = _dftBuf[spectrum_start:spectrum_end].abs()
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult)) spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
_d = spectrum[spectrum_start:spectrum_end] ** 2 _d = spectrum[spectrum_start:spectrum_end] ** 2
_zeros = torch.zeros(spectrum_end - spectrum_start) _zeros = torch.zeros(spectrum_end - spectrum_start, device=device)
_mult = ( _mult = (
_zeros _zeros
if boot_count >= 0 if boot_count >= 0
else torch.where( else torch.where(
_d > noise_spectrum[spectrum_start:spectrum_end], _d > noise_spectrum[spectrum_start:spectrum_end],
torch.tensor(noise_up_time_mult), # if noise_up_time_mult, # if
torch.tensor(noise_down_time_mult), # else noise_down_time_mult, # else,
) )
) )
...@@ -1441,10 +1433,10 @@ def _measure( ...@@ -1441,10 +1433,10 @@ def _measure(
torch.max( torch.max(
_zeros, _zeros,
_d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end], _d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end],
) ),
) )
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1) _cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1, device=device)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window _cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_() _cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
...@@ -1539,6 +1531,7 @@ def vad( ...@@ -1539,6 +1531,7 @@ def vad(
Reference: Reference:
- http://sox.sourceforge.net/sox.html - http://sox.sourceforge.net/sox.html
""" """
device = waveform.device
if waveform.ndim > 2: if waveform.ndim > 2:
warnings.warn( warnings.warn(
...@@ -1566,23 +1559,23 @@ def vad( ...@@ -1566,23 +1559,23 @@ def vad(
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5) fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
spectrum_window = torch.zeros(measure_len_ws) spectrum_window = torch.zeros(measure_len_ws, device=device)
for i in range(measure_len_ws): for i in range(measure_len_ws):
# sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32) # sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws)) spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws))
# lsx_apply_hann(spectrum_window, (int)measure_len_ws); # lsx_apply_hann(spectrum_window, (int)measure_len_ws);
spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float) spectrum_window *= torch.hann_window(measure_len_ws, device=device, dtype=torch.float)
spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5) spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5)
spectrum_start: int = max(spectrum_start, 1) spectrum_start: int = max(spectrum_start, 1)
spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5) spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5)
spectrum_end: int = min(spectrum_end, dft_len_ws // 2) spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
cepstrum_window = torch.zeros(spectrum_end - spectrum_start) cepstrum_window = torch.zeros(spectrum_end - spectrum_start, device=device)
for i in range(spectrum_end - spectrum_start): for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start) cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start)); # lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float) cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, device=device, dtype=torch.float)
cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq) cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq) cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
...@@ -1594,14 +1587,13 @@ def vad( ...@@ -1594,14 +1587,13 @@ def vad(
f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}." f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}."
) )
noise_up_time_mult = math.exp(-1.0 / (noise_up_time * measure_freq)) noise_up_time_mult = torch.tensor(math.exp(-1.0 / (noise_up_time * measure_freq)), device=device)
noise_down_time_mult = math.exp(-1.0 / (noise_down_time * measure_freq)) noise_down_time_mult = torch.tensor(math.exp(-1.0 / (noise_down_time * measure_freq)), device=device)
measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq)) measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq))
trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq)) trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq))
boot_count_max = int(boot_time * measure_freq - 0.5) boot_count_max = int(boot_time * measure_freq - 0.5)
measure_timer_ns = measure_len_ns boot_count = measures_index = flushedLen_ns = 0
boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0
# pack batch # pack batch
shape = waveform.size() shape = waveform.size()
...@@ -1609,80 +1601,65 @@ def vad( ...@@ -1609,80 +1601,65 @@ def vad(
n_channels, ilen = waveform.size() n_channels, ilen = waveform.size()
mean_meas = torch.zeros(n_channels) mean_meas = torch.zeros(n_channels, device=device)
samples = torch.zeros(n_channels, samplesLen_ns) spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
spectrum = torch.zeros(n_channels, dft_len_ws) noise_spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
noise_spectrum = torch.zeros(n_channels, dft_len_ws) measures = torch.zeros(n_channels, measures_len, device=device)
measures = torch.zeros(n_channels, measures_len)
has_triggered: bool = False has_triggered: bool = False
num_measures_to_flush: int = 0 num_measures_to_flush: int = 0
pos: int = 0
while pos < ilen and not has_triggered: pos = 0
measure_timer_ns -= 1 for pos in range(measure_len_ns, ilen, measure_period_ns):
for i in range(n_channels): for i in range(n_channels):
samples[i, samplesIndex_ns] = waveform[i, pos] meas: float = _measure(
# if (!p->measure_timer_ns) { measure_len_ws=measure_len_ws,
if measure_timer_ns == 0: samples=waveform[i, pos - measure_len_ws : pos],
index_ns: int = (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns spectrum=spectrum[i],
meas: float = _measure( noise_spectrum=noise_spectrum[i],
measure_len_ws=measure_len_ws, spectrum_window=spectrum_window,
samples=samples[i], spectrum_start=spectrum_start,
spectrum=spectrum[i], spectrum_end=spectrum_end,
noise_spectrum=noise_spectrum[i], cepstrum_window=cepstrum_window,
spectrum_window=spectrum_window, cepstrum_start=cepstrum_start,
spectrum_start=spectrum_start, cepstrum_end=cepstrum_end,
spectrum_end=spectrum_end, noise_reduction_amount=noise_reduction_amount,
cepstrum_window=cepstrum_window, measure_smooth_time_mult=measure_smooth_time_mult,
cepstrum_start=cepstrum_start, noise_up_time_mult=noise_up_time_mult,
cepstrum_end=cepstrum_end, noise_down_time_mult=noise_down_time_mult,
noise_reduction_amount=noise_reduction_amount, boot_count=boot_count,
measure_smooth_time_mult=measure_smooth_time_mult, )
noise_up_time_mult=noise_up_time_mult, measures[i, measures_index] = meas
noise_down_time_mult=noise_down_time_mult, mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
index_ns=index_ns,
boot_count=boot_count, has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
) if has_triggered:
measures[i, measures_index] = meas n: int = measures_len
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult) k: int = measures_index
jTrigger: int = n
has_triggered = has_triggered or (mean_meas[i] >= trigger_level) jZero: int = n
if has_triggered: j: int = 0
n: int = measures_len
k: int = measures_index for j in range(n):
jTrigger: int = n if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero: int = n jZero = jTrigger = j
j: int = 0 elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
for j in range(n): k = (k + n - 1) % n
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len): j = min(j, jZero)
jZero = jTrigger = j # num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
elif (measures[i, k] == 0) and (jTrigger >= jZero): num_measures_to_flush = min(max(num_measures_to_flush, j), n)
jZero = j # end if has_triggered
k = (k + n - 1) % n # end for channel
j = min(j, jZero) measures_index += 1
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n); measures_index = measures_index % measures_len
num_measures_to_flush = min(max(num_measures_to_flush, j), n) if boot_count >= 0:
# end if has_triggered boot_count = -1 if boot_count == boot_count_max else boot_count + 1
# end if (measure_timer_ns == 0):
# end for
samplesIndex_ns += 1
pos += 1
# end while
if samplesIndex_ns == samplesLen_ns:
samplesIndex_ns = 0
if measure_timer_ns == 0:
measure_timer_ns = measure_period_ns
measures_index += 1
measures_index = measures_index % measures_len
if boot_count >= 0:
boot_count = -1 if boot_count == boot_count_max else boot_count + 1
if has_triggered: if has_triggered:
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns break
# end for window
res = waveform[:, pos - samplesLen_ns + flushedLen_ns :] res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
# unpack batch # unpack batch
return res.view(shape[:-1] + res.shape[-1:]) return res.view(shape[:-1] + res.shape[-1:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment