create_jittable_pipeline.py 2.24 KB
Newer Older
moto's avatar
moto committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#!/usr/bin/env python3
"""
Create a data preprocess pipeline that can be run with libtorchaudio
"""
import os
import argparse

import torch
import torchaudio


class Pipeline(torch.nn.Module):
    """Example audio process pipeline.

    This example load waveform from a file then apply effects and save it to a file.
    """
moto's avatar
moto committed
17
    def __init__(self, rir_path: str):
moto's avatar
moto committed
18
        super().__init__()
moto's avatar
moto committed
19
        rir, sample_rate = torchaudio.load(rir_path)
moto's avatar
moto committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
        self.register_buffer('rir', rir)
        self.rir_sample_rate: int = sample_rate

    def forward(self, input_path: str, output_path: str):
        torchaudio.sox_effects.init_sox_effects()

        # 1. load audio
        waveform, sample_rate = torchaudio.load(input_path)

        # 2. Add background noise
        alpha = 0.01
        waveform = alpha * torch.randn_like(waveform) + (1 - alpha) * waveform

        # 3. Reample the RIR filter to much the audio sample rate
        rir, _ = torchaudio.sox_effects.apply_effects_tensor(
            self.rir, self.rir_sample_rate, effects=[["rate", str(sample_rate)]])
        rir = rir / torch.norm(rir, p=2)
        rir = torch.flip(rir, [1])

        # 4. Apply RIR filter
        waveform = torch.nn.functional.pad(waveform, (rir.shape[1] - 1, 0))
        waveform = torch.nn.functional.conv1d(waveform[None, ...], rir[None, ...])[0]

        # Save
        torchaudio.save(output_path, waveform, sample_rate)


moto's avatar
moto committed
47
48
def _create_jit_pipeline(rir_path, output_path):
    module = torch.jit.script(Pipeline(rir_path))
moto's avatar
moto committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    print("*" * 40)
    print("* Pipeline code")
    print("*" * 40)
    print()
    print(module.code)
    print("*" * 40)
    module.save(output_path)


def _get_path(*paths):
    return os.path.join(os.path.dirname(__file__), *paths)


def _parse_args():
    parser = argparse.ArgumentParser(description=__doc__)
moto's avatar
moto committed
64
65
66
67
68
    parser.add_argument(
        "--rir-path",
        default=_get_path("..", "data", "rir.wav"),
        help="Audio dara for room impulse response."
    )
moto's avatar
moto committed
69
70
    parser.add_argument(
        "--output-path",
moto's avatar
moto committed
71
        default=_get_path("pipeline.zip"),
moto's avatar
moto committed
72
73
74
75
76
77
78
        help="Output JIT file."
    )
    return parser.parse_args()


def _main():
    args = _parse_args()
moto's avatar
moto committed
79
    _create_jit_pipeline(args.rir_path, args.output_path)
moto's avatar
moto committed
80
81
82
83


if __name__ == '__main__':
    _main()