global_stats.py 2.78 KB
Newer Older
1
"""Generate feature statistics for training set.
2
3

Example:
4
python global_stats.py --model_type librispeech --dataset_path /home/librispeech
5
6
7
8
9
10
11
12
13
"""

import json
import logging
import pathlib
from argparse import ArgumentParser, RawTextHelpFormatter

import torch
import torchaudio
14
from common import GAIN, MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3, piecewise_linear_log, spectrogram_transform
15
16
17
18
19
20

logger = logging.getLogger()


def parse_args():
    parser = ArgumentParser(description=__doc__, formatter_class=RawTextHelpFormatter)
21
    parser.add_argument("--model_type", type=str, choices=[MODEL_TYPE_LIBRISPEECH, MODEL_TYPE_TEDLIUM3], required=True)
22
    parser.add_argument(
23
        "--dataset_path",
24
25
        required=True,
        type=pathlib.Path,
26
27
        help="Path to dataset. "
        "For LibriSpeech, all of 'train-clean-360', 'train-clean-100', and 'train-other-500' must exist.",
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
    )
    parser.add_argument(
        "--output_path",
        default=pathlib.Path("global_stats.json"),
        type=pathlib.Path,
        help="File to save feature statistics to. (Default: './global_stats.json')",
    )
    return parser.parse_args()


def generate_statistics(samples):
    E_x = 0
    E_x_2 = 0
    N = 0

    for idx, sample in enumerate(samples):
        mel_spec = spectrogram_transform(sample[0].squeeze()).transpose(1, 0)
        scaled_mel_spec = piecewise_linear_log(mel_spec * GAIN)
        sum = scaled_mel_spec.sum(0)
        sq_sum = scaled_mel_spec.pow(2).sum(0)
        M = scaled_mel_spec.size(0)

        E_x = E_x * (N / (N + M)) + sum / (N + M)
        E_x_2 = E_x_2 * (N / (N + M)) + sq_sum / (N + M)
        N += M

        if idx % 100 == 0:
            logger.info(f"Processed {idx}")

    return E_x, (E_x_2 - E_x ** 2) ** 0.5


60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
def get_dataset(args):
    if args.model_type == MODEL_TYPE_LIBRISPEECH:
        return torch.utils.data.ConcatDataset(
            [
                torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-360"),
                torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-clean-100"),
                torchaudio.datasets.LIBRISPEECH(args.dataset_path, url="train-other-500"),
            ]
        )
    elif args.model_type == MODEL_TYPE_TEDLIUM3:
        return torchaudio.datasets.TEDLIUM(args.dataset_path, release="release3", subset="train")
    else:
        raise ValueError(f"Encountered unsupported model type {args.model_type}.")


75
76
def cli_main():
    args = parse_args()
77
    dataset = get_dataset(args)
78
79
80
81
82
83
84
85
86
87
88
    dataloader = torch.utils.data.DataLoader(dataset, num_workers=4)
    mean, stddev = generate_statistics(iter(dataloader))

    json_str = json.dumps({"mean": mean.tolist(), "invstddev": (1 / stddev).tolist()}, indent=2)

    with open(args.output_path, "w") as f:
        f.write(json_str)


if __name__ == "__main__":
    cli_main()