Unverified Commit 4e99c12d authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Fix `amplitude_to_DB` clamping behaviour on batches (#1113)



modified amplitude_to_DB to clamp per-item when a batch is provided
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent c539ad7d
...@@ -3,6 +3,8 @@ import unittest ...@@ -3,6 +3,8 @@ import unittest
import itertools import itertools
from parameterized import parameterized from parameterized import parameterized
import math
import torch import torch
import torchaudio import torchaudio
import torchaudio.functional as F import torchaudio.functional as F
...@@ -59,6 +61,78 @@ class TestFunctional(common_utils.TorchaudioTestCase): ...@@ -59,6 +61,78 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels=n_channels, duration=5) n_channels=n_channels, duration=5)
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate) self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
def test_amplitude_to_DB(self):
torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
# Test with & without a `top_db` clamp
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None)
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)
def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch.
The clamp was determined per-batch in a prior implementation, which
meant it was determined by the loudest item, thus items weren't
independent. See:
https://github.com/pytorch/audio/issues/994
"""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 20.
# Make a batch of noise
torch.manual_seed(0)
spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other
spec[0] += 50
batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
itemwise_dbs = torch.stack([
F.amplitude_to_DB(item, amplitude_mult, amin,
db_mult, top_db=top_db)
for item in spec
])
self.assertEqual(batchwise_dbs, itemwise_dbs)
def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
spec = torch.rand([1, 2, 100, 100]) * 200
# Make one channel blow out the other
spec[:, 0] += 50
specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
channelwise_dbs = torch.stack([
F.amplitude_to_DB(spec[:, i], amplitude_mult, amin,
db_mult, top_db=top_db)
for i in range(spec.size(-3))
])
# Just check channelwise gives a different answer.
difference = (specwise_dbs - channelwise_dbs).abs()
assert (difference >= 1e-5).any()
def test_contrast(self): def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5 waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.) self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
...@@ -103,7 +177,7 @@ class TestTransforms(common_utils.TorchaudioTestCase): ...@@ -103,7 +177,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for classes defined in `transforms` module""" """Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self): def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201)) spec = torch.rand((2, 6, 201))
# Single then transform then batch # Single then transform then batch
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1) expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
......
...@@ -83,46 +83,78 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase): ...@@ -83,46 +83,78 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
self.assertFalse(s) self.assertFalse(s)
class TestDB_to_amplitude(common_utils.TorchaudioTestCase): class Testamplitude_to_DB(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self): @parameterized.expand([
# Make some noise ([100, 100],),
x = torch.rand(1000) ([2, 100, 100],),
spectrogram = torchaudio.transforms.Spectrogram() ([2, 2, 100, 100],),
spec = spectrogram(x) ])
def test_reversible(self, shape):
"""Round trip between amplitude and db should return the original for various shape
This implicitly also tests `DB_to_amplitude`.
"""
amplitude_mult = 20.
power_mult = 10.
amin = 1e-10 amin = 1e-10
ref = 1.0 ref = 1.0
db_multiplier = math.log10(max(amin, ref)) db_mult = math.log10(max(amin, ref))
# Waveform amplitude -> DB -> amplitude
multiplier = 20.
power = 0.5
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5) torch.manual_seed(0)
spec = torch.rand(*shape) * 200
# Spectrogram amplitude -> DB -> amplitude # Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, 0.5)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5) self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
# Waveform power -> DB -> power
multiplier = 10.
power = 1.
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power # Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None) db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power) x2 = F.DB_to_amplitude(db, ref, 1.)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5) self.assertEqual(x2, spec)
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
# Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200
decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
# Ensure the clamp was applied
below_limit = decibels < 6.0205
assert not below_limit.any(), (
"{} decibel values were below the expected cutoff:\n{}".format(
below_limit.sum().item(), decibels
)
)
# Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), (
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)
class TestComplexNorm(common_utils.TorchaudioTestCase): class TestComplexNorm(common_utils.TorchaudioTestCase):
......
...@@ -237,14 +237,16 @@ def amplitude_to_DB( ...@@ -237,14 +237,16 @@ def amplitude_to_DB(
db_multiplier: float, db_multiplier: float,
top_db: Optional[float] = None top_db: Optional[float] = None
) -> Tensor: ) -> Tensor:
r"""Turn a tensor from the power/amplitude scale to the decibel scale. r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so The output of each tensor in a batch depends on the maximum value of that tensor,
may return different values for an audio clip split into snippets vs. a and so may return different values for an audio clip split into snippets vs. a full clip.
full clip.
Args: Args:
x (Tensor): Input tensor before being converted to decibel scale
x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
the form `(..., freq, time)`. Batched inputs should include a channel dimension and
have the form `(batch, channel, freq, time)`.
multiplier (float): Use 10. for power and 20. for amplitude multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp ``x`` amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin)) db_multiplier (float): Log10(max(reference value and amin))
...@@ -258,7 +260,15 @@ def amplitude_to_DB( ...@@ -258,7 +260,15 @@ def amplitude_to_DB(
x_db -= multiplier * db_multiplier x_db -= multiplier * db_multiplier
if top_db is not None: if top_db is not None:
x_db = x_db.clamp(min=x_db.max().item() - top_db) # Expand batch
shape = x_db.size()
packed_channels = shape[-3] if x_db.dim() > 2 else 1
x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])
x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))
# Repack batch
x_db = x_db.reshape(shape)
return x_db return x_db
......
...@@ -541,24 +541,16 @@ class MFCC(torch.nn.Module): ...@@ -541,24 +541,16 @@ class MFCC(torch.nn.Module):
Returns: Returns:
Tensor: specgram_mel_db of size (..., ``n_mfcc``, time). Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
""" """
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform) mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels: if self.log_mels:
log_offset = 1e-6 log_offset = 1e-6
mel_specgram = torch.log(mel_specgram + log_offset) mel_specgram = torch.log(mel_specgram + log_offset)
else: else:
mel_specgram = self.amplitude_to_DB(mel_specgram) mel_specgram = self.amplitude_to_DB(mel_specgram)
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
# (..., channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (..., channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
return mfcc return mfcc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment