Commit 9bd633e3 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Implementing Kaldi Spectrogram (#119)

parent a420cced
import random
# Path to the compute-spectrogram-feats executable.
EXE_PATH = '/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats'
# Path to the scp file. An example of its contents would be "my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav"
# where the space separates an id from a wav file.
SCP_PATH = 'scp:/scratch/jamarshon/downloads/a.scp'
# The directory to which the stft features will be written to.
OUTPUT_DIR = 'ark:/scratch/jamarshon/audio/test/assets/kaldi/'
# The number of samples inside the input wave file read from `SCP_PATH`
WAV_LEN = 20
# How many output files should be generated.
NUMBER_OF_OUTPUTS = 100
WINDOWS = ['hamming', 'hanning', 'povey', 'rectangular', 'blackman']
def generate_rand_boolean():
# Generates a random boolean ('true', 'false')
return 'true' if random.randint(0, 1) else 'false'
def generate_rand_window_type():
# Generates a random window type
return WINDOWS[random.randint(0, len(WINDOWS) - 1)]
def run():
for i in range(NUMBER_OF_OUTPUTS):
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
'dither': '0',
'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(2, WAV_LEN - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, WAV_LEN - 1)) / 16000 * 1000),
'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': generate_rand_boolean(),
'remove_dc_offset': generate_rand_boolean(),
'round_to_power_of_two': generate_rand_boolean(),
'snip_edges': generate_rand_boolean(),
'subtract_mean': generate_rand_boolean(),
'window_type': generate_rand_window_type()
}
fn = 'spec-' + ('-'.join(list(inputs.values())))
arg = [
EXE_PATH,
'--blackman-coeff=' + inputs['blackman_coeff'],
'--dither=' + inputs['dither'],
'--energy-floor=' + inputs['energy_floor'],
'--frame-length=' + inputs['frame_length'],
'--frame-shift=' + inputs['frame_shift'],
'--preemphasis-coefficient=' + inputs['preemphasis_coefficient'],
'--raw-energy=' + inputs['raw_energy'],
'--remove-dc-offset=' + inputs['remove_dc_offset'],
'--round-to-power-of-two=' + inputs['round_to_power_of_two'],
'--sample-frequency=16000',
'--snip-edges=' + inputs['snip_edges'],
'--subtract-mean=' + inputs['subtract_mean'],
'--window-type=' + inputs['window_type'],
SCP_PATH,
OUTPUT_DIR + fn + '.ark'
]
print(fn)
print(inputs)
print(' '.join(arg))
try:
subprocess.call(arg)
except Exception:
pass
if __name__ == '__main__':
run()
import math
import os
import test.common_utils
import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
import unittest
def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
# just a copy of ExtractWindow from feature-window.cc in python
def first_sample_of_frame(frame, window_size, window_shift, snip_edges):
if snip_edges:
return frame * window_shift
else:
midpoint_of_frame = frame * window_shift + window_shift // 2
beginning_of_frame = midpoint_of_frame - window_size // 2
return beginning_of_frame
sample_offset = 0
num_samples = sample_offset + wave.size(0)
start_sample = first_sample_of_frame(f, frame_length, frame_shift, snip_edges)
end_sample = start_sample + frame_length
if snip_edges:
assert(start_sample >= sample_offset and end_sample <= num_samples)
else:
assert(sample_offset == 0 or start_sample >= sample_offset)
wave_start = start_sample - sample_offset
wave_end = wave_start + frame_length
if wave_start >= 0 and wave_end <= wave.size(0):
window[f, :] = wave[wave_start:(wave_start + frame_length)]
else:
wave_dim = wave.size(0)
for s in range(frame_length):
s_in_wave = s + wave_start
while s_in_wave < 0 or s_in_wave >= wave_dim:
if s_in_wave < 0:
s_in_wave = - s_in_wave - 1
else:
s_in_wave = 2 * wave_dim - 1 - s_in_wave
window[f, s] = wave[s_in_wave]
class Test_Kaldi(unittest.TestCase):
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
waveform = torch.arange(num_samples).float()
output = kaldi._get_strided(waveform, window_size, window_shift, snip_edges)
# from NumFrames in feature-window.cc
n = window_size
if snip_edges:
m = 0 if num_samples < window_size else 1 + (num_samples - window_size) // window_shift
else:
m = (num_samples + (window_shift // 2)) // window_shift
self.assertTrue(output.dim() == 2)
self.assertTrue(output.shape[0] == m and output.shape[1] == n)
window = torch.empty((m, window_size))
for r in range(m):
extract_window(window, waveform, r, window_size, window_shift, snip_edges)
self.assertTrue(torch.allclose(window, output))
def test_get_strided(self):
# generate any combination where 0 < window_size <= num_samples and
# 0 < window_shift.
for num_samples in range(1, 20):
for window_size in range(1, num_samples + 1):
for window_shift in range(1, 2 * num_samples + 1):
for snip_edges in range(0, 2):
self._test_get_strided_helper(num_samples, window_size, window_shift, snip_edges)
def _create_data_set(self):
# used to generate the dataset to test on. this is not used in testing (offline procedure)
test_dirpath = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
sr = 16000
x = torch.arange(0, 20).float()
# between [-6,6]
y = torch.cos(2 * math.pi * x) + 3 * torch.sin(math.pi * x) + 2 * torch.cos(x)
# between [-2^30, 2^30]
y = (y / 6 * (1 << 30)).long()
# clear the last 16 bits because they aren't used anyways
y = ((y >> 16) << 16).float()
torchaudio.save(test_filepath, y, sr)
sound, sample_rate = torchaudio.load(test_filepath, normalization=False)
print(y >> 16)
self.assertTrue(sample_rate == sr)
self.assertTrue(torch.allclose(y, sound))
def test_spectrogram(self):
sound, sample_rate = torchaudio.load_wav(self.test_filepath)
kaldi_output_dir = os.path.join(self.test_dirpath, 'assets', 'kaldi')
files = list(filter(lambda x: x.startswith('spec'), os.listdir(kaldi_output_dir)))
print('Results:', len(files))
for f in files:
print(f)
kaldi_output_path = os.path.join(kaldi_output_dir, f)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)}
assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
kaldi_output = kaldi_output_dict['my_id']
args = f.split('-')
args[-1] = os.path.splitext(args[-1])[0]
assert len(args) == 13, 'invalid test kaldi file name'
spec_output = kaldi.spectrogram(
sound,
blackman_coeff=float(args[1]),
dither=float(args[2]),
energy_floor=float(args[3]),
frame_length=float(args[4]),
frame_shift=float(args[5]),
preemphasis_coefficient=float(args[6]),
raw_energy=args[7] == 'true',
remove_dc_offset=args[8] == 'true',
round_to_power_of_two=args[9] == 'true',
snip_edges=args[10] == 'true',
subtract_mean=args[11] == 'true',
window_type=args[12])
error = spec_output - kaldi_output
mse = error.pow(2).sum() / spec_output.numel()
max_error = torch.max(error.abs())
print('mse:', mse.item(), 'max_error:', max_error.item())
self.assertTrue(spec_output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(spec_output, kaldi_output, atol=1e-3, rtol=0))
if __name__ == '__main__':
unittest.main()
......@@ -14,7 +14,7 @@ class TORCHAUDIODS(Dataset):
def __init__(self):
self.asset_dirpath = os.path.join(self.test_dirpath, "assets")
sound_files = list(filter(lambda x: '.wav' in x or '.mp3' in x, os.listdir(self.asset_dirpath)))
sound_files = ["sinewave.wav", "steam-train-whistle-daniel_simon.mp3"]
self.data = [os.path.join(self.asset_dirpath, fn) for fn in sound_files]
self.si, self.ei = torchaudio.info(os.path.join(self.asset_dirpath, "sinewave.wav"))
self.si.precision = 16
......
......@@ -4,7 +4,7 @@ import os.path
import torch
import _torch_sox
from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy
from torchaudio import transforms, datasets, kaldi_io, sox_effects, legacy, compliance
def check_input(src):
......@@ -92,6 +92,14 @@ def load(filepath,
return out, sample_rate
def load_wav(filepath, **kwargs):
""" Loads a wave file. It assumes that the wav file uses 16 bit per sample that needs normalization by shifting
the input right by 16 bits.
"""
kwargs['normalization'] = 1 << 16
return load(filepath, **kwargs)
def save(filepath, src, sample_rate, precision=16, channels_first=True):
"""Convenience function for `save_encinfo`.
......
import math
import random
import torch
__all__ = [
'spectrogram'
]
# numeric_limits<float>::epsilon()
EPSILON = torch.tensor(1.19209290e-07, dtype=torch.get_default_dtype())
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = 'hamming'
HANNING = 'hanning'
POVEY = 'povey'
RECTANGULAR = 'rectangular'
BLACKMAN = 'blackman'
def _next_power_of_2(x):
""" Returns the smallest power of 2 that is greater than x
"""
return 1 if x == 0 else 2**(x - 1).bit_length()
def _get_strided(waveform, window_size, window_shift, snip_edges):
""" Given a waveform (1D tensor of size num_samples), it returns a 2D tensor (m, window_size)
representing how the window is shifted along the waveform. Each row is a frame.
Inputs:
sig (Tensor): Tensor of size num_samples
window_size (int): Frame length
window_shift (int): Frame shift
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends.
Output:
Tensor: 2D tensor of size (m, window_size) where each row is a frame
"""
assert waveform.dim() == 1
num_samples = waveform.size(0)
strides = (window_shift * waveform.stride(0), waveform.stride(0))
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0))
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = torch.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
def _feature_window_function(window_type, window_size, blackman_coeff):
""" Returns a window function with the given type and size
"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, dtype=torch.get_default_dtype())
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, dtype=torch.get_default_dtype())
# can't use torch.blackman_window as they use different coefficients
return blackman_coeff - 0.5 * torch.cos(a * window_function) + \
(0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
else:
raise Exception('Invalid window type ' + window_type)
def _get_log_energy(strided_input, epsilon, energy_floor):
""" Returns the log energy of size (m) for a strided_input (m,*)
"""
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
else:
return torch.max(log_energy,
torch.tensor(math.log(energy_floor), dtype=torch.get_default_dtype()))
def spectrogram(
sig, blackman_coeff=0.42, channel=-1, dither=1.0, energy_floor=0.0,
frame_length=25.0, frame_shift=10.0, min_duration=0.0,
preemphasis_coefficient=0.97, raw_energy=True, remove_dc_offset=True,
round_to_power_of_two=True, sample_frequency=16000.0, snip_edges=True,
subtract_mean=False, window_type=POVEY):
"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
Inputs:
sig (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float): Constant coefficient for generalized Blackman window. (default = 0.42)
channel (int): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (default = -1)
dither (float): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (default = 1.0)
energy_floor (float): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (default = 0.0)
frame_length (float): Frame length in milliseconds (default = 25.0)
frame_shift (float): Frame shift in milliseconds (default = 10.0)
min_duration (float): Minimum duration of segments to process (in seconds). (default = 0.0)
preemphasis_coefficient (float): Coefficient for use in signal preemphasis (default = 0.97)
raw_energy (bool): If True, compute energy before preemphasis and windowing (default = True)
remove_dc_offset: Subtract mean from waveform on each frame (default = True)
round_to_power_of_two (bool): If True, round window size to power of two by zero-padding input
to FFT. (default = True)
sample_frequency (float): Waveform data sample frequency (must match the waveform file, if
specified there) (default = 16000.0)
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (default = true)
subtract_mean (bool): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (default = False)
window_type (str): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman') (default = 'povey')
Outputs:
Tensor: a spectrogram identical to what Kaldi would output. The shape is (, `padded_window_size` // 2 + 1)
"""
waveform = sig[max(channel, 0), :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), ('choose a window size %d that is [2, %d]' % (window_size, len(waveform)))
assert 0 < window_shift, '`window_shift` must be greater than 0'
assert padded_window_size % 2 == 0, 'the padded ' \
'`window_size` must be divisible by two. use `round_to_power_of_two` or change `frame_length`'
assert 0. <= preemphasis_coefficient <= 1.0, '`preemphasis_coefficient` must be between [0,1]'
assert sample_frequency > 0, '`sample_frequency` must be greater than zero'
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
# Returns a random number strictly between 0 and 1
x = torch.max(EPSILON, torch.rand(strided_input.shape))
rand_gauss = torch.sqrt(-2 * x.log()) * torch.cos(2 * math.pi * x)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (1, 0), mode='replicate').squeeze(0) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(
window_type, window_size, blackman_coeff).unsqueeze(0) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode='constant', value=0).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, EPSILON, energy_floor) # size (m)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.rfft(strided_input, 1, normalized=False, onesided=True)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.pow(2).sum(2), EPSILON).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
if subtract_mean:
col_means = torch.mean(power_spectrum, dim=0).unsqueeze(0) # size (1, padded_window_size // 2 + 1)
power_spectrum = power_spectrum - col_means
return power_spectrum
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment