lightning.py 12 KB
Newer Older
1
2
3
import json
import math
import os
4
from collections import namedtuple
5
6
7
8
9
10
from typing import List, Tuple

import sentencepiece as spm
import torch
import torchaudio
import torchaudio.functional as F
11
from pytorch_lightning import LightningModule
12
from torchaudio.models import Hypothesis, RNNTBeamSearch, emformer_rnnt_base
13
from utils import GAIN, piecewise_linear_log, spectrogram_transform
14
15


16
Batch = namedtuple("Batch", ["features", "feature_lengths", "targets", "target_lengths"])
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


def _batch_by_token_count(idx_target_lengths, token_limit):
    batches = []
    current_batch = []
    current_token_count = 0
    for idx, target_length in idx_target_lengths:
        if current_token_count + target_length > token_limit:
            batches.append(current_batch)
            current_batch = [idx]
            current_token_count = target_length
        else:
            current_batch.append(idx)
            current_token_count += target_length

    if current_batch:
        batches.append(current_batch)

    return batches


class CustomDataset(torch.utils.data.Dataset):
    r"""Sort samples by target length and batch to max token count."""

    def __init__(self, base_dataset, max_token_limit):
        super().__init__()
        self.base_dataset = base_dataset

        fileid_to_target_length = {}
        idx_target_lengths = [
            (idx, self._target_length(fileid, fileid_to_target_length))
            for idx, fileid in enumerate(self.base_dataset._walker)
        ]

        assert len(idx_target_lengths) > 0

53
        idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1], reverse=True)
54
55
56
57
58
59
60
61
62
63

        assert max_token_limit >= idx_target_lengths[0][1]

        self.batches = _batch_by_token_count(idx_target_lengths, max_token_limit)

    def _target_length(self, fileid, fileid_to_target_length):
        if fileid not in fileid_to_target_length:
            speaker_id, chapter_id, _ = fileid.split("-")

            file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt
64
            file_text = os.path.join(self.base_dataset._path, speaker_id, chapter_id, file_text)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

            with open(file_text) as ft:
                for line in ft:
                    fileid_text, transcript = line.strip().split(" ", 1)
                    fileid_to_target_length[fileid_text] = len(transcript)

        return fileid_to_target_length[fileid]

    def __getitem__(self, idx):
        return [self.base_dataset[subidx] for subidx in self.batches[idx]]

    def __len__(self):
        return len(self.batches)


class TimeMasking(torchaudio.transforms._AxisMasking):
81
    def __init__(self, time_mask_param: int, min_mask_p: float, iid_masks: bool = False) -> None:
82
83
84
85
86
        super(TimeMasking, self).__init__(time_mask_param, 2, iid_masks)
        self.min_mask_p = min_mask_p

    def forward(self, specgram: torch.Tensor, mask_value: float = 0.0) -> torch.Tensor:
        if self.iid_masks and specgram.dim() == 4:
87
88
            mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis + 1])
            return F.mask_along_axis_iid(specgram, mask_param, mask_value, self.axis + 1)
89
        else:
90
            mask_param = min(self.mask_param, self.min_mask_p * specgram.shape[self.axis])
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
            return F.mask_along_axis(specgram, mask_param, mask_value, self.axis)


class FunctionalModule(torch.nn.Module):
    def __init__(self, functional):
        super().__init__()
        self.functional = functional

    def forward(self, input):
        return self.functional(input)


class GlobalStatsNormalization(torch.nn.Module):
    def __init__(self, global_stats_path):
        super().__init__()

        with open(global_stats_path) as f:
            blob = json.loads(f.read())

        self.mean = torch.tensor(blob["mean"])
        self.invstddev = torch.tensor(blob["invstddev"])

    def forward(self, input):
        return (input - self.mean) * self.invstddev


class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, warmup_updates, last_epoch=-1, verbose=False):
        self.warmup_updates = warmup_updates
        super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
123
        return [(min(1.0, self._step_count / self.warmup_updates)) * base_lr for base_lr in self.base_lrs]
124
125
126
127
128
129
130
131
132
133
134


def post_process_hypos(
    hypos: List[Hypothesis], sp_model: spm.SentencePieceProcessor
) -> List[Tuple[str, float, List[int], List[int]]]:
    post_process_remove_list = [
        sp_model.unk_id(),
        sp_model.eos_id(),
        sp_model.pad_id(),
    ]
    filtered_hypo_tokens = [
135
        [token_index for token_index in h.tokens[1:] if token_index not in post_process_remove_list] for h in hypos
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    ]
    hypos_str = [sp_model.decode(s) for s in filtered_hypo_tokens]
    hypos_ali = [h.alignment[1:] for h in hypos]
    hypos_ids = [h.tokens[1:] for h in hypos]
    hypos_score = [[math.exp(h.score)] for h in hypos]

    nbest_batch = list(zip(hypos_str, hypos_score, hypos_ali, hypos_ids))

    return nbest_batch


class RNNTModule(LightningModule):
    def __init__(
        self,
        *,
        librispeech_path: str,
        sp_model_path: str,
        global_stats_path: str,
    ):
        super().__init__()

157
        self.model = emformer_rnnt_base(num_symbols=4097)
158
        self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
159
160
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
        self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, factor=0.96, patience=0)
161
162
163
        self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)

        self.train_data_pipeline = torch.nn.Sequential(
164
            FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
165
166
167
168
169
170
171
172
173
174
            GlobalStatsNormalization(global_stats_path),
            FunctionalModule(lambda x: x.transpose(1, 2)),
            torchaudio.transforms.FrequencyMasking(27),
            torchaudio.transforms.FrequencyMasking(27),
            TimeMasking(100, 0.2),
            TimeMasking(100, 0.2),
            FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
            FunctionalModule(lambda x: x.transpose(1, 2)),
        )
        self.valid_data_pipeline = torch.nn.Sequential(
175
            FunctionalModule(lambda x: piecewise_linear_log(x * GAIN)),
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
            GlobalStatsNormalization(global_stats_path),
            FunctionalModule(lambda x: x.transpose(1, 2)),
            FunctionalModule(lambda x: torch.nn.functional.pad(x, (0, 4))),
            FunctionalModule(lambda x: x.transpose(1, 2)),
        )

        self.librispeech_path = librispeech_path

        self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
        self.blank_idx = self.sp_model.get_piece_size()

    def _extract_labels(self, samples: List):
        targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
        lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
        targets = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(elem) for elem in targets],
            batch_first=True,
            padding_value=1.0,
        ).to(dtype=torch.int32)
        return targets, lengths

    def _train_extract_features(self, samples: List):
198
        mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
199
200
        features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
        features = self.train_data_pipeline(features)
201
        lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
202
203
204
        return features, lengths

    def _valid_extract_features(self, samples: List):
205
        mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
206
207
        features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
        features = self.valid_data_pipeline(features)
208
        lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
        return features, lengths

    def _train_collate_fn(self, samples: List):
        features, feature_lengths = self._train_extract_features(samples)
        targets, target_lengths = self._extract_labels(samples)
        return Batch(features, feature_lengths, targets, target_lengths)

    def _valid_collate_fn(self, samples: List):
        features, feature_lengths = self._valid_extract_features(samples)
        targets, target_lengths = self._extract_labels(samples)
        return Batch(features, feature_lengths, targets, target_lengths)

    def _test_collate_fn(self, samples: List):
        return self._valid_collate_fn(samples), samples

    def _step(self, batch, batch_idx, step_type):
        if batch is None:
            return None

228
        prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        prepended_targets[:, 1:] = batch.targets
        prepended_targets[:, 0] = self.blank_idx
        prepended_target_lengths = batch.target_lengths + 1
        output, src_lengths, _, _ = self.model(
            batch.features,
            batch.feature_lengths,
            prepended_targets,
            prepended_target_lengths,
        )
        loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
        self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return (
            [self.optimizer],
            [
                {
                    "scheduler": self.lr_scheduler,
                    "monitor": "Losses/val_loss",
                    "interval": "epoch",
                },
                {"scheduler": self.warmup_lr_scheduler, "interval": "step"},
            ],
        )

    def forward(self, batch: Batch):
        decoder = RNNTBeamSearch(self.model, self.blank_idx)
257
        hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
        return post_process_hypos(hypotheses, self.sp_model)[0][0]

    def training_step(self, batch: Batch, batch_idx):
        return self._step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "val")

    def test_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "test")

    def train_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                CustomDataset(
273
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
274
275
276
                    1000,
                ),
                CustomDataset(
277
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
278
279
280
                    1000,
                ),
                CustomDataset(
281
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
                    1000,
                ),
            ]
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=None,
            collate_fn=self._train_collate_fn,
            num_workers=10,
            shuffle=True,
        )
        return dataloader

    def val_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                CustomDataset(
299
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
300
301
302
                    1000,
                ),
                CustomDataset(
303
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
304
305
306
307
308
                    1000,
                ),
            ]
        )
        dataloader = torch.utils.data.DataLoader(
309
310
311
312
            dataset,
            batch_size=None,
            collate_fn=self._valid_collate_fn,
            num_workers=10,
313
314
315
316
        )
        return dataloader

    def test_dataloader(self):
317
318
        dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
319
        return dataloader