Commit 3d21b437 authored by Jason Lian's avatar Jason Lian
Browse files

more

parent 101e0d5f
import torch
def scale(tensor, factor):
# type: (Tensor, int) -> Tensor
if not tensor.dtype.is_floating_point:
......@@ -7,6 +8,7 @@ def scale(tensor, factor):
return tensor / factor
def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
# type: (Tensor, int, int, int, float) -> Tensor
assert tensor.size(ch_dim) < 128, \
......@@ -22,6 +24,7 @@ def pad_trim(tensor, ch_dim, max_len, len_dim, fill_value):
tensor = tensor.narrow(len_dim, 0, max_len)
return tensor
def downmix_mono(tensor, ch_dim):
# type: (Tensor, int) -> Tensor
if not tensor.dtype.is_floating_point:
......@@ -30,10 +33,12 @@ def downmix_mono(tensor, ch_dim):
tensor = torch.mean(tensor, ch_dim, True)
return tensor
def lc2cl(tensor):
# type: (Tensor) -> Tensor
return tensor.transpose(0, 1).contiguous()
def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
# type: (Tensor, int, Tensor, int, int, int, int, bool) -> Tensor
assert sig.dim() == 2
......@@ -52,3 +57,56 @@ def spectrogram(sig, pad, window, n_fft, hop, ws, power, normalize):
spec_f /= window.pow(2).sum().sqrt()
spec_f = spec_f.pow(power).sum(-1) # get power of "complex" tensor (c, l, n_fft)
return spec_f
def create_fb_matrix(n_stft, f_min, f_max, n_mels):
# type: (int, float, float, int) -> Tensor
""" Create a frequency bin conversion matrix.
Args:
n_stft (int): number of filter banks from spectrogram
"""
def _hertz_to_mel(f):
# type: (float) -> Tensor
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(mel):
# type: (Tensor) -> Tensor
return 700. * (10**(mel / 2595.) - 1.)
# get stft freq bins
stft_freqs = torch.linspace(f_min, f_max, n_stft)
# calculate mel freq bins
m_min = 0. if f_min == 0 else _hertz_to_mel(f_min)
m_max = _hertz_to_mel(f_max)
m_pts = torch.linspace(m_min, m_max, n_mels + 2)
f_pts = _mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def mel_scale(spec_f, f_min, f_max, n_mels, fb=None):
# type: (Tensor, float, float, int, Optional[Tensor]) -> Tuple[Tensor, Tensor]
if fb is None:
fb = create_fb_matrix(spec_f.size(2), f_min, f_max, n_mels).to(spec_f.device)
else:
# need to ensure same device for dot product
fb = fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
return fb, spec_m
def spectrogram_to_DB(spec, multiplier, amin, db_multiplier, top_db):
spec_db = multiplier * torch.log10(torch.clamp(spec, min=amin))
spec_db -= multiplier * db_multiplier
if top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - top_db))
return spec_db
......@@ -205,47 +205,13 @@ class MelScale(object):
self.sr = sr
self.f_max = f_max if f_max is not None else sr // 2
self.f_min = f_min
self.fb = self._create_fb_matrix(n_stft) if n_stft is not None else n_stft
self.fb = F.create_fb_matrix(
n_stft, self.f_min, self.f_max, self.n_mels) if n_stft is not None else n_stft
def __call__(self, spec_f):
if self.fb is None:
self.fb = self._create_fb_matrix(spec_f.size(2)).to(spec_f.device)
else:
# need to ensure same device for dot product
self.fb = self.fb.to(spec_f.device)
spec_m = torch.matmul(spec_f, self.fb) # (c, l, n_fft) dot (n_fft, n_mels) -> (c, l, n_mels)
self.fb, spec_m = F.mel_scale(spec_f, self.f_min, self.f_max, self.n_mels, self.fb)
return spec_m
def _create_fb_matrix(self, n_stft):
""" Create a frequency bin conversion matrix.
Args:
n_stft (int): number of filter banks from spectrogram
"""
# get stft freq bins
stft_freqs = torch.linspace(self.f_min, self.f_max, n_stft)
# calculate mel freq bins
m_min = 0. if self.f_min == 0 else self._hertz_to_mel(self.f_min)
m_max = self._hertz_to_mel(self.f_max)
m_pts = torch.linspace(m_min, m_max, self.n_mels + 2)
f_pts = self._mel_to_hertz(m_pts)
# calculate the difference between each mel point and each stft freq point in hertz
f_diff = f_pts[1:] - f_pts[:-1] # (n_mels + 1)
slopes = f_pts.unsqueeze(0) - stft_freqs.unsqueeze(1) # (n_stft, n_mels + 2)
# create overlapping triangles
z = torch.tensor(0.)
down_slopes = (-1. * slopes[:, :-2]) / f_diff[:-1] # (n_stft, n_mels)
up_slopes = slopes[:, 2:] / f_diff[1:] # (n_stft, n_mels)
fb = torch.max(z, torch.min(down_slopes, up_slopes))
return fb
def _hertz_to_mel(self, f):
return 2595. * torch.log10(torch.tensor(1.) + (f / 700.))
def _mel_to_hertz(self, mel):
return 700. * (10**(mel / 2595.) - 1.)
class SpectrogramToDB(object):
"""Turns a spectrogram from the power/amplitude scale to the decibel scale.
......@@ -273,12 +239,7 @@ class SpectrogramToDB(object):
def __call__(self, spec):
# numerically stable implementation from librosa
# https://librosa.github.io/librosa/_modules/librosa/core/spectrum.html
spec_db = self.multiplier * torch.log10(torch.clamp(spec, min=self.amin))
spec_db -= self.multiplier * self.db_multiplier
if self.top_db is not None:
spec_db = torch.max(spec_db, spec_db.new_full((1,), spec_db.max() - self.top_db))
return spec_db
return F.spectrogram_to_DB(spec, self.multiplier, self.amin, self.db_multiplier, self.top_db)
class MFCC(object):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment