Unverified Commit e061b268 authored by discort's avatar discort Committed by GitHub
Browse files

Use torchaudio melscale 'slaney' instead of librosa in WaveRNN pipeline preprocessing (#1444)

* Use torchaudio melscale instead of librosa
parent 48630302
...@@ -17,7 +17,7 @@ from torchaudio.models.wavernn import WaveRNN ...@@ -17,7 +17,7 @@ from torchaudio.models.wavernn import WaveRNN
from datasets import collate_factory, split_process_dataset from datasets import collate_factory, split_process_dataset
from losses import LongCrossEntropyLoss, MoLLoss from losses import LongCrossEntropyLoss, MoLLoss
from processing import LinearToMel, NormalizeDB from processing import NormalizeDB
from utils import MetricLogger, count_parameters, save_checkpoint from utils import MetricLogger, count_parameters, save_checkpoint
...@@ -269,12 +269,12 @@ def main(args): ...@@ -269,12 +269,12 @@ def main(args):
} }
transforms = torch.nn.Sequential( transforms = torch.nn.Sequential(
torchaudio.transforms.Spectrogram(**melkwargs), torchaudio.transforms.MelSpectrogram(
LinearToMel(
sample_rate=args.sample_rate, sample_rate=args.sample_rate,
n_fft=args.n_fft,
n_mels=args.n_freq, n_mels=args.n_freq,
fmin=args.f_min, f_min=args.f_min,
mel_scale='slaney',
**melkwargs,
), ),
NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization), NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
) )
......
import librosa
import torch import torch
import torch.nn as nn import torch.nn as nn
# TODO Replace by torchaudio, once https://github.com/pytorch/audio/pull/593 is resolved
class LinearToMel(nn.Module):
def __init__(self, sample_rate, n_fft, n_mels, fmin, htk=False, norm="slaney"):
super().__init__()
self.sample_rate = sample_rate
self.n_fft = n_fft
self.n_mels = n_mels
self.fmin = fmin
self.htk = htk
self.norm = norm
def forward(self, specgram):
specgram = librosa.feature.melspectrogram(
S=specgram.squeeze(0).numpy(),
sr=self.sample_rate,
n_fft=self.n_fft,
n_mels=self.n_mels,
fmin=self.fmin,
htk=self.htk,
norm=self.norm,
)
return torch.from_numpy(specgram)
class NormalizeDB(nn.Module): class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value r"""Normalize the spectrogram with a minimum db value
""" """
...@@ -37,7 +12,7 @@ class NormalizeDB(nn.Module): ...@@ -37,7 +12,7 @@ class NormalizeDB(nn.Module):
self.normalization = normalization self.normalization = normalization
def forward(self, specgram): def forward(self, specgram):
specgram = torch.log10(torch.clamp(specgram, min=1e-5)) specgram = torch.log10(torch.clamp(specgram.squeeze(0), min=1e-5))
if self.normalization: if self.normalization:
return torch.clamp( return torch.clamp(
(self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1 (self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment