Unverified Commit 26237c8b authored by Vincent QB's avatar Vincent QB Committed by GitHub
Browse files

Pitch detection (#313)

* pitch detection validation.
* make torchscriptable.
parent d88c2449
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import math import math
import os
import torch import torch
import torchaudio import torchaudio
...@@ -247,6 +248,30 @@ class TestFunctional(unittest.TestCase): ...@@ -247,6 +248,30 @@ class TestFunctional(unittest.TestCase):
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0) self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
def test_pitch(self):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
# Files from https://www.mediacollege.com/audio/tone/download/
tests = [
(test_filepath_100, 100),
(test_filepath_440, 440),
]
for filename, freq_ref in tests:
waveform, sample_rate = torchaudio.load(filename)
# Convert to stereo for testing purposes
waveform = waveform.repeat(2, 1, 1)
freq = torchaudio.functional.detect_pitch_frequency(waveform, sample_rate)
threshold = 1
s = ((freq - freq_ref).abs() > threshold).sum()
self.assertFalse(s)
def _num_stft_bins(signal_len, fft_len, hop_length, pad): def _num_stft_bins(signal_len, fft_len, hop_length, pad):
return (signal_len + 2 * pad - fft_len + hop_length) // hop_length return (signal_len + 2 * pad - fft_len + hop_length) // hop_length
......
from __future__ import absolute_import, division, print_function, unicode_literals from __future__ import absolute_import, division, print_function, unicode_literals
import math import math
import torch import torch
__all__ = [ __all__ = [
...@@ -801,3 +803,153 @@ def compute_deltas(specgram, win_length=5, mode="replicate"): ...@@ -801,3 +803,153 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return torch.nn.functional.conv1d( return torch.nn.functional.conv1d(
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0] specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
) / denom ) / denom
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor
r"""
Compute Normalized Cross-Correlation Function (NCCF).
.. math::
\phi_i(m) = \frac{\sum_{n=b_i}^{b_i + N-1} w(n) w(m+n)}{\sqrt{E(b_i) E(m+b_i)}},
where
:math:`\phi_i(m)` is the NCCF at frame :math:`i` with lag :math:`m`,
:math:`w` is the waveform,
:math:`N` is the lenght of a frame,
:math:`b_i` is the beginning of frame :math:`i`,
:math:`E(j)` is the energy :math:`\sum_{n=j}^{j+N-1} w^2(n)`.
"""
EPSILON = 10 ** (-9)
# Number of lags to check
lags = math.ceil(sample_rate / freq_low)
frame_size = int(math.ceil(sample_rate * frame_time))
waveform_length = waveform.size()[-1]
num_of_frames = math.ceil(waveform_length / frame_size)
p = lags + num_of_frames * frame_size - waveform_length
waveform = torch.nn.functional.pad(waveform, (0, p))
# Compute lags
output_lag = []
for lag in range(1, lags + 1):
s1 = waveform[..., :-lag].unfold(-1, frame_size, frame_size)[
..., :num_of_frames, :
]
s2 = waveform[..., lag:].unfold(-1, frame_size, frame_size)[
..., :num_of_frames, :
]
output_frames = (
(s1 * s2).sum(-1)
/ (EPSILON + torch.norm(s1, p=2, dim=-1)).pow(2)
/ (EPSILON + torch.norm(s2, p=2, dim=-1)).pow(2)
)
output_lag.append(output_frames.unsqueeze(-1))
nccf = torch.cat(output_lag, -1)
return nccf
def _combine_max(a, b, thresh=0.99):
# type: (Tuple[Tensor, Tensor], Tuple[Tensor, Tensor], float) -> Tuple[Tensor, Tensor]
"""
Take value from first if bigger than a multiplicative factor of the second, elementwise.
"""
mask = (a[0] > thresh * b[0])
values = mask * a[0] + ~mask * b[0]
indices = mask * a[1] + ~mask * b[1]
return values, indices
def _find_max_per_frame(nccf, sample_rate, freq_high):
# type: (Tensor, int, int) -> Tensor
r"""
For each frame, take the highest value of NCCF,
apply centered median smoothing, and convert to frequency.
Note: If the max among all the lags is very close
to the first half of lags, then the latter is taken.
"""
lag_min = math.ceil(sample_rate / freq_high)
# Find near enough max that is smallest
best = torch.max(nccf[..., lag_min:], -1)
half_size = nccf.shape[-1] // 2
half = torch.max(nccf[..., lag_min:half_size], -1)
best = _combine_max(half, best)
indices = best[1]
# Add back minimal lag
indices += lag_min
# Add 1 empirical calibration offset
indices += 1
return indices
def _median_smoothing(indices, win_length):
# type: (Tensor, int) -> Tensor
r"""
Apply median smoothing to the 1D tensor over the given window.
"""
# Centered windowed
pad_length = (win_length - 1) // 2
# "replicate" padding in any dimension
indices = torch.nn.functional.pad(
indices, (pad_length, 0), mode="constant", value=0.
)
indices[..., :pad_length] = torch.cat(pad_length * [indices[..., pad_length].unsqueeze(-1)], dim=-1)
roll = indices.unfold(-1, win_length, 1)
values, _ = torch.median(roll, -1)
return values
@torch.jit.script
def detect_pitch_frequency(
waveform,
sample_rate,
frame_time=10 ** (-2),
win_length=30,
freq_low=85,
freq_high=3400,
):
# type: (Tensor, int, float, int, int, int) -> Tensor
r"""Detect pitch frequency.
It is implemented using normalized cross-correlation function and median smoothing.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, freq, time)
sample_rate (int): The sample rate of the waveform (Hz)
win_length (int): The window length for median smoothing (in number of frames)
freq_low (int): Lowest frequency that can be detected (Hz)
freq_high (int): Highest frequency that can be detected (Hz)
Returns:
freq (torch.Tensor): Tensor of audio of dimension (channel, frame)
"""
nccf = _compute_nccf(waveform, sample_rate, frame_time, freq_low)
indices = _find_max_per_frame(nccf, sample_rate, freq_high)
indices = _median_smoothing(indices, win_length)
# Convert indices to frequency
EPSILON = 10 ** (-9)
freq = sample_rate / (EPSILON + indices.to(torch.float))
return freq
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment