Commit 946b180a authored by bshall's avatar bshall Committed by Facebook GitHub Bot
Browse files

An implemenation of the ITU-R BS.1770-4 loudness recommendation (#2472)

Summary:
I took a stab at implementing the ITU-R BS.1770-4 loudness recommendation (closes https://github.com/pytorch/audio/issues/1205). To give some more details:
- I've implemented K-weighting following csteinmetz1 instead of BrechtDeMan since it fit well with torchaudio's already implemented filters (`treble_biquad` and `highpass_biquad`).
- I've added four audio files to test compliance with the recommendation. These are linked in [this pdf](https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf). There are many more test files there but I didn't want to bog down the assets directory with too many files. Let me know if I should add or remove anything.
- I've kept many of the constant internal to the function (e.g. the block duration, overlap, and the absolute threshold gamma). I'm not sure if these should be exposed in the signature.
- I've implemented support for up to 5 channels (following both csteinmetz1 and BrechtDeMan). The recommendation includes weights for up to 24 channels. Is there any convention for how many channels to support?

I hope this is helpful! looking forward to hearing from you.

Pull Request resolved: https://github.com/pytorch/audio/pull/2472

Reviewed By: hwangjeff

Differential Revision: D38389155

Pulled By: carolineechen

fbshipit-source-id: fcc86d864c04ab2bedaa9acd941ebc4478ca6904
parent 8e0c2a3b
...@@ -69,6 +69,11 @@ resample ...@@ -69,6 +69,11 @@ resample
.. autofunction:: resample .. autofunction:: resample
loudness
--------
.. autofunction:: loudness
:hidden:`Filtering` :hidden:`Filtering`
~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~
......
...@@ -90,6 +90,13 @@ Transforms are common audio transforms. They can be chained together using :clas ...@@ -90,6 +90,13 @@ Transforms are common audio transforms. They can be chained together using :clas
.. automethod:: forward .. automethod:: forward
:hidden:`Loudness`
-----------------
.. autoclass:: Loudness
.. automethod:: forward
:hidden:`Feature Extractions` :hidden:`Feature Extractions`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......
"""Test suite for compliance with the ITU-R BS.1770-4 recommendation"""
import zipfile
import pytest
import torch
import torchaudio
import torchaudio.functional as F
# Test files linked in https://www.itu.int/dms_pub/itu-r/opb/rep/R-REP-BS.2217-2-2016-PDF-E.pdf
@pytest.mark.parametrize(
"filename,url,expected",
[
(
"1770-2_Comp_RelGateTest",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010030ZIPM.zip",
-10.0,
),
(
"1770-2_Comp_AbsGateTest",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010029ZIPM.zip",
-69.5,
),
(
"1770-2_Comp_24LKFS_500Hz_2ch",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010018ZIPM.zip",
-24.0,
),
(
"1770-2 Conf Mono Voice+Music-24LKFS",
"http://www.itu.int/dms_pub/itu-r/oth/11/02/R11020000010038ZIPM.zip",
-24.0,
),
],
)
def test_loudness(tmp_path, filename, url, expected):
zippath = tmp_path / filename
torch.hub.download_url_to_file(url, zippath, progress=False)
with zipfile.ZipFile(zippath) as file:
file.extractall(zippath.parent)
waveform, sample_rate = torchaudio.load(zippath.with_suffix(".wav"))
loudness = F.loudness(waveform, sample_rate)
expected = torch.tensor(expected, dtype=loudness.dtype, device=loudness.device)
assert torch.allclose(loudness, expected, rtol=0.01, atol=0.1)
...@@ -120,6 +120,14 @@ class Functional(TempDirMixin, TestBaseMixin): ...@@ -120,6 +120,14 @@ class Functional(TempDirMixin, TestBaseMixin):
self._assert_consistency(func, (waveform,)) self._assert_consistency(func, (waveform,))
def test_measure_loudness(self):
if self.dtype == torch.float64:
raise unittest.SkipTest("This test is known to fail for float64")
sample_rate = 44100
waveform = common_utils.get_sinusoid(sample_rate=sample_rate, device=self.device)
self._assert_consistency(F.loudness, (waveform, sample_rate))
def test_melscale_fbanks(self): def test_melscale_fbanks(self):
if self.device != torch.device("cpu"): if self.device != torch.device("cpu"):
raise unittest.SkipTest("No need to perform test on device other than CPU") raise unittest.SkipTest("No need to perform test on device other than CPU")
......
...@@ -35,6 +35,7 @@ from .functional import ( ...@@ -35,6 +35,7 @@ from .functional import (
griffinlim, griffinlim,
inverse_spectrogram, inverse_spectrogram,
linear_fbanks, linear_fbanks,
loudness,
mask_along_axis, mask_along_axis,
mask_along_axis_iid, mask_along_axis_iid,
melscale_fbanks, melscale_fbanks,
...@@ -62,6 +63,7 @@ __all__ = [ ...@@ -62,6 +63,7 @@ __all__ = [
"melscale_fbanks", "melscale_fbanks",
"linear_fbanks", "linear_fbanks",
"DB_to_amplitude", "DB_to_amplitude",
"loudness",
"detect_pitch_frequency", "detect_pitch_frequency",
"griffinlim", "griffinlim",
"mask_along_axis", "mask_along_axis",
......
...@@ -11,6 +11,8 @@ import torchaudio ...@@ -11,6 +11,8 @@ import torchaudio
from torch import Tensor from torch import Tensor
from torchaudio._internal import module_utils as _mod_utils from torchaudio._internal import module_utils as _mod_utils
from .filtering import highpass_biquad, treble_biquad
__all__ = [ __all__ = [
"spectrogram", "spectrogram",
"inverse_spectrogram", "inverse_spectrogram",
...@@ -35,6 +37,7 @@ __all__ = [ ...@@ -35,6 +37,7 @@ __all__ = [
"apply_codec", "apply_codec",
"resample", "resample",
"edit_distance", "edit_distance",
"loudness",
"pitch_shift", "pitch_shift",
"rnnt_loss", "rnnt_loss",
"psd", "psd",
...@@ -1640,6 +1643,67 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int: ...@@ -1640,6 +1643,67 @@ def edit_distance(seq1: Sequence, seq2: Sequence) -> int:
return int(dold[-1]) return int(dold[-1])
def loudness(waveform: Tensor, sample_rate: int):
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
sample_rate (int): sampling rate of the waveform
Returns:
Tensor: loudness estimates (LKFS)
Reference:
- https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
"""
if waveform.size(-2) > 5:
raise ValueError("Only up to 5 channels are supported.")
gate_duration = 0.4
overlap = 0.75
gamma_abs = -70.0
kweight_bias = -0.691
gate_samples = int(round(gate_duration * sample_rate))
step = int(round(gate_samples * (1 - overlap)))
# Apply K-weighting
waveform = treble_biquad(waveform, sample_rate, 4.0, 1500.0, 1 / math.sqrt(2))
waveform = highpass_biquad(waveform, sample_rate, 38.0, 0.5)
# Compute the energy for each block
energy = torch.square(waveform).unfold(-1, gate_samples, step)
energy = torch.mean(energy, dim=-1)
# Compute channel-weighted summation
g = torch.tensor([1.0, 1.0, 1.0, 1.41, 1.41], dtype=waveform.dtype, device=waveform.device)
g = g[: energy.size(-2)]
energy_weighted = torch.sum(g.unsqueeze(-1) * energy, dim=-2)
loudness = -0.691 + 10 * torch.log10(energy_weighted)
# Apply absolute gating of the blocks
gated_blocks = loudness > gamma_abs
gated_blocks = gated_blocks.unsqueeze(-2)
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
gamma_rel = kweight_bias + 10 * torch.log10(energy_weighted) - 10
# Apply relative gating of the blocks
gated_blocks = torch.logical_and(gated_blocks.squeeze(-2), loudness > gamma_rel.unsqueeze(-1))
gated_blocks = gated_blocks.unsqueeze(-2)
energy_filtered = torch.sum(gated_blocks * energy, dim=-1) / torch.count_nonzero(gated_blocks, dim=-1)
energy_weighted = torch.sum(g * energy_filtered, dim=-1)
LKFS = kweight_bias + 10 * torch.log10(energy_weighted)
return LKFS
def pitch_shift( def pitch_shift(
waveform: Tensor, waveform: Tensor,
sample_rate: int, sample_rate: int,
......
...@@ -8,6 +8,7 @@ from ._transforms import ( ...@@ -8,6 +8,7 @@ from ._transforms import (
InverseMelScale, InverseMelScale,
InverseSpectrogram, InverseSpectrogram,
LFCC, LFCC,
Loudness,
MelScale, MelScale,
MelSpectrogram, MelSpectrogram,
MFCC, MFCC,
...@@ -35,6 +36,7 @@ __all__ = [ ...@@ -35,6 +36,7 @@ __all__ = [
"InverseMelScale", "InverseMelScale",
"InverseSpectrogram", "InverseSpectrogram",
"LFCC", "LFCC",
"Loudness",
"MFCC", "MFCC",
"MVDR", "MVDR",
"MelScale", "MelScale",
......
...@@ -1251,6 +1251,36 @@ class TimeMasking(_AxisMasking): ...@@ -1251,6 +1251,36 @@ class TimeMasking(_AxisMasking):
super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p) super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks, p=p)
class Loudness(torch.nn.Module):
r"""Measure audio loudness according to the ITU-R BS.1770-4 recommendation.
.. devices:: CPU CUDA
.. properties:: TorchScript
Args:
sample_rate (int): Sample rate of audio signal.
Reference:
- https://www.itu.int/rec/R-REC-BS.1770-4-201510-I/en
"""
__constants__ = ["sample_rate"]
def __init__(self, sample_rate: int):
super(Loudness, self).__init__()
self.sample_rate = sample_rate
def forward(self, wavefrom: Tensor):
r"""
Args:
waveform(torch.Tensor): audio waveform of dimension `(..., channels, time)`
Returns:
Tensor: loudness estimates (LKFS)
"""
return F.loudness(wavefrom, self.sample_rate)
class Vol(torch.nn.Module): class Vol(torch.nn.Module):
r"""Adjust volume of waveform. r"""Adjust volume of waveform.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment