Unverified Commit 6aafbb6d authored by jimchen90's avatar jimchen90 Committed by GitHub
Browse files

Add spectrogram normalization option (#863)



* Add spectrogram normalization option
Co-authored-by: default avatarJi Chen <jimchen90@devfair0160.h2.fair>
parent e808225f
...@@ -163,6 +163,9 @@ def parse_args(): ...@@ -163,6 +163,9 @@ def parse_args():
parser.add_argument( parser.add_argument(
"--file-path", default="", type=str, help="the path of audio files", "--file-path", default="", type=str, help="the path of audio files",
) )
parser.add_argument(
"--normalization", default=True, action="store_true", help="if True, spectrogram is normalized",
)
args = parser.parse_args() args = parser.parse_args()
return args return args
...@@ -273,7 +276,7 @@ def main(args): ...@@ -273,7 +276,7 @@ def main(args):
n_mels=args.n_freq, n_mels=args.n_freq,
fmin=args.f_min, fmin=args.f_min,
), ),
NormalizeDB(min_level_db=args.min_level_db), NormalizeDB(min_level_db=args.min_level_db, normalization=args.normalization),
) )
train_dataset, val_dataset = split_process_dataset(args, transforms) train_dataset, val_dataset = split_process_dataset(args, transforms)
......
...@@ -31,15 +31,18 @@ class NormalizeDB(nn.Module): ...@@ -31,15 +31,18 @@ class NormalizeDB(nn.Module):
r"""Normalize the spectrogram with a minimum db value r"""Normalize the spectrogram with a minimum db value
""" """
def __init__(self, min_level_db): def __init__(self, min_level_db, normalization):
super().__init__() super().__init__()
self.min_level_db = min_level_db self.min_level_db = min_level_db
self.normalization = normalization
def forward(self, specgram): def forward(self, specgram):
specgram = 20 * torch.log10(torch.clamp(specgram, min=1e-5)) specgram = torch.log10(torch.clamp(specgram, min=1e-5))
if self.normalization:
return torch.clamp( return torch.clamp(
(self.min_level_db - specgram) / self.min_level_db, min=0, max=1 (self.min_level_db - 20 * specgram) / self.min_level_db, min=0, max=1
) )
return specgram
def normalized_waveform_to_bits(waveform, bits): def normalized_waveform_to_bits(waveform, bits):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment