Unverified Commit 4e99c12d authored by Jcaw's avatar Jcaw Committed by GitHub
Browse files

Fix `amplitude_to_DB` clamping behaviour on batches (#1113)



modified amplitude_to_DB to clamp per-item when a batch is provided
Co-authored-by: default avatarVincent QB <vincentqb@users.noreply.github.com>
parent c539ad7d
......@@ -3,6 +3,8 @@ import unittest
import itertools
from parameterized import parameterized
import math
import torch
import torchaudio
import torchaudio.functional as F
......@@ -59,6 +61,78 @@ class TestFunctional(common_utils.TorchaudioTestCase):
n_channels=n_channels, duration=5)
self.assert_batch_consistencies(F.detect_pitch_frequency, waveform, sample_rate)
def test_amplitude_to_DB(self):
torch.manual_seed(0)
spec = torch.rand(2, 100, 100) * 200
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
# Test with & without a `top_db` clamp
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=None)
self.assert_batch_consistencies(F.amplitude_to_DB, spec, amplitude_mult,
amin, db_mult, top_db=40.)
def test_amplitude_to_DB_itemwise_clamps(self):
"""Ensure that the clamps are separate for each spectrogram in a batch.
The clamp was determined per-batch in a prior implementation, which
meant it was determined by the loudest item, thus items weren't
independent. See:
https://github.com/pytorch/audio/issues/994
"""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 20.
# Make a batch of noise
torch.manual_seed(0)
spec = torch.rand([2, 2, 100, 100]) * 200
# Make one item blow out the other
spec[0] += 50
batchwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
itemwise_dbs = torch.stack([
F.amplitude_to_DB(item, amplitude_mult, amin,
db_mult, top_db=top_db)
for item in spec
])
self.assertEqual(batchwise_dbs, itemwise_dbs)
def test_amplitude_to_DB_not_channelwise_clamps(self):
"""Check that clamps are applied per-item, not per channel."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
spec = torch.rand([1, 2, 100, 100]) * 200
# Make one channel blow out the other
spec[:, 0] += 50
specwise_dbs = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
channelwise_dbs = torch.stack([
F.amplitude_to_DB(spec[:, i], amplitude_mult, amin,
db_mult, top_db=top_db)
for i in range(spec.size(-3))
])
# Just check channelwise gives a different answer.
difference = (specwise_dbs - channelwise_dbs).abs()
assert (difference >= 1e-5).any()
def test_contrast(self):
waveform = torch.rand(2, 100) - 0.5
self.assert_batch_consistencies(F.contrast, waveform, enhancement_amount=80.)
......@@ -103,7 +177,7 @@ class TestTransforms(common_utils.TorchaudioTestCase):
"""Test suite for classes defined in `transforms` module"""
def test_batch_AmplitudeToDB(self):
spec = torch.rand((6, 201))
spec = torch.rand((2, 6, 201))
# Single then transform then batch
expected = torchaudio.transforms.AmplitudeToDB()(spec).repeat(3, 1, 1)
......
......@@ -83,46 +83,78 @@ class TestDetectPitchFrequency(common_utils.TorchaudioTestCase):
self.assertFalse(s)
class TestDB_to_amplitude(common_utils.TorchaudioTestCase):
def test_DB_to_amplitude(self):
# Make some noise
x = torch.rand(1000)
spectrogram = torchaudio.transforms.Spectrogram()
spec = spectrogram(x)
class Testamplitude_to_DB(common_utils.TorchaudioTestCase):
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_reversible(self, shape):
"""Round trip between amplitude and db should return the original for various shape
This implicitly also tests `DB_to_amplitude`.
"""
amplitude_mult = 20.
power_mult = 10.
amin = 1e-10
ref = 1.0
db_multiplier = math.log10(max(amin, ref))
# Waveform amplitude -> DB -> amplitude
multiplier = 20.
power = 0.5
db = F.amplitude_to_DB(torch.abs(x), multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
db_mult = math.log10(max(amin, ref))
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
torch.manual_seed(0)
spec = torch.rand(*shape) * 200
# Spectrogram amplitude -> DB -> amplitude
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
db = F.amplitude_to_DB(spec, amplitude_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 0.5)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
# Waveform power -> DB -> power
multiplier = 10.
power = 1.
db = F.amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertEqual(x2, torch.abs(x), atol=5e-5, rtol=1e-5)
# Spectrogram power -> DB -> power
db = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db=None)
x2 = F.DB_to_amplitude(db, ref, power)
self.assertEqual(x2, spec, atol=5e-5, rtol=1e-5)
db = F.amplitude_to_DB(spec, power_mult, amin, db_mult, top_db=None)
x2 = F.DB_to_amplitude(db, ref, 1.)
self.assertEqual(x2, spec)
@parameterized.expand([
([100, 100],),
([2, 100, 100],),
([2, 2, 100, 100],),
])
def test_top_db_clamp(self, shape):
"""Ensure values are properly clamped when `top_db` is supplied."""
amplitude_mult = 20.
amin = 1e-10
ref = 1.0
db_mult = math.log10(max(amin, ref))
top_db = 40.
torch.manual_seed(0)
# A random tensor is used for increased entropy, but the max and min for
# each spectrogram still need to be predictable. The max determines the
# decibel cutoff, and the distance from the min must be large enough
# that it triggers a clamp.
spec = torch.rand(*shape)
# Ensure each spectrogram has a min of 0 and a max of 1.
spec -= spec.amin([-2, -1])[..., None, None]
spec /= spec.amax([-2, -1])[..., None, None]
# Expand the range to (0, 200) - wide enough to properly test clamping.
spec *= 200
decibels = F.amplitude_to_DB(spec, amplitude_mult, amin,
db_mult, top_db=top_db)
# Ensure the clamp was applied
below_limit = decibels < 6.0205
assert not below_limit.any(), (
"{} decibel values were below the expected cutoff:\n{}".format(
below_limit.sum().item(), decibels
)
)
# Ensure it didn't over-clamp
close_to_limit = decibels < 6.0207
assert close_to_limit.any(), (
f"No values were close to the limit. Did it over-clamp?\n{decibels}"
)
class TestComplexNorm(common_utils.TorchaudioTestCase):
......
......@@ -237,14 +237,16 @@ def amplitude_to_DB(
db_multiplier: float,
top_db: Optional[float] = None
) -> Tensor:
r"""Turn a tensor from the power/amplitude scale to the decibel scale.
r"""Turn a spectrogram from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input tensor, and so
may return different values for an audio clip split into snippets vs. a
full clip.
The output of each tensor in a batch depends on the maximum value of that tensor,
and so may return different values for an audio clip split into snippets vs. a full clip.
Args:
x (Tensor): Input tensor before being converted to decibel scale
x (Tensor): Input spectrogram(s) before being converted to decibel scale. Input should take
the form `(..., freq, time)`. Batched inputs should include a channel dimension and
have the form `(batch, channel, freq, time)`.
multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin))
......@@ -258,7 +260,15 @@ def amplitude_to_DB(
x_db -= multiplier * db_multiplier
if top_db is not None:
x_db = x_db.clamp(min=x_db.max().item() - top_db)
# Expand batch
shape = x_db.size()
packed_channels = shape[-3] if x_db.dim() > 2 else 1
x_db = x_db.reshape(-1, packed_channels, shape[-2], shape[-1])
x_db = torch.max(x_db, (x_db.amax(dim=(-3, -2, -1)) - top_db).view(-1, 1, 1, 1))
# Repack batch
x_db = x_db.reshape(shape)
return x_db
......
......@@ -541,24 +541,16 @@ class MFCC(torch.nn.Module):
Returns:
Tensor: specgram_mel_db of size (..., ``n_mfcc``, time).
"""
# pack batch
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])
mel_specgram = self.MelSpectrogram(waveform)
if self.log_mels:
log_offset = 1e-6
mel_specgram = torch.log(mel_specgram + log_offset)
else:
mel_specgram = self.amplitude_to_DB(mel_specgram)
# (channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
# unpack batch
mfcc = mfcc.reshape(shape[:-1] + mfcc.shape[-2:])
# (..., channel, n_mels, time).tranpose(...) dot (n_mels, n_mfcc)
# -> (..., channel, time, n_mfcc).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(-2, -1), self.dct_mat).transpose(-2, -1)
return mfcc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment