Commit 1e117f57 authored by Kuba Rad's avatar Kuba Rad Committed by Facebook GitHub Bot
Browse files

Optimize Torchaudio Vad (#3382)

Summary:
Pull Request resolved: https://github.com/pytorch/audio/pull/3382

The voice activity detector function was unoptimized, confusingly written, and buggy.

The optimizations created here allow for the function to run roughly 17x faster.
The main optimizations were to loop over windows of audio rather than individual audio samples. Reducing the number of copies also helped.

There was an off by one error where the array slice referenced was [1: 16001] (for the default settings) instead of [0: 16000]

Reviewed By: hwangjeff

Differential Revision: D44749359

fbshipit-source-id: c76c9412e70cdc6fcd527d113603c88f78480558
parent c3ca2562
......@@ -1390,11 +1390,11 @@ def _measure(
cepstrum_end: int,
noise_reduction_amount: float,
measure_smooth_time_mult: float,
noise_up_time_mult: float,
noise_down_time_mult: float,
index_ns: int,
noise_up_time_mult: Tensor,
noise_down_time_mult: Tensor,
boot_count: int,
) -> float:
device = samples.device
if spectrum.size(-1) != noise_spectrum.size(-1):
raise ValueError(
......@@ -1402,37 +1402,29 @@ def _measure(
f"Found: spectrum size: {spectrum.size()}, noise_spectrum size: {noise_spectrum.size()}"
)
samplesLen_ns = samples.size()[-1]
dft_len_ws = spectrum.size()[-1]
dftBuf = torch.zeros(dft_len_ws)
dftBuf = torch.zeros(dft_len_ws, device=device)
_index_ns = torch.tensor([index_ns] + [(index_ns + i) % samplesLen_ns for i in range(1, measure_len_ws)])
dftBuf[:measure_len_ws] = samples[_index_ns] * spectrum_window[:measure_len_ws]
# memset(c->dftBuf + i, 0, (p->dft_len_ws - i) * sizeof(*c->dftBuf));
dftBuf[measure_len_ws:dft_len_ws].zero_()
dftBuf[:measure_len_ws] = samples * spectrum_window[:measure_len_ws]
# lsx_safe_rdft((int)p->dft_len_ws, 1, c->dftBuf);
_dftBuf = torch.fft.rfft(dftBuf)
# memset(c->dftBuf, 0, p->spectrum_start * sizeof(*c->dftBuf));
_dftBuf[:spectrum_start].zero_()
mult: float = boot_count / (1.0 + boot_count) if boot_count >= 0 else measure_smooth_time_mult
_d = _dftBuf[spectrum_start:spectrum_end].abs()
spectrum[spectrum_start:spectrum_end].mul_(mult).add_(_d * (1 - mult))
_d = spectrum[spectrum_start:spectrum_end] ** 2
_zeros = torch.zeros(spectrum_end - spectrum_start)
_zeros = torch.zeros(spectrum_end - spectrum_start, device=device)
_mult = (
_zeros
if boot_count >= 0
else torch.where(
_d > noise_spectrum[spectrum_start:spectrum_end],
torch.tensor(noise_up_time_mult), # if
torch.tensor(noise_down_time_mult), # else
noise_up_time_mult, # if
noise_down_time_mult, # else,
)
)
......@@ -1441,10 +1433,10 @@ def _measure(
torch.max(
_zeros,
_d - noise_reduction_amount * noise_spectrum[spectrum_start:spectrum_end],
)
),
)
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1)
_cepstrum_Buf: Tensor = torch.zeros(dft_len_ws >> 1, device=device)
_cepstrum_Buf[spectrum_start:spectrum_end] = _d * cepstrum_window
_cepstrum_Buf[spectrum_end : dft_len_ws >> 1].zero_()
......@@ -1539,6 +1531,7 @@ def vad(
Reference:
- http://sox.sourceforge.net/sox.html
"""
device = waveform.device
if waveform.ndim > 2:
warnings.warn(
......@@ -1566,23 +1559,23 @@ def vad(
fixed_pre_trigger_len_ns = int(pre_trigger_time * sample_rate + 0.5)
samplesLen_ns = fixed_pre_trigger_len_ns + search_pre_trigger_len_ns + measure_len_ns
spectrum_window = torch.zeros(measure_len_ws)
spectrum_window = torch.zeros(measure_len_ws, device=device)
for i in range(measure_len_ws):
# sox.h:741 define SOX_SAMPLE_MIN (sox_sample_t)SOX_INT_MIN(32)
spectrum_window[i] = 2.0 / math.sqrt(float(measure_len_ws))
# lsx_apply_hann(spectrum_window, (int)measure_len_ws);
spectrum_window *= torch.hann_window(measure_len_ws, dtype=torch.float)
spectrum_window *= torch.hann_window(measure_len_ws, device=device, dtype=torch.float)
spectrum_start: int = int(hp_filter_freq / sample_rate * dft_len_ws + 0.5)
spectrum_start: int = max(spectrum_start, 1)
spectrum_end: int = int(lp_filter_freq / sample_rate * dft_len_ws + 0.5)
spectrum_end: int = min(spectrum_end, dft_len_ws // 2)
cepstrum_window = torch.zeros(spectrum_end - spectrum_start)
cepstrum_window = torch.zeros(spectrum_end - spectrum_start, device=device)
for i in range(spectrum_end - spectrum_start):
cepstrum_window[i] = 2.0 / math.sqrt(float(spectrum_end) - spectrum_start)
# lsx_apply_hann(cepstrum_window,(int)(spectrum_end - spectrum_start));
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, dtype=torch.float)
cepstrum_window *= torch.hann_window(spectrum_end - spectrum_start, device=device, dtype=torch.float)
cepstrum_start = math.ceil(sample_rate * 0.5 / lp_lifter_freq)
cepstrum_end = math.floor(sample_rate * 0.5 / hp_lifter_freq)
......@@ -1594,14 +1587,13 @@ def vad(
f"Found: cepstrum_start: {cepstrum_start}, cepstrum_end: {cepstrum_end}."
)
noise_up_time_mult = math.exp(-1.0 / (noise_up_time * measure_freq))
noise_down_time_mult = math.exp(-1.0 / (noise_down_time * measure_freq))
noise_up_time_mult = torch.tensor(math.exp(-1.0 / (noise_up_time * measure_freq)), device=device)
noise_down_time_mult = torch.tensor(math.exp(-1.0 / (noise_down_time * measure_freq)), device=device)
measure_smooth_time_mult = math.exp(-1.0 / (measure_smooth_time * measure_freq))
trigger_meas_time_mult = math.exp(-1.0 / (trigger_time * measure_freq))
boot_count_max = int(boot_time * measure_freq - 0.5)
measure_timer_ns = measure_len_ns
boot_count = measures_index = flushedLen_ns = samplesIndex_ns = 0
boot_count = measures_index = flushedLen_ns = 0
# pack batch
shape = waveform.size()
......@@ -1609,80 +1601,65 @@ def vad(
n_channels, ilen = waveform.size()
mean_meas = torch.zeros(n_channels)
samples = torch.zeros(n_channels, samplesLen_ns)
spectrum = torch.zeros(n_channels, dft_len_ws)
noise_spectrum = torch.zeros(n_channels, dft_len_ws)
measures = torch.zeros(n_channels, measures_len)
mean_meas = torch.zeros(n_channels, device=device)
spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
noise_spectrum = torch.zeros(n_channels, dft_len_ws, device=device)
measures = torch.zeros(n_channels, measures_len, device=device)
has_triggered: bool = False
num_measures_to_flush: int = 0
pos: int = 0
while pos < ilen and not has_triggered:
measure_timer_ns -= 1
pos = 0
for pos in range(measure_len_ns, ilen, measure_period_ns):
for i in range(n_channels):
samples[i, samplesIndex_ns] = waveform[i, pos]
# if (!p->measure_timer_ns) {
if measure_timer_ns == 0:
index_ns: int = (samplesIndex_ns + samplesLen_ns - measure_len_ns) % samplesLen_ns
meas: float = _measure(
measure_len_ws=measure_len_ws,
samples=samples[i],
spectrum=spectrum[i],
noise_spectrum=noise_spectrum[i],
spectrum_window=spectrum_window,
spectrum_start=spectrum_start,
spectrum_end=spectrum_end,
cepstrum_window=cepstrum_window,
cepstrum_start=cepstrum_start,
cepstrum_end=cepstrum_end,
noise_reduction_amount=noise_reduction_amount,
measure_smooth_time_mult=measure_smooth_time_mult,
noise_up_time_mult=noise_up_time_mult,
noise_down_time_mult=noise_down_time_mult,
index_ns=index_ns,
boot_count=boot_count,
)
measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered:
n: int = measures_len
k: int = measures_index
jTrigger: int = n
jZero: int = n
j: int = 0
for j in range(n):
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
k = (k + n - 1) % n
j = min(j, jZero)
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
num_measures_to_flush = min(max(num_measures_to_flush, j), n)
# end if has_triggered
# end if (measure_timer_ns == 0):
# end for
samplesIndex_ns += 1
pos += 1
# end while
if samplesIndex_ns == samplesLen_ns:
samplesIndex_ns = 0
if measure_timer_ns == 0:
measure_timer_ns = measure_period_ns
measures_index += 1
measures_index = measures_index % measures_len
if boot_count >= 0:
boot_count = -1 if boot_count == boot_count_max else boot_count + 1
meas: float = _measure(
measure_len_ws=measure_len_ws,
samples=waveform[i, pos - measure_len_ws : pos],
spectrum=spectrum[i],
noise_spectrum=noise_spectrum[i],
spectrum_window=spectrum_window,
spectrum_start=spectrum_start,
spectrum_end=spectrum_end,
cepstrum_window=cepstrum_window,
cepstrum_start=cepstrum_start,
cepstrum_end=cepstrum_end,
noise_reduction_amount=noise_reduction_amount,
measure_smooth_time_mult=measure_smooth_time_mult,
noise_up_time_mult=noise_up_time_mult,
noise_down_time_mult=noise_down_time_mult,
boot_count=boot_count,
)
measures[i, measures_index] = meas
mean_meas[i] = mean_meas[i] * trigger_meas_time_mult + meas * (1.0 - trigger_meas_time_mult)
has_triggered = has_triggered or (mean_meas[i] >= trigger_level)
if has_triggered:
n: int = measures_len
k: int = measures_index
jTrigger: int = n
jZero: int = n
j: int = 0
for j in range(n):
if (measures[i, k] >= trigger_level) and (j <= jTrigger + gap_len):
jZero = jTrigger = j
elif (measures[i, k] == 0) and (jTrigger >= jZero):
jZero = j
k = (k + n - 1) % n
j = min(j, jZero)
# num_measures_to_flush = range_limit(j, num_measures_to_flush, n);
num_measures_to_flush = min(max(num_measures_to_flush, j), n)
# end if has_triggered
# end for channel
measures_index += 1
measures_index = measures_index % measures_len
if boot_count >= 0:
boot_count = -1 if boot_count == boot_count_max else boot_count + 1
if has_triggered:
flushedLen_ns = (measures_len - num_measures_to_flush) * measure_period_ns
samplesIndex_ns = (samplesIndex_ns + flushedLen_ns) % samplesLen_ns
break
# end for window
res = waveform[:, pos - samplesLen_ns + flushedLen_ns :]
# unpack batch
return res.view(shape[:-1] + res.shape[-1:])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment