utils.py 4.73 KB
Newer Older
hungchiayu1's avatar
hungchiayu1 committed
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd

import torchaudio
import random
import itertools
import numpy as np


import numpy as np

mrfakename's avatar
mrfakename committed
14

hungchiayu1's avatar
hungchiayu1 committed
15
16
17
18
19
20
21
22
def normalize_wav(waveform):
    waveform = waveform - torch.mean(waveform)
    waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8)
    return waveform * 0.5


def pad_wav(waveform, segment_length):
    waveform_length = len(waveform)
mrfakename's avatar
mrfakename committed
23

hungchiayu1's avatar
hungchiayu1 committed
24
25
26
27
28
    if segment_length is None or waveform_length == segment_length:
        return waveform
    elif waveform_length > segment_length:
        return waveform[:segment_length]
    else:
hungchiayu1's avatar
update  
hungchiayu1 committed
29
30
        padded_wav = torch.zeros(segment_length - waveform_length).to(waveform.device)
        waveform = torch.cat([waveform, padded_wav])
hungchiayu1's avatar
hungchiayu1 committed
31
32
33
        return waveform


hungchiayu1's avatar
update  
hungchiayu1 committed
34
def read_wav_file(filename, duration_sec):
hungchiayu1's avatar
hungchiayu1 committed
35
36
    info = torchaudio.info(filename)
    sample_rate = info.sample_rate
mrfakename's avatar
mrfakename committed
37

hungchiayu1's avatar
hungchiayu1 committed
38
39
    # Calculate the number of frames corresponding to the desired duration
    num_frames = int(sample_rate * duration_sec)
hungchiayu1's avatar
update  
hungchiayu1 committed
40

mrfakename's avatar
mrfakename committed
41
    waveform, sr = torchaudio.load(filename, num_frames=num_frames)  # Faster!!!
hungchiayu1's avatar
hungchiayu1 committed
42

mrfakename's avatar
mrfakename committed
43
    if waveform.shape[0] == 2:  ## Stereo audio
hungchiayu1's avatar
hungchiayu1 committed
44
45
        resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=44100)
        resampled_waveform = resampler(waveform)
mrfakename's avatar
mrfakename committed
46
47
48
49
50
        # print(resampled_waveform.shape)
        padded_left = pad_wav(
            resampled_waveform[0], int(44100 * duration_sec)
        )  ## We pad left and right seperately
        padded_right = pad_wav(resampled_waveform[1], int(44100 * duration_sec))
hungchiayu1's avatar
hungchiayu1 committed
51

mrfakename's avatar
mrfakename committed
52
        return torch.stack([padded_left, padded_right])
hungchiayu1's avatar
hungchiayu1 committed
53
    else:
mrfakename's avatar
mrfakename committed
54
55
56
57
        waveform = torchaudio.functional.resample(
            waveform, orig_freq=sr, new_freq=44100
        )[0]
        waveform = pad_wav(waveform, int(44100 * duration_sec)).unsqueeze(0)
hungchiayu1's avatar
hungchiayu1 committed
58
59
60
61
62

        return waveform


class DPOText2AudioDataset(Dataset):
mrfakename's avatar
mrfakename committed
63
64
65
66
67
68
69
70
71
72
    def __init__(
        self,
        dataset,
        prefix,
        text_column,
        audio_w_column,
        audio_l_column,
        duration,
        num_examples=-1,
    ):
hungchiayu1's avatar
hungchiayu1 committed
73
74
75
76
77
78
79
80
81

        inputs = list(dataset[text_column])
        self.inputs = [prefix + inp for inp in inputs]
        self.audios_w = list(dataset[audio_w_column])
        self.audios_l = list(dataset[audio_l_column])
        self.durations = list(dataset[duration])
        self.indices = list(range(len(self.inputs)))

        self.mapper = {}
mrfakename's avatar
mrfakename committed
82
83
84
        for index, audio_w, audio_l, duration, text in zip(
            self.indices, self.audios_w, self.audios_l, self.durations, inputs
        ):
hungchiayu1's avatar
hungchiayu1 committed
85
86
87
            self.mapper[index] = [audio_w, audio_l, duration, text]

        if num_examples != -1:
mrfakename's avatar
mrfakename committed
88
89
90
91
92
93
            self.inputs, self.audios_w, self.audios_l, self.durations = (
                self.inputs[:num_examples],
                self.audios_w[:num_examples],
                self.audios_l[:num_examples],
                self.durations[:num_examples],
            )
hungchiayu1's avatar
hungchiayu1 committed
94
95
96
97
98
99
100
101
102
            self.indices = self.indices[:num_examples]

    def __len__(self):
        return len(self.inputs)

    def get_num_instances(self):
        return len(self.inputs)

    def __getitem__(self, index):
mrfakename's avatar
mrfakename committed
103
104
105
106
107
108
109
        s1, s2, s3, s4, s5 = (
            self.inputs[index],
            self.audios_w[index],
            self.audios_l[index],
            self.durations[index],
            self.indices[index],
        )
hungchiayu1's avatar
hungchiayu1 committed
110
111
112
113
114
115
        return s1, s2, s3, s4, s5

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [dat[i].tolist() for i in dat]

mrfakename's avatar
mrfakename committed
116

hungchiayu1's avatar
hungchiayu1 committed
117
class Text2AudioDataset(Dataset):
mrfakename's avatar
mrfakename committed
118
119
120
    def __init__(
        self, dataset, prefix, text_column, audio_column, duration, num_examples=-1
    ):
hungchiayu1's avatar
hungchiayu1 committed
121
122
123
124
125
126
127
128

        inputs = list(dataset[text_column])
        self.inputs = [prefix + inp for inp in inputs]
        self.audios = list(dataset[audio_column])
        self.durations = list(dataset[duration])
        self.indices = list(range(len(self.inputs)))

        self.mapper = {}
mrfakename's avatar
mrfakename committed
129
130
131
132
        for index, audio, duration, text in zip(
            self.indices, self.audios, self.durations, inputs
        ):
            self.mapper[index] = [audio, text, duration]
hungchiayu1's avatar
hungchiayu1 committed
133
134

        if num_examples != -1:
mrfakename's avatar
mrfakename committed
135
136
137
138
139
            self.inputs, self.audios, self.durations = (
                self.inputs[:num_examples],
                self.audios[:num_examples],
                self.durations[:num_examples],
            )
hungchiayu1's avatar
hungchiayu1 committed
140
141
142
143
144
145
146
147
148
            self.indices = self.indices[:num_examples]

    def __len__(self):
        return len(self.inputs)

    def get_num_instances(self):
        return len(self.inputs)

    def __getitem__(self, index):
mrfakename's avatar
mrfakename committed
149
150
151
152
153
154
        s1, s2, s3, s4 = (
            self.inputs[index],
            self.audios[index],
            self.durations[index],
            self.indices[index],
        )
hungchiayu1's avatar
hungchiayu1 committed
155
156
157
158
159
        return s1, s2, s3, s4

    def collate_fn(self, data):
        dat = pd.DataFrame(data)
        return [dat[i].tolist() for i in dat]