lightning.py 8.77 KB
Newer Older
1
import os
2
from functools import partial
3
from typing import List
4
5
6
7

import sentencepiece as spm
import torch
import torchaudio
8
9
from common import (
    Batch,
10
    batch_by_token_count,
11
12
13
14
15
    FunctionalModule,
    GlobalStatsNormalization,
    piecewise_linear_log,
    post_process_hypos,
    spectrogram_transform,
16
    WarmupLR,
17
)
18
from pytorch_lightning import LightningModule
19
from torchaudio.models import emformer_rnnt_base, RNNTBeamSearch
20
21
22


class CustomDataset(torch.utils.data.Dataset):
23
    r"""Sort LibriSpeech samples by target length and batch to max token count."""
24
25
26
27
28
29
30
31
32
33
34
35
36

    def __init__(self, base_dataset, max_token_limit):
        super().__init__()
        self.base_dataset = base_dataset

        fileid_to_target_length = {}
        idx_target_lengths = [
            (idx, self._target_length(fileid, fileid_to_target_length))
            for idx, fileid in enumerate(self.base_dataset._walker)
        ]

        assert len(idx_target_lengths) > 0

37
        idx_target_lengths = sorted(idx_target_lengths, key=lambda x: x[1], reverse=True)
38
39
40

        assert max_token_limit >= idx_target_lengths[0][1]

41
        self.batches = batch_by_token_count(idx_target_lengths, max_token_limit)
42
43
44
45
46
47

    def _target_length(self, fileid, fileid_to_target_length):
        if fileid not in fileid_to_target_length:
            speaker_id, chapter_id, _ = fileid.split("-")

            file_text = speaker_id + "-" + chapter_id + self.base_dataset._ext_txt
48
            file_text = os.path.join(self.base_dataset._path, speaker_id, chapter_id, file_text)
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

            with open(file_text) as ft:
                for line in ft:
                    fileid_text, transcript = line.strip().split(" ", 1)
                    fileid_to_target_length[fileid_text] = len(transcript)

        return fileid_to_target_length[fileid]

    def __getitem__(self, idx):
        return [self.base_dataset[subidx] for subidx in self.batches[idx]]

    def __len__(self):
        return len(self.batches)


64
class LibriSpeechRNNTModule(LightningModule):
65
66
67
68
69
70
71
72
73
    def __init__(
        self,
        *,
        librispeech_path: str,
        sp_model_path: str,
        global_stats_path: str,
    ):
        super().__init__()

74
        self.model = emformer_rnnt_base(num_symbols=4097)
75
        self.loss = torchaudio.transforms.RNNTLoss(reduction="sum", clamp=1.0)
76
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=5e-4, betas=(0.9, 0.999), eps=1e-8)
77
78
79
        self.warmup_lr_scheduler = WarmupLR(self.optimizer, 10000)

        self.train_data_pipeline = torch.nn.Sequential(
80
            FunctionalModule(piecewise_linear_log),
81
            GlobalStatsNormalization(global_stats_path),
82
            FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
83
84
            torchaudio.transforms.FrequencyMasking(27),
            torchaudio.transforms.FrequencyMasking(27),
85
86
            torchaudio.transforms.TimeMasking(100, p=0.2),
            torchaudio.transforms.TimeMasking(100, p=0.2),
87
88
            FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
            FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
89
90
        )
        self.valid_data_pipeline = torch.nn.Sequential(
91
            FunctionalModule(piecewise_linear_log),
92
            GlobalStatsNormalization(global_stats_path),
93
94
95
            FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
            FunctionalModule(partial(torch.nn.functional.pad, pad=(0, 4))),
            FunctionalModule(partial(torch.transpose, dim0=1, dim1=2)),
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
        )

        self.librispeech_path = librispeech_path

        self.sp_model = spm.SentencePieceProcessor(model_file=sp_model_path)
        self.blank_idx = self.sp_model.get_piece_size()

    def _extract_labels(self, samples: List):
        targets = [self.sp_model.encode(sample[2].lower()) for sample in samples]
        lengths = torch.tensor([len(elem) for elem in targets]).to(dtype=torch.int32)
        targets = torch.nn.utils.rnn.pad_sequence(
            [torch.tensor(elem) for elem in targets],
            batch_first=True,
            padding_value=1.0,
        ).to(dtype=torch.int32)
        return targets, lengths

    def _train_extract_features(self, samples: List):
114
        mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
115
116
        features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
        features = self.train_data_pipeline(features)
117
        lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
118
119
120
        return features, lengths

    def _valid_extract_features(self, samples: List):
121
        mel_features = [spectrogram_transform(sample[0].squeeze()).transpose(1, 0) for sample in samples]
122
123
        features = torch.nn.utils.rnn.pad_sequence(mel_features, batch_first=True)
        features = self.valid_data_pipeline(features)
124
        lengths = torch.tensor([elem.shape[0] for elem in mel_features], dtype=torch.int32)
125
126
127
128
129
130
131
132
133
134
135
136
137
        return features, lengths

    def _train_collate_fn(self, samples: List):
        features, feature_lengths = self._train_extract_features(samples)
        targets, target_lengths = self._extract_labels(samples)
        return Batch(features, feature_lengths, targets, target_lengths)

    def _valid_collate_fn(self, samples: List):
        features, feature_lengths = self._valid_extract_features(samples)
        targets, target_lengths = self._extract_labels(samples)
        return Batch(features, feature_lengths, targets, target_lengths)

    def _test_collate_fn(self, samples: List):
138
        return self._valid_collate_fn(samples), [sample[2] for sample in samples]
139
140
141
142
143

    def _step(self, batch, batch_idx, step_type):
        if batch is None:
            return None

144
        prepended_targets = batch.targets.new_empty([batch.targets.size(0), batch.targets.size(1) + 1])
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        prepended_targets[:, 1:] = batch.targets
        prepended_targets[:, 0] = self.blank_idx
        prepended_target_lengths = batch.target_lengths + 1
        output, src_lengths, _, _ = self.model(
            batch.features,
            batch.feature_lengths,
            prepended_targets,
            prepended_target_lengths,
        )
        loss = self.loss(output, batch.targets, src_lengths, batch.target_lengths)
        self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return (
            [self.optimizer],
            [
                {"scheduler": self.warmup_lr_scheduler, "interval": "step"},
            ],
        )

    def forward(self, batch: Batch):
        decoder = RNNTBeamSearch(self.model, self.blank_idx)
168
        hypotheses = decoder(batch.features.to(self.device), batch.feature_lengths.to(self.device), 20)
169
170
171
172
173
174
175
176
        return post_process_hypos(hypotheses, self.sp_model)[0][0]

    def training_step(self, batch: Batch, batch_idx):
        return self._step(batch, batch_idx, "train")

    def validation_step(self, batch, batch_idx):
        return self._step(batch, batch_idx, "val")

177
178
    def test_step(self, batch_tuple, batch_idx):
        return self._step(batch_tuple[0], batch_idx, "test")
179
180
181
182
183

    def train_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                CustomDataset(
184
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-360"),
185
186
187
                    1000,
                ),
                CustomDataset(
188
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-clean-100"),
189
190
191
                    1000,
                ),
                CustomDataset(
192
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="train-other-500"),
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
                    1000,
                ),
            ]
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_size=None,
            collate_fn=self._train_collate_fn,
            num_workers=10,
            shuffle=True,
        )
        return dataloader

    def val_dataloader(self):
        dataset = torch.utils.data.ConcatDataset(
            [
                CustomDataset(
210
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-clean"),
211
212
213
                    1000,
                ),
                CustomDataset(
214
                    torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="dev-other"),
215
216
217
218
219
                    1000,
                ),
            ]
        )
        dataloader = torch.utils.data.DataLoader(
220
221
222
223
            dataset,
            batch_size=None,
            collate_fn=self._valid_collate_fn,
            num_workers=10,
224
225
226
227
        )
        return dataloader

    def test_dataloader(self):
228
229
        dataset = torchaudio.datasets.LIBRISPEECH(self.librispeech_path, url="test-clean")
        dataloader = torch.utils.data.DataLoader(dataset, batch_size=1, collate_fn=self._test_collate_fn)
230
        return dataloader