Commit 873af313 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Rename SpectrogramToDB to AmplitudeToDB (#170)

parent d3fe2a77
...@@ -78,11 +78,11 @@ class Test_JIT(unittest.TestCase): ...@@ -78,11 +78,11 @@ class Test_JIT(unittest.TestCase):
self._test_script_module(spec_f, transforms.MelScale) self._test_script_module(spec_f, transforms.MelScale)
def test_torchscript_spectrogram_to_DB(self): def test_torchscript_amplitude_to_DB(self):
@torch.jit.script @torch.jit.script
def jit_method(spec, multiplier, amin, db_multiplier, top_db): def jit_method(spec, multiplier, amin, db_multiplier, top_db):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
return F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) return F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
spec = torch.rand((6, 201)) spec = torch.rand((6, 201))
multiplier = 10. multiplier = 10.
...@@ -91,15 +91,15 @@ class Test_JIT(unittest.TestCase): ...@@ -91,15 +91,15 @@ class Test_JIT(unittest.TestCase):
top_db = 80. top_db = 80.
jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db) jit_out = jit_method(spec, multiplier, amin, db_multiplier, top_db)
py_out = F.spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db) py_out = F.amplitude_to_DB(spec, multiplier, amin, db_multiplier, top_db)
self.assertTrue(torch.allclose(jit_out, py_out)) self.assertTrue(torch.allclose(jit_out, py_out))
@unittest.skipIf(not RUN_CUDA, "no CUDA") @unittest.skipIf(not RUN_CUDA, "no CUDA")
def test_scriptmodule_SpectrogramToDB(self): def test_scriptmodule_AmplitudeToDB(self):
spec = torch.rand((6, 201), device="cuda") spec = torch.rand((6, 201), device="cuda")
self._test_script_module(spec, transforms.SpectrogramToDB) self._test_script_module(spec, transforms.AmplitudeToDB)
def test_torchscript_create_dct(self): def test_torchscript_create_dct(self):
@torch.jit.script @torch.jit.script
......
...@@ -52,7 +52,7 @@ class Tester(unittest.TestCase): ...@@ -52,7 +52,7 @@ class Tester(unittest.TestCase):
def test_mel2(self): def test_mel2(self):
top_db = 80. top_db = 80.
s2db = transforms.SpectrogramToDB('power', top_db) s2db = transforms.AmplitudeToDB('power', top_db)
waveform = self.waveform.clone() # (1, 16000) waveform = self.waveform.clone() # (1, 16000)
waveform_scaled = self.scale(waveform) # (1, 16000) waveform_scaled = self.scale(waveform) # (1, 16000)
...@@ -155,7 +155,7 @@ class Tester(unittest.TestCase): ...@@ -155,7 +155,7 @@ class Tester(unittest.TestCase):
self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3)) self.assertTrue(torch.allclose(torch_mel.type(librosa_mel_tensor.dtype), librosa_mel_tensor, atol=5e-3))
# test s2db # test s2db
db_transform = torchaudio.transforms.SpectrogramToDB('power', 80.) db_transform = torchaudio.transforms.AmplitudeToDB('power', 80.)
db_torch = db_transform(spect_transform(sound)).squeeze().cpu() db_torch = db_transform(spect_transform(sound)).squeeze().cpu()
db_librosa = librosa.core.spectrum.power_to_db(out_librosa) db_librosa = librosa.core.spectrum.power_to_db(out_librosa)
self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3)) self.assertTrue(torch.allclose(db_torch, torch.from_numpy(db_librosa), atol=5e-3))
......
...@@ -5,7 +5,7 @@ import torch ...@@ -5,7 +5,7 @@ import torch
__all__ = [ __all__ = [
'istft', 'istft',
'spectrogram', 'spectrogram',
'spectrogram_to_DB', 'amplitude_to_DB',
'create_fb_matrix', 'create_fb_matrix',
'create_dct', 'create_dct',
'mu_law_encoding', 'mu_law_encoding',
...@@ -207,34 +207,34 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor ...@@ -207,34 +207,34 @@ def spectrogram(waveform, pad, window, n_fft, hop_length, win_length, power, nor
@torch.jit.script @torch.jit.script
def spectrogram_to_DB(specgram, multiplier, amin, db_multiplier, top_db=None): def amplitude_to_DB(x, multiplier, amin, db_multiplier, top_db=None):
# type: (Tensor, float, float, float, Optional[float]) -> Tensor # type: (Tensor, float, float, float, Optional[float]) -> Tensor
r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. r"""Turns a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so This output depends on the maximum value in the input tensor, and so
may return different values for an audio clip split into snippets vs. a may return different values for an audio clip split into snippets vs. a
a full clip. a full clip.
Args: Args:
specgram (torch.Tensor): Normal STFT of size (c, f, t) x (torch.Tensor): Input tensor before being converted to decibel scale
multiplier (float): Use 10. for power and 20. for amplitude multiplier (float): Use 10. for power and 20. for amplitude
amin (float): Number to clamp specgram amin (float): Number to clamp ``x``
db_multiplier (float): Log10(max(reference value and amin)) db_multiplier (float): Log10(max(reference value and amin))
top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number top_db (Optional[float]): Minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
Returns: Returns:
torch.Tensor: Spectrogram in DB of size (c, f, t) torch.Tensor: Output tensor in decibel scale
""" """
specgram_db = multiplier * torch.log10(torch.clamp(specgram, min=amin)) x_db = multiplier * torch.log10(torch.clamp(x, min=amin))
specgram_db -= multiplier * db_multiplier x_db -= multiplier * db_multiplier
if top_db is not None: if top_db is not None:
new_spec_db_max = torch.tensor(float(specgram_db.max()) - top_db, new_x_db_max = torch.tensor(float(x_db.max()) - top_db,
dtype=specgram_db.dtype, device=specgram_db.device) dtype=x_db.dtype, device=x_db.device)
specgram_db = torch.max(specgram_db, new_spec_db_max) x_db = torch.max(x_db, new_x_db_max)
return specgram_db return x_db
@torch.jit.script @torch.jit.script
......
...@@ -9,7 +9,7 @@ from .compliance import kaldi ...@@ -9,7 +9,7 @@ from .compliance import kaldi
__all__ = [ __all__ = [
'Spectrogram', 'Spectrogram',
'SpectrogramToDB', 'AmplitudeToDB',
'MelScale', 'MelScale',
'MelSpectrogram', 'MelSpectrogram',
'MFCC', 'MFCC',
...@@ -67,15 +67,15 @@ class Spectrogram(torch.jit.ScriptModule): ...@@ -67,15 +67,15 @@ class Spectrogram(torch.jit.ScriptModule):
self.win_length, self.power, self.normalized) self.win_length, self.power, self.normalized)
class SpectrogramToDB(torch.jit.ScriptModule): class AmplitudeToDB(torch.jit.ScriptModule):
r"""Turns a spectrogram from the power/amplitude scale to the decibel scale. r"""Turns a tensor from the power/amplitude scale to the decibel scale.
This output depends on the maximum value in the input spectrogram, and so This output depends on the maximum value in the input tensor, and so
may return different values for an audio clip split into snippets vs. a may return different values for an audio clip split into snippets vs. a
a full clip. a full clip.
Args: Args:
stype (str): scale of input spectrogram ('power' or 'magnitude'). The stype (str): scale of input tensor ('power' or 'magnitude'). The
power being the elementwise square of the magnitude. (Default: 'power') power being the elementwise square of the magnitude. (Default: 'power')
top_db (float, optional): minimum negative cut-off in decibels. A reasonable number top_db (float, optional): minimum negative cut-off in decibels. A reasonable number
is 80. is 80.
...@@ -83,7 +83,7 @@ class SpectrogramToDB(torch.jit.ScriptModule): ...@@ -83,7 +83,7 @@ class SpectrogramToDB(torch.jit.ScriptModule):
__constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier'] __constants__ = ['multiplier', 'amin', 'ref_value', 'db_multiplier']
def __init__(self, stype='power', top_db=None): def __init__(self, stype='power', top_db=None):
super(SpectrogramToDB, self).__init__() super(AmplitudeToDB, self).__init__()
self.stype = torch.jit.Attribute(stype, str) self.stype = torch.jit.Attribute(stype, str)
if top_db is not None and top_db < 0: if top_db is not None and top_db < 0:
raise ValueError('top_db must be positive value') raise ValueError('top_db must be positive value')
...@@ -94,17 +94,17 @@ class SpectrogramToDB(torch.jit.ScriptModule): ...@@ -94,17 +94,17 @@ class SpectrogramToDB(torch.jit.ScriptModule):
self.db_multiplier = math.log10(max(self.amin, self.ref_value)) self.db_multiplier = math.log10(max(self.amin, self.ref_value))
@torch.jit.script_method @torch.jit.script_method
def forward(self, specgram): def forward(self, x):
r"""Numerically stable implementation from Librosa r"""Numerically stable implementation from Librosa
https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
Args: Args:
specgram (torch.Tensor): STFT of size (c, f, t) x (torch.Tensor): Input tensor before being converted to decibel scale
Returns: Returns:
torch.Tensor: STFT after changing scale of size (c, f, t) torch.Tensor: Output tensor in decibel scale
""" """
return F.spectrogram_to_DB(specgram, self.multiplier, self.amin, self.db_multiplier, self.top_db) return F.amplitude_to_DB(x, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MelScale(torch.jit.ScriptModule): class MelScale(torch.jit.ScriptModule):
...@@ -246,7 +246,7 @@ class MFCC(torch.jit.ScriptModule): ...@@ -246,7 +246,7 @@ class MFCC(torch.jit.ScriptModule):
self.dct_type = dct_type self.dct_type = dct_type
self.norm = torch.jit.Attribute(norm, Optional[str]) self.norm = torch.jit.Attribute(norm, Optional[str])
self.top_db = 80.0 self.top_db = 80.0
self.spectrogram_to_DB = SpectrogramToDB('power', self.top_db) self.amplitude_to_DB = AmplitudeToDB('power', self.top_db)
if melkwargs is not None: if melkwargs is not None:
self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs) self.MelSpectrogram = MelSpectrogram(sample_rate=self.sample_rate, **melkwargs)
...@@ -273,7 +273,7 @@ class MFCC(torch.jit.ScriptModule): ...@@ -273,7 +273,7 @@ class MFCC(torch.jit.ScriptModule):
log_offset = 1e-6 log_offset = 1e-6
mel_specgram = torch.log(mel_specgram + log_offset) mel_specgram = torch.log(mel_specgram + log_offset)
else: else:
mel_specgram = self.spectrogram_to_DB(mel_specgram) mel_specgram = self.amplitude_to_DB(mel_specgram)
# (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...) # (c, `n_mels`, t).tranpose(...) dot (`n_mels`, `n_mfcc`) -> (c, t, `n_mfcc`).tranpose(...)
mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2) mfcc = torch.matmul(mel_specgram.transpose(1, 2), self.dct_mat).transpose(1, 2)
return mfcc return mfcc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment