generate_test_stft_data.py 3.34 KB
Newer Older
jamarshon's avatar
jamarshon committed
1
import argparse
jamarshon's avatar
jamarshon committed
2
import logging
jamarshon's avatar
jamarshon committed
3
import os
4
import random
jamarshon's avatar
jamarshon committed
5
6
import subprocess
import utils
7
8


jamarshon's avatar
jamarshon committed
9
10
def run(exe_path, scp_path, out_dir, wave_len, num_outputs, remove_files, log_level):
    logging.basicConfig(level=log_level)
jamarshon's avatar
jamarshon committed
11
    for i in range(num_outputs):
12
13
14
15
        inputs = {
            'blackman_coeff': '%.4f' % (random.random() * 5),
            'dither': '0',
            'energy_floor': '%.4f' % (random.random() * 5),
jamarshon's avatar
jamarshon committed
16
17
            'frame_length': '%.4f' % (float(random.randint(2, wave_len - 1)) / 16000 * 1000),
            'frame_shift': '%.4f' % (float(random.randint(1, wave_len - 1)) / 16000 * 1000),
18
            'preemphasis_coefficient': '%.2f' % random.random(),
jamarshon's avatar
jamarshon committed
19
20
21
22
23
24
            'raw_energy': utils.generate_rand_boolean(),
            'remove_dc_offset': utils.generate_rand_boolean(),
            'round_to_power_of_two': utils.generate_rand_boolean(),
            'snip_edges': utils.generate_rand_boolean(),
            'subtract_mean': utils.generate_rand_boolean(),
            'window_type': utils.generate_rand_window_type()
25
26
27
28
        }

        fn = 'spec-' + ('-'.join(list(inputs.values())))

jamarshon's avatar
jamarshon committed
29
30
31
32
33
        out_fn = out_dir + fn + '.ark'

        arg = [exe_path]
        arg += ['--' + k.replace('_', '-') + '=' + inputs[k] for k in inputs]
        arg += [scp_path, out_fn]
34

jamarshon's avatar
jamarshon committed
35
36
37
        logging.info(fn)
        logging.info(inputs)
        logging.info(' '.join(arg))
38
39

        try:
jamarshon's avatar
jamarshon committed
40
            if log_level == 'INFO':
jamarshon's avatar
jamarshon committed
41
42
43
                subprocess.call(arg)
            else:
                subprocess.call(arg, stderr=open(os.devnull, 'wb'), stdout=open(os.devnull, 'wb'))
jamarshon's avatar
jamarshon committed
44
            logging.info('success')
45
        except Exception:
jamarshon's avatar
jamarshon committed
46
            if remove_files and os.path.exists(out_fn):
jamarshon's avatar
jamarshon committed
47
                os.remove(out_fn)
48
49
50


if __name__ == '__main__':
jamarshon's avatar
jamarshon committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
    """ Examples:
    >> python test/compliance/generate_test_stft_data.py \
        --exe_path=/scratch/jamarshon/kaldi/src/featbin/compute-spectrogram-feats \
        --scp_path=scp:/scratch/jamarshon/downloads/a.scp \
        --out_dir=ark:/scratch/jamarshon/audio/test/assets/kaldi/
    """
    parser = argparse.ArgumentParser(description='Generate spectrogram data using Kaldi.')
    parser.add_argument('--exe_path', type=str, required=True, help='Path to the compute-spectrogram-feats executable.')
    parser.add_argument('--scp_path', type=str, required=True, help='Path to the scp file. An example of its contents would be \
    "my_id /scratch/jamarshon/audio/test/assets/kaldi_file.wav". where the space separates an id from a wav file.')
    parser.add_argument('--out_dir', type=str, required=True,
                        help='The directory to which the stft features will be written to.')

    # run arguments
    parser.add_argument('--wave_len', type=int, default=20,
                        help='The number of samples inside the input wave file read from `scp_path`')
    parser.add_argument('--num_outputs', type=int, default=100, help='How many output files should be generated.')
jamarshon's avatar
jamarshon committed
68
69
70
71
    parser.add_argument('--remove_files', type=bool, default=False,
                        help='Whether to remove files generated from exception')
    parser.add_argument('--log_level', type=str, default='WARNING',
                        help='Log level (DEBUG|INFO|WARNING|ERROR|CRITICAL)')
jamarshon's avatar
jamarshon committed
72
73

    args = parser.parse_args()
jamarshon's avatar
jamarshon committed
74
75
    run(args.exe_path, args.scp_path, args.out_dir, args.wave_len, args.num_outputs,
        args.remove_files, args.log_level)