Commit 3d21b437 authored by Jason Lian's avatar Jason Lian
Browse files

more

parent 101e0d5f
import torch import torch
def scale(tensor, factor): def scale(tensor, factor):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
if not tensor.dtype.is_floating_point: if not tensor.dtype.is_floating_point:
...@@ -7,6 +8,7 @@ def scale(tensor, factor): ...@@ -7,6 +8,7 @@ def scale(tensor, factor):
return tensor / factor return tensor / factor
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor # type: (Tensor, int, int, int, float) -> Tensor
assert tensor.size(ch_dim) < 128, \ assert tensor.size(ch_dim) < 128, \
...@@ -22,6 +24,7 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value): ...@@ -22,6 +24,7 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
tensor = tensor.narrow(len_dim, 0, max_len) tensor = tensor.narrow(len_dim, 0, max_len)
return tensor return tensor
def downmix_mono(tensor, ch_dim): def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor # type: (Tensor, int) -> Tensor
if not tensor.dtype.is_floating_point: if not tensor.dtype.is_floating_point:
...@@ -30,10 +33,12 @@ def downmix_mono(tensor, ch_dim): ...@@ -30,10 +33,12 @@ def downmix_mono(tensor, ch_dim):
tensor = torch.mean(tensor, ch_dim, True) tensor = torch.mean(tensor, ch_dim, True)
return tensor return tensor
def lc2cl(tensor): def lc2cl(tensor):
# type: (Tensor) -> Tensor # type: (Tensor) -> Tensor
return tensor.transpose(0, 1).contiguous() return tensor.transpose(0, 1).contiguous()
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor # type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
assert sig.dim() == 2 assert sig.dim() == 2
...@@ -52,3 +57,56 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize): ...@@ -52,3 +57,56 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
spec_f /= window.pow(2).sum().sqrt() spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft) spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f return spec_f
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix.
Args:
n_stft (int): number of filter banks from spectrogram
"""
def _hertz_to_mel(f):
# type: (float) -> Tensor
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(mel):
# type: (Tensor) -> Tensor
return 700. * (10**(mel / 2595.) - 1.)
# get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft)
# calculate mel freq bins
m_min = 0. if f_min == 0 else _hertz_to_mel(f_min)
m_max = _hertz_to_mel(f_max)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
if fb is None:
fb = create_fb_matrix(spec_f.size(2), f_min, f_max, n_mels).to(spec_f.device)
else:
# need to ensure same device for dot product
fb = fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return fb, spec_m
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db):
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
spec_db -= multiplier * db_multiplier
if top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - top_db))
return spec_db
...@@ -205,47 +205,13 @@ class MelScale(object): ...@@ -205,47 +205,13 @@ class MelScale(object):
self.sr = sr self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2 self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min self.f_min = f_min
self.fb = self._create_fb_matrix(n_stft) if n_stft is not None else n_stft self.fb = F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) if n_stft is not None else n_stft
def __call__(self, spec_f): def __call__(self, spec_f):
if self.fb is None: self.fb, spec_m = F.mel_scale(spec_f, self.f_min, self.f_max, self.n_mels, self.fb)
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
else:
# need to ensure same device for dot product
self.fb = self.fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return spec_m return spec_m
def _create_fb_matrix(self, n_stft):
""" Create a frequency bin conversion matrix.
Args:
n_stft (int): number of filter banks from spectrogram
"""
# get stft freq bins
stft_freqs = torch.linspace(self.f_min, self.f_max, n_stft)
# calculate mel freq bins
m_min = 0. if self.f_min == 0 else self._hertz_to_mel(self.f_min)
m_max = self._hertz_to_mel(self.f_max)
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = self._mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def _hertz_to_mel(self, f):
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(self, mel):
return 700. * (10**(mel / 2595.) - 1.)
class SpectrogramToDB(object): class SpectrogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale. """Turns a spectrogram from the power/amplitude scale to the decibel scale.
...@@ -273,12 +239,7 @@ class SpectrogramToDB(object): ...@@ -273,12 +239,7 @@ class SpectrogramToDB(object):
def __call__(self, spec): def __call__(self, spec):
# numerically stable implementation from librosa # numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html # https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
spec_db = self.multiplier * torch.log10(torch.clamp(spec, min=self.amin)) return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
spec_db -= self.multiplier * self.db_multiplier
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - self.top_db))
return spec_db
class MFCC(object): class MFCC(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment