Commit bd00146d authored by jamarshon's avatar jamarshon Committed by cpuhrsch
Browse files

Quick cleanup of kaldi fbank

parent 4f7886d1
import argparse import argparse
import logging
import os import os
import random import random
import subprocess import subprocess
...@@ -7,7 +8,8 @@ import torchaudio ...@@ -7,7 +8,8 @@ import torchaudio
import utils import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose): def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for i in range(num_outputs): for i in range(num_outputs):
try: try:
nyquist = 16000 // 2 nyquist = 16000 // 2
...@@ -57,18 +59,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose): ...@@ -57,18 +59,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs] arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += ['--dither=0.0', scp_path, out_fn] arg += ['--dither=0.0', scp_path, out_fn]
print(fn) logging.info(fn)
print(inputs) logging.info(inputs)
print(' '.join(arg)) logging.info(' '.join(arg))
try: try:
if verbose: if log_level == 'INFO':
subprocess.call(arg) subprocess.call(arg)
else: else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb')) subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success') logging.info('success')
except Exception: except Exception:
if os.path.exists(out_fn): if remove_files and os.path.exists(out_fn):
os.remove(out_fn) os.remove(out_fn)
...@@ -95,17 +97,17 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir): ...@@ -95,17 +97,17 @@ def decode(fn, sound_path, exe_path, scp_path, out_dir):
# print flags for C++ # print flags for C++
s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))]) s = ' '.join(['--' + arr[i].replace('_', '-') + '=' + fn_split[i] for i in range(len(arr))])
print(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn) logging.info(exe_path + ' --dither=0.0 --debug-mel=true ' + s + ' ' + scp_path + ' ' + out_fn)
print() logging.info()
# print args for python # print args for python
inputs['dither'] = 0.0 inputs['dither'] = 0.0
print(inputs) logging.info(inputs)
sound, sample_rate = torchaudio.load_wav(sound_path) sound, sample_rate = torchaudio.load_wav(sound_path)
kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)} kaldi_output_dict = {k: v for k, v in torchaudio.kaldi_io.read_mat_ark(out_fn)}
res = torchaudio.compliance.kaldi.fbank(sound, **inputs) res = torchaudio.compliance.kaldi.fbank(sound, **inputs)
torch.set_printoptions(precision=10, sci_mode=False) torch.set_printoptions(precision=10, sci_mode=False)
print(res) logging.info(res)
print(kaldi_output_dict['my_id']) logging.info(kaldi_output_dict['my_id'])
if __name__ == '__main__': if __name__ == '__main__':
...@@ -134,7 +136,10 @@ if __name__ == '__main__': ...@@ -134,7 +136,10 @@ if __name__ == '__main__':
parser.add_argument('--wave_len', type=int, default=20, parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`') help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.') parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.') parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
# decode arguments # decode arguments
parser.add_argument('--decode', type=bool, default=False, help='Whether to run the decode or run function.') parser.add_argument('--decode', type=bool, default=False, help='Whether to run the decode or run function.')
...@@ -145,4 +150,5 @@ if __name__ == '__main__': ...@@ -145,4 +150,5 @@ if __name__ == '__main__':
if args.decode: if args.decode:
decode(args.fn, args.sound_path, args.exe_path, args.scp_path, args.out_dir) decode(args.fn, args.sound_path, args.exe_path, args.scp_path, args.out_dir)
else: else:
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose) run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
import argparse import argparse
import logging
import os import os
import random import random
import subprocess import subprocess
import utils import utils
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose): def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
logging.basicConfig(level=log_level)
for i in range(num_outputs): for i in range(num_outputs):
inputs = { inputs = {
'blackman_coeff': '%.4f' % (random.random() * 5), 'blackman_coeff': '%.4f' % (random.random() * 5),
...@@ -30,18 +32,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose): ...@@ -30,18 +32,18 @@ def run(exe_path, scp_path, out_dir, wave_len, num_outputs, verbose):
arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs] arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
arg += [scp_path, out_fn] arg += [scp_path, out_fn]
print(fn) logging.info(fn)
print(inputs) logging.info(inputs)
print(' '.join(arg)) logging.info(' '.join(arg))
try: try:
if verbose: if log_level == 'INFO':
subprocess.call(arg) subprocess.call(arg)
else: else:
subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb')) subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
print('success') logging.info('success')
except Exception: except Exception:
if os.path.exists(out_fn): if remove_files and os.path.exists(out_fn):
os.remove(out_fn) os.remove(out_fn)
...@@ -63,7 +65,11 @@ if __name__ == '__main__': ...@@ -63,7 +65,11 @@ if __name__ == '__main__':
parser.add_argument('--wave_len', type=int, default=20, parser.add_argument('--wave_len', type=int, default=20,
help='The number of samples inside the input wave file read from `scp_path`') help='The number of samples inside the input wave file read from `scp_path`')
parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.') parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
parser.add_argument('--verbose', type=bool, default=False, help='Whether to print information.') parser.add_argument('--remove_files', type=bool, default=False,
help='Whether to remove files generated from exception')
parser.add_argument('--log_level', type=str, default='WARNING',
help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
args = parser.parse_args() args = parser.parse_args()
run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs, args.verbose) run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
args.remove_files, args.log_level)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment