common_utils.py 3.31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# https://github.com/pytorch/fairseq/blob/265df7144c79446f5ea8d835bda6e727f54dad9d/LICENSE
"""
Data pre-processing: create tsv files for training (and valiation).
"""
import logging
import re
from pathlib import Path
from typing import (
    Tuple,
    Union,
)

import torch
import torchaudio


_LG = logging.getLogger(__name__)


def create_tsv(
    root_dir: Union[str, Path],
    out_dir: Union[str, Path],
    dataset: str = "librispeech",
    valid_percent: float = 0.01,
    seed: int = 0,
    extension: str = "flac",
) -> None:
    """Create file lists for training and validation.
    Args:
        root_dir (str or Path): The directory of the dataset.
        out_dir (str or Path): The directory to store the file lists.
        dataset (str, optional): The dataset to use. Options:
            [``librispeech``, ``libri-light``]. (Default: ``librispeech``)
        valid_percent (float, optional): The percentage of data for validation. (Default: 0.01)
        seed (int): The seed for randomly selecting the validation files.
40
        extension (str, optional): The extension of audio files. (Default: ``flac``)
41
42
43
44
45
46
47
48
49
50
51
52
53

    Returns:
        None
    """
    assert valid_percent >= 0 and valid_percent <= 1.0

    torch.manual_seed(seed)
    root_dir = Path(root_dir)
    out_dir = Path(out_dir)

    if not out_dir.exists():
        out_dir.mkdir()

54
    valid_f = open(out_dir / f"{dataset}_valid.tsv", "w") if valid_percent > 0 else None
55
56
57
58
59
60
61
62
63
64
65
    search_pattern = ".*train.*"
    with open(out_dir / f"{dataset}_train.tsv", "w") as train_f:
        print(root_dir, file=train_f)

        if valid_f is not None:
            print(root_dir, file=valid_f)

        for fname in root_dir.glob(f"**/*.{extension}"):
            if re.match(search_pattern, str(fname)):
                frames = torchaudio.info(fname).num_frames
                dest = train_f if torch.rand(1) > valid_percent else valid_f
66
                print(f"{fname.relative_to(root_dir)}\t{frames}", file=dest)
67
68
69
70
71
    if valid_f is not None:
        valid_f.close()
    _LG.info("Finished creating the file lists successfully")


72
def _get_feat_lens_paths(feat_dir: Path, split: str, rank: int, num_rank: int) -> Tuple[Path, Path]:
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
    r"""Get the feature and lengths paths based on feature directory,
        data split, rank, and number of ranks.
    Args:
        feat_dir (Path): The directory that stores the feature and lengths tensors.
        split (str): The split of data. Options: [``train``, ``valid``].
        rank (int): The rank in the multi-processing.
        num_rank (int): The number of ranks for multi-processing in feature extraction.

    Returns:
        (Path, Path)
        Path: The file path of the feature tensor for the current rank.
        Path: The file path of the lengths tensor for the current rank.
    """
    feat_path = feat_dir / f"{split}_{rank}_{num_rank}.pt"
    len_path = feat_dir / f"len_{split}_{rank}_{num_rank}.pt"
    return feat_path, len_path


91
def _get_model_path(km_dir: Path) -> Path:
92
93
94
95
96
97
98
99
    r"""Get the file path of the KMeans clustering model
    Args:
        km_dir (Path): The directory to store the KMeans clustering model.

    Returns:
        Path: The file path of the model.
    """
    return km_dir / "model.pt"