Commit bd00146d authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Quick cleanup of kaldi fbank

parent 4f7886d1
import argparse
import logging
import os
import random
import subprocess
......@@ -7,7 +8,8 @@ import torchaudio
import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for i in range(num_outputs):
try:
nyquist = 16000 // 2
......@@ -57,18 +59,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += ['--dither=0.0', scp_path, out_fn]
print(fn)
print(inputs)
print(' '.join(arg))
logging.info(fn)
logging.info(inputs)
logging.info(' '.join(arg))
try:
if verbose:
if log_level == 'INFO':
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success')
logging.info('success')
except Exception:
if os.path.exists(out_fn):
if remove_files and os.path.exists(out_fn):
os.remove(out_fn)
......@@ -95,17 +97,17 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
# print flags for C++
s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
print(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
print()
logging.info(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
logging.info()
# print args for python
inputs['dither'] = 0.0
print(inputs)
logging.info(inputs)
sound, sample_rate = torchaudio.load_wav(sound_path)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
torch.set_printoptions(precision=10, sci_mode=False)
print(res)
print(kaldi_output_dict['my_id'])
logging.info(res)
logging.info(kaldi_output_dict['my_id'])
if __name__ == '__main__':
......@@ -134,7 +136,10 @@ if __name__ == '__main__':
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.')
parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
# decode arguments
parser.add_argument('--decode', type=bool, default=False, help='Whether to run the decode or run function.')
......@@ -145,4 +150,5 @@ if __name__ == '__main__':
if args.decode:
decode(args.fn, args.sound_path, args.exe_path, args.scp_path, args.out_dir)
else:
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose)
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
import argparse
import logging
import os
import random
import subprocess
import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for i in range(num_outputs):
inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5),
......@@ -30,18 +32,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += [scp_path, out_fn]
print(fn)
print(inputs)
print(' '.join(arg))
logging.info(fn)
logging.info(inputs)
logging.info(' '.join(arg))
try:
if verbose:
if log_level == 'INFO':
subprocess.call(arg)
else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success')
logging.info('success')
except Exception:
if os.path.exists(out_fn):
if remove_files and os.path.exists(out_fn):
os.remove(out_fn)
......@@ -63,7 +65,11 @@ if __name__ == '__main__':
parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.')
parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
args = parser.parse_args()
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose)
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment