Commit 4f7886d1 authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Kaldi Fbank (#127)

parent 9bd633e3
import argparse
import os
import random
import subprocess
import torch
import torchaudio
import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
for i in range(num_outputs):
try:
nyquist = 16000 // 2
high_freq = random.randint(1, nyquist)
low_freq = random.randint(0, high_freq - 1)
vtln_low = random.randint(low_freq + 1, high_freq - 1)
vtln_high = random.randint(vtln_low + 1, high_freq - 1)
vtln_warp_factor = random.uniform(0.0, 10.0) if random.random() < 0.3 else 1.0
except Exception:
continue
if not ((0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)):
continue
if not (vtln_warp_factor == 1.0 or ((low_freq < vtln_low < high_freq) and
(0.0 < vtln_high < high_freq) and (vtln_low < vtln_high))):
continue
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(3, wave_len - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, wave_len - 1)) / 16000 * 1000),
'high_freq': str(high_freq),
'htk_compat': utils.generate_rand_boolean(),
'low_freq': str(low_freq),
'num_mel_bins': str(random.randint(4, 8)),
'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': utils.generate_rand_boolean(),
'remove_dc_offset': utils.generate_rand_boolean(),
'round_to_power_of_two': utils.generate_rand_boolean(),
'snip_edges': utils.generate_rand_boolean(),
'subtract_mean': utils.generate_rand_boolean(),
'use_energy': utils.generate_rand_boolean(),
'use_log_fbank': utils.generate_rand_boolean(),
'use_power': utils.generate_rand_boolean(),
'vtln_high': str(vtln_high),
'vtln_low': str(vtln_low),
'vtln_warp': '%.4f' % (vtln_warp_factor),
'window_type': utils.generate_rand_window_type()
}
fn = 'fbank-' + ('-'.join(list(inputs.values())))
out_fn = out_dir + fn + '.ark'
arg = [exe_path]
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += ['--dither=0.0', scp_path, out_fn]
print(fn)
print(inputs)
print(' '.join(arg))
try:
if verbose:
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success')
except Exception:
if os.path.exists(out_fn):
os.remove(out_fn)
def decode(fn, sound_path, exe_path, scp_path, out_dir):
"""
Takes a filepath and prints out the corresponding shell command to run that specific
kaldi configuration. It also calls compliance.kaldi and prints the two outputs.
Example:
>> fn = 'fbank-1.1009-2.5985-1.1875-0.8750-5723-true-918-4-0.31-true-false-true-true-' \
'false-false-false-true-4595-4281-1.0000-hamming.ark'
>> decode(fn)
"""
out_fn = out_dir + fn
fn = fn[len('fbank-'):-len('.ark')]
arr = [
'blackman_coeff', 'energy_floor', 'frame_length', 'frame_shift', 'high_freq', 'htk_compat',
'low_freq', 'num_mel_bins', 'preemphasis_coefficient', 'raw_energy', 'remove_dc_offset',
'round_to_power_of_two', 'snip_edges', 'subtract_mean', 'use_energy', 'use_log_fbank',
'use_power', 'vtln_high', 'vtln_low', 'vtln_warp', 'window_type']
fn_split = fn.split('-')
assert len(fn_split) == len(arr), ('Len mismatch: %d and %d' % (len(fn_split), len(arr)))
inputs = {arr[i]: utils.parse(fn_split[i]) for i in range(len(arr))}
# print flags for C++
s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
print(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
print()
# print args for python
inputs['dither'] = 0.0
print(inputs)
sound, sample_rate = torchaudio.load_wav(sound_path)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
torch.set_printoptions(precision=10, sci_mode=False)
print(res)
print(kaldi_output_dict['my_id'])
if __name__ == '__main__':
""" Examples:
>> python test/compliance/generate_fbank_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-fbank-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/
>> python test/compliance/generate_fbank_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-fbank-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/ \
--decode=true \
--sound_path=/scratch/jamarshon/audio/test/assets/kaldi_file.wav \
--fn="fbank-1.1009-2.5985-1.1875-0.8750-5723-true-918-4-0.31-true-false-true-
true-false-false-false-true-4595-4281-1.0000-hamming.ark"
"""
parser = argparse.ArgumentParser(description='Generate fbank data using Kaldi.')
parser.add_argument('--exe_path', type=str, required=True, help='Path to the compute-fbank-feats executable.')
parser.add_argument('--scp_path', type=str, required=True, help='Path to the scp file. An example of its contents would be \
"my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav". where the space separates an id from a wav file.')
parser.add_argument('--out_dir', type=str, required=True,
help='The directory to which the stft features will be written to.')
# run arguments
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.')
# decode arguments
parser.add_argument('--decode', type=bool, default=False, help='Whether to run the decode or run function.')
parser.add_argument('--fn', type=str, help='Filepath to decode.')
parser.add_argument('--sound_path', type=str, help='Sound filepath to decode.')
args = parser.parse_args()
if args.decode:
decode(args.fn, args.sound_path, args.exe_path, args.scp_path, args.out_dir)
else:
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose)
import argparse
import os
import random import random
import subprocess
import utils
# Path to the compute-spectrogram-feats executable.
EXE_PATH = '/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats'
# Path to the scp file. An example of its contents would be "my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav" def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
# where the space separates an id from a wav file. for i in range(num_outputs):
SCP_PATH = 'scp:/scratch/jamarshon/downloads/a.scp'
# The directory to which the stft features will be written to.
OUTPUT_DIR = 'ark:/scratch/jamarshon/audio/test/assets/kaldi/'
# The number of samples inside the input wave file read from `SCP_PATH`
WAV_LEN = 20
# How many output files should be generated.
NUMBER_OF_OUTPUTS = 100
WINDOWS = ['hamming', 'hanning', 'povey', 'rectangular', 'blackman']
def generate_rand_boolean():
# Generates a random boolean ('true', 'false')
return 'true' if random.randint(0, 1) else 'false'
def generate_rand_window_type():
# Generates a random window type
return WINDOWS[random.randint(0, len(WINDOWS) - 1)]
def run():
for i in range(NUMBER_OF_OUTPUTS):
inputs = { inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5), 'blackman_coeff': '%.4f' % (random.random() * 5),
'dither': '0', 'dither': '0',
'energy_floor': '%.4f' % (random.random() * 5), 'energy_floor': '%.4f' % (random.random() * 5),
'frame_length': '%.4f' % (float(random.randint(2, WAV_LEN - 1)) / 16000 * 1000), 'frame_length': '%.4f' % (float(random.randint(2, wave_len - 1)) / 16000 * 1000),
'frame_shift': '%.4f' % (float(random.randint(1, WAV_LEN - 1)) / 16000 * 1000), 'frame_shift': '%.4f' % (float(random.randint(1, wave_len - 1)) / 16000 * 1000),
'preemphasis_coefficient': '%.2f' % random.random(), 'preemphasis_coefficient': '%.2f' % random.random(),
'raw_energy': generate_rand_boolean(), 'raw_energy': utils.generate_rand_boolean(),
'remove_dc_offset': generate_rand_boolean(), 'remove_dc_offset': utils.generate_rand_boolean(),
'round_to_power_of_two': generate_rand_boolean(), 'round_to_power_of_two': utils.generate_rand_boolean(),
'snip_edges': generate_rand_boolean(), 'snip_edges': utils.generate_rand_boolean(),
'subtract_mean': generate_rand_boolean(), 'subtract_mean': utils.generate_rand_boolean(),
'window_type': generate_rand_window_type() 'window_type': utils.generate_rand_window_type()
} }
fn = 'spec-' + ('-'.join(list(inputs.values()))) fn = 'spec-' + ('-'.join(list(inputs.values())))
arg = [ out_fn = out_dir + fn + '.ark'
EXE_PATH,
'--blackman-coeff=' + inputs['blackman_coeff'], arg = [exe_path]
'--dither=' + inputs['dither'], arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
'--energy-floor=' + inputs['energy_floor'], arg += [scp_path, out_fn]
'--frame-length=' + inputs['frame_length'],
'--frame-shift=' + inputs['frame_shift'],
'--preemphasis-coefficient=' + inputs['preemphasis_coefficient'],
'--raw-energy=' + inputs['raw_energy'],
'--remove-dc-offset=' + inputs['remove_dc_offset'],
'--round-to-power-of-two=' + inputs['round_to_power_of_two'],
'--sample-frequency=16000',
'--snip-edges=' + inputs['snip_edges'],
'--subtract-mean=' + inputs['subtract_mean'],
'--window-type=' + inputs['window_type'],
SCP_PATH,
OUTPUT_DIR + fn + '.ark'
]
print(fn) print(fn)
print(inputs) print(inputs)
print(' '.join(arg)) print(' '.join(arg))
try: try:
subprocess.call(arg) if verbose:
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success')
except Exception: except Exception:
pass if os.path.exists(out_fn):
os.remove(out_fn)
if __name__ == '__main__': if __name__ == '__main__':
run() """ Examples:
>> python test/compliance/generate_test_stft_data.py \
--exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats \
--scp_path=scp:/scratch/jamarshon/downloads/a.scp \
--out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/
"""
parser = argparse.ArgumentParser(description='Generate spectrogram data using Kaldi.')
parser.add_argument('--exe_path', type=str, required=True, help='Path to the compute-spectrogram-feats executable.')
parser.add_argument('--scp_path', type=str, required=True, help='Path to the scp file. An example of its contents would be \
"my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav". where the space separates an id from a wav file.')
parser.add_argument('--out_dir', type=str, required=True,
help='The directory to which the stft features will be written to.')
# run arguments
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.')
args = parser.parse_args()
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose)
import math import math
import os import os
import test.common_utils import test.common_utils
import test.compliance.utils
import torch import torch
import torchaudio import torchaudio
import torchaudio.compliance.kaldi as kaldi import torchaudio.compliance.kaldi as kaldi
...@@ -46,6 +47,16 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges): ...@@ -46,6 +47,16 @@ def extract_window(window, wave, f, frame_length, frame_shift, snip_edges):
class Test_Kaldi(unittest.TestCase): class Test_Kaldi(unittest.TestCase):
test_dirpath, test_dir = test.common_utils.create_temp_assets_dir() test_dirpath, test_dir = test.common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav') test_filepath = os.path.join(test_dirpath, 'assets', 'kaldi_file.wav')
kaldi_output_dir = os.path.join(test_dirpath, 'assets', 'kaldi')
test_filepaths = {'spec': [], 'fbank': []}
# separating test files by their types (e.g 'spec', 'fbank', etc.)
for f in os.listdir(kaldi_output_dir):
dash_idx = f.find('-')
assert f.endswith('.ark') and dash_idx != -1
key = f[:dash_idx]
assert key in test_filepaths
test_filepaths[key].append(f)
def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges): def _test_get_strided_helper(self, num_samples, window_size, window_shift, snip_edges):
waveform = torch.arange(num_samples).float() waveform = torch.arange(num_samples).float()
...@@ -94,46 +105,104 @@ class Test_Kaldi(unittest.TestCase): ...@@ -94,46 +105,104 @@ class Test_Kaldi(unittest.TestCase):
self.assertTrue(sample_rate == sr) self.assertTrue(sample_rate == sr)
self.assertTrue(torch.allclose(y, sound)) self.assertTrue(torch.allclose(y, sound))
def test_spectrogram(self): def _print_diagnostic(self, output, expect_output):
# given an output and expected output, it will print the absolute/relative errors (max and mean squared)
abs_error = output - expect_output
abs_mse = abs_error.pow(2).sum() / output.numel()
abs_max_error = torch.max(abs_error.abs())
relative_error = abs_error / expect_output
relative_mse = relative_error.pow(2).sum() / output.numel()
relative_max_error = torch.max(relative_error.abs())
print('abs_mse:', abs_mse.item(), 'abs_max_error:', abs_max_error.item())
print('relative_mse:', relative_mse.item(), 'relative_max_error:', relative_max_error.item())
def _compliance_test_helper(self, filepath_key, expected_num_files, expected_num_args, get_output_fn):
"""
Inputs:
filepath_key (str): A key to `test_filepaths` which matches which files to use
expected_num_files (int): The expected number of kaldi files to read
expected_num_args (int): The expected number of arguments used in a kaldi configuration
get_output_fn (Callable[[Tensor, List], Tensor]): A function that takes in a sound signal
and a configuration and returns an output
"""
sound, sample_rate = torchaudio.load_wav(self.test_filepath) sound, sample_rate = torchaudio.load_wav(self.test_filepath)
kaldi_output_dir = os.path.join(self.test_dirpath, 'assets', 'kaldi') files = self.test_filepaths[filepath_key]
files = list(filter(lambda x: x.startswith('spec'), os.listdir(kaldi_output_dir)))
print('Results:', len(files)) assert len(files) == expected_num_files, ('number of kaldi %s file changed to %d' % (filepath_key, len(files)))
for f in files: for f in files:
print(f) print(f)
kaldi_output_path = os.path.join(kaldi_output_dir, f)
# Read kaldi's output from file
kaldi_output_path = os.path.join(self.kaldi_output_dir, f)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)} kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(kaldi_output_path)}
assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file' assert len(kaldi_output_dict) == 1 and 'my_id' in kaldi_output_dict, 'invalid test kaldi ark file'
kaldi_output = kaldi_output_dict['my_id'] kaldi_output = kaldi_output_dict['my_id']
# Construct the same configuration used by kaldi
args = f.split('-') args = f.split('-')
args[-1] = os.path.splitext(args[-1])[0] args[-1] = os.path.splitext(args[-1])[0]
assert len(args) == 13, 'invalid test kaldi file name' assert len(args) == expected_num_args, 'invalid test kaldi file name'
args = [test.compliance.utils.parse(arg) for arg in args]
spec_output = kaldi.spectrogram( output = get_output_fn(sound, args)
self._print_diagnostic(output, kaldi_output)
self.assertTrue(output.shape, kaldi_output.shape)
self.assertTrue(torch.allclose(output, kaldi_output, atol=1e-3, rtol=1e-1))
def test_spectrogram(self):
def get_output_fn(sound, args):
output = kaldi.spectrogram(
sound, sound,
blackman_coeff=float(args[1]), blackman_coeff=args[1],
dither=float(args[2]), dither=args[2],
energy_floor=float(args[3]), energy_floor=args[3],
frame_length=float(args[4]), frame_length=args[4],
frame_shift=float(args[5]), frame_shift=args[5],
preemphasis_coefficient=float(args[6]), preemphasis_coefficient=args[6],
raw_energy=args[7] == 'true', raw_energy=args[7],
remove_dc_offset=args[8] == 'true', remove_dc_offset=args[8],
round_to_power_of_two=args[9] == 'true', round_to_power_of_two=args[9],
snip_edges=args[10] == 'true', snip_edges=args[10],
subtract_mean=args[11] == 'true', subtract_mean=args[11],
window_type=args[12]) window_type=args[12])
return output
error = spec_output - kaldi_output self._compliance_test_helper('spec', 131, 13, get_output_fn)
mse = error.pow(2).sum() / spec_output.numel()
max_error = torch.max(error.abs())
print('mse:', mse.item(), 'max_error:', max_error.item()) def test_fbank(self):
self.assertTrue(spec_output.shape, kaldi_output.shape) def get_output_fn(sound, args):
self.assertTrue(torch.allclose(spec_output, kaldi_output, atol=1e-3, rtol=0)) output = kaldi.fbank(
sound,
blackman_coeff=args[1],
dither=0.0,
energy_floor=args[2],
frame_length=args[3],
frame_shift=args[4],
high_freq=args[5],
htk_compat=args[6],
low_freq=args[7],
num_mel_bins=args[8],
preemphasis_coefficient=args[9],
raw_energy=args[10],
remove_dc_offset=args[11],
round_to_power_of_two=args[12],
snip_edges=args[13],
subtract_mean=args[14],
use_energy=args[15],
use_log_fbank=args[16],
use_power=args[17],
vtln_high=args[18],
vtln_low=args[19],
vtln_warp=args[20],
window_type=args[21])
return output
self._compliance_test_helper('fbank', 97, 22, get_output_fn)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment