lightning.py 13.8 KB
Newer Older
1
import math
2
3
4
from typing import Tuple

import torch
5
import torch.nn.functional as F
6
import torchaudio
7
import torchaudio.models.wav2vec2.components as components
8
from dataset import (
9
10
    _get_lengths_librilightlimited,
    _get_lengths_librispeech,
11
12
    BucketizeBatchSampler,
    CollateFnHubert,
13
    CollateFnLibriLightLimited,
14
15
16
17
18
19
20
21
22
23
24
    DistributedBatchSampler,
    HuBERTDataSet,
)
from loss import hubert_loss
from pytorch_lightning import LightningModule
from torch import Tensor
from torch.optim.optimizer import Optimizer
from torch.utils.data import DataLoader


Batch = Tuple[Tensor, Tensor, Tensor]
25
Batch_FineTune = Tuple[Tensor, Tensor, Tensor, Tensor]
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52


class LinearDecayLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """Linear learning rate scheduler with warm up."""

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_updates: int,
        max_updates: int,
        last_epoch: int = -1,
        verbose: bool = False,
    ):
        self.warmup_updates = warmup_updates
        self.max_updates = max_updates
        super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
        if self._step_count <= self.warmup_updates:
            return [self._step_count / self.warmup_updates * base_lr for base_lr in self.base_lrs]
        elif self._step_count >= self.max_updates:
            return [0.0 for _ in self.base_lrs]
        else:
            pct_remaining = (self.max_updates - self._step_count) / (self.max_updates - self.warmup_updates)
            return [base_lr * pct_remaining for base_lr in self.base_lrs]


53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
class TriStageLRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """Linear learning rate scheduler with warmup, hold, and decay."""

    def __init__(
        self,
        optimizer: Optimizer,
        warmup_updates: int,
        hold_updates: int,
        decay_updates: int,
        init_lr_scale: float = 0.01,
        final_lr_scale: float = 0.05,
        last_epoch: int = -1,
        verbose: bool = False,
    ):
        self.warmup_updates = warmup_updates
        self.hold_updates = hold_updates
        self.decay_updates = decay_updates
        self.init_lr_scale = init_lr_scale
        self.final_lr_scale = final_lr_scale

        super().__init__(optimizer, last_epoch=last_epoch, verbose=verbose)

    def get_lr(self):
        if self._step_count <= self.warmup_updates:
            return [
                base_lr * (self.init_lr_scale + self._step_count / self.warmup_updates * (1 - self.init_lr_scale))
                for base_lr in self.base_lrs
            ]
        elif self.warmup_updates < self._step_count <= (self.warmup_updates + self.hold_updates):
            return list(self.base_lrs)
        elif self._step_count <= (self.warmup_updates + self.hold_updates + self.decay_updates):
            return [
                base_lr
                * math.exp(
                    math.log(self.final_lr_scale)
                    * (self._step_count - self.warmup_updates - self.hold_updates)
                    / self.decay_updates
                )
                for base_lr in self.base_lrs
            ]
        else:
            return [base_lr * self.final_lr_scale for base_lr in self.base_lrs]


97
98
99
100
101
102
103
104
class HuBERTPreTrainModule(LightningModule):
    def __init__(
        self,
        *,
        model_name: str,
        feature_grad_mult: float,
        num_classes: int,
        dataset: str,
105
        dataset_path: str,
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        feature_type: str,
        seconds_per_batch: float,
        learning_rate: float,
        betas: Tuple[float, float],
        eps: float,
        weight_decay: float,
        warmup_updates: int,
        max_updates: int,
    ):
        super().__init__()

        if model_name == "hubert_pretrain_base":
            self.model = torchaudio.models.hubert_pretrain_base(
                feature_grad_mult=feature_grad_mult, num_classes=num_classes
            )
        elif model_name == "hubert_pretrain_large":
            self.model = torchaudio.models.hubert_pretrain_large()
        elif model_name == "hubert_pretrain_xlarge":
            self.model = torchaudio.models.hubert_pretrain_xlarge()
        else:
            raise ValueError(f"Unsupported model name: {model_name}")

        self.loss = hubert_loss
nateanl's avatar
nateanl committed
129
        self.optimizer = torch.optim.AdamW(
130
131
132
133
            self.model.parameters(), lr=learning_rate, betas=betas, eps=eps, weight_decay=weight_decay
        )
        self.lr_scheduler = LinearDecayLRScheduler(self.optimizer, warmup_updates, max_updates)
        self.dataset = dataset
134
        self.dataset_path = dataset_path
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
        self.feature_type = feature_type
        self.seconds_per_batch = seconds_per_batch

    def _step(self, batch: Batch, batch_idx, step_type):
        if batch is None:
            return None
        waveforms, labels, audio_lengths = batch
        logit_m, logit_u, feature_penalty = self.model(
            waveforms,
            labels,
            audio_lengths,
        )
        loss = self.loss(logit_m, logit_u, feature_penalty)
        self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return (
            [self.optimizer],
            [
                {
                    "scheduler": self.lr_scheduler,
                    "interval": "step",
                },
            ],
        )

    def training_step(self, batch: Batch, batch_idx):
        return self._step(batch, batch_idx, "train")

    def validation_step(self, batch: Batch, batch_idx):
        return self._step(batch, batch_idx, "val")

    def train_dataloader(self):
169
        dataset = HuBERTDataSet(self.dataset_path, self.dataset, "train")
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=10000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=True,
        )
        sampler = DistributedBatchSampler(sampler, shuffle=True)
        sampler.set_epoch(self.current_epoch)
        dataloader = DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=10,
        )
        return dataloader

    def val_dataloader(self):
189
        dataset = HuBERTDataSet(self.dataset_path, self.dataset, "valid")
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=1000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=False,
        )
        dataloader = DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=10,
        )
        return dataloader
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275


class HuBERTFineTuneModule(LightningModule):
    def __init__(
        self,
        *,
        model_name: str,
        encoder_projection_dropout: float,
        encoder_attention_dropout: float,
        encoder_ff_interm_dropout: float,
        encoder_dropout: float,
        encoder_layer_drop: float,
        mask_prob: float,
        mask_channel_prob: float,
        mask_channel_length: float,
        aux_num_out: int,
        checkpoint: str,
        dataset_path: str,
        seconds_per_batch: float,
        subset: str,
        learning_rate: float,
        betas: Tuple[float, float],
        adam_eps: float,
        weight_decay: float,
        freeze_encoder_updates: int,
        warmup_updates: int,
        hold_updates: int,
        decay_updates: int,
    ):
        super().__init__()

        if model_name == "hubert_pretrain_base":
            self.model = torchaudio.models.hubert_pretrain_base(
                encoder_projection_dropout=encoder_projection_dropout,
                encoder_attention_dropout=encoder_attention_dropout,
                encoder_ff_interm_dropout=encoder_ff_interm_dropout,
                encoder_dropout=encoder_dropout,
                encoder_layer_drop=encoder_layer_drop,
                mask_prob=mask_prob,
                mask_channel_prob=mask_channel_prob,
                mask_channel_length=mask_channel_length,
            )
        elif model_name == "hubert_large":
            self.model = torchaudio.models.hubert_pretrain_large(
                encoder_projection_dropout=encoder_projection_dropout,
                encoder_attention_dropout=encoder_attention_dropout,
                encoder_ff_interm_dropout=encoder_ff_interm_dropout,
                encoder_dropout=encoder_dropout,
                encoder_layer_drop=encoder_layer_drop,
                mask_prob=mask_prob,
                mask_channel_prob=mask_channel_prob,
                mask_channel_length=mask_channel_length,
            )
        elif model_name == "hubert_xlarge":
            self.model = torchaudio.models.hubert_pretrain_xlarge(
                encoder_projection_dropout=encoder_projection_dropout,
                encoder_attention_dropout=encoder_attention_dropout,
                encoder_ff_interm_dropout=encoder_ff_interm_dropout,
                encoder_dropout=encoder_dropout,
                encoder_layer_drop=encoder_layer_drop,
                mask_prob=mask_prob,
                mask_channel_prob=mask_channel_prob,
                mask_channel_length=mask_channel_length,
            )
        else:
            raise ValueError(f"Unsupported model name: {model_name}.")
        self.aux = torch.nn.Linear(768, aux_num_out)
        self._load_checkpoint(checkpoint)
        for p in self.model.wav2vec2.feature_extractor.parameters():
            p.requires_grad = False
        self.loss_fn = torch.nn.CTCLoss(blank=0, reduction="sum", zero_infinity=True)
276
        self.optimizer = torch.optim.AdamW(
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
            list(self.aux.parameters()) + list(self.model.wav2vec2.encoder.parameters()),
            lr=learning_rate,
            betas=betas,
            eps=adam_eps,
            weight_decay=weight_decay,
        )
        self.freeze_encoder_updates = freeze_encoder_updates
        self.lr_scheduler = TriStageLRScheduler(self.optimizer, warmup_updates, hold_updates, decay_updates)
        self.dataset_path = dataset_path
        self.seconds_per_batch = seconds_per_batch
        self.subset = subset

    def _load_checkpoint(self, checkpoint):
        # load pretrain model
        state_dict = torch.load(checkpoint, map_location=torch.device("cpu"))
        state_dict = state_dict["state_dict"]
        s = {}
        for k in state_dict:
            if "wav2vec2" in k:
                s[k.replace("model.wav2vec2.", "")] = state_dict[k]
        self.model.wav2vec2.load_state_dict(s)

    def _step(self, batch: Batch_FineTune, batch_idx, step_type):
        if batch is None:
            return None
        waveforms, labels, audio_lengths, label_lengths = batch
        if self.global_step <= self.freeze_encoder_updates:
            with torch.no_grad():
                x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
                padding_mask = components._get_padding_mask(x, out_len)
                x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
                x, _ = self.model.mask_generator(x, padding_mask)
                x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
        else:
            with torch.no_grad():
                x, out_len = self.model.wav2vec2.feature_extractor(waveforms, audio_lengths)
                padding_mask = components._get_padding_mask(x, out_len)
            x, attention_mask = self.model.wav2vec2.encoder._preprocess(x, out_len)
            x, _ = self.model.mask_generator(x, padding_mask)
            x = self.model.wav2vec2.encoder.transformer(x, attention_mask=attention_mask)
        logits = self.aux(x)
        logits[padding_mask][..., 0] = 0
        logits[padding_mask][..., 1:] = float("-inf")
        log_probs = F.log_softmax(logits, dim=-1)
        log_probs = log_probs.transpose(0, 1)
        loss = self.loss_fn(
            log_probs,
            labels,
            out_len,
            label_lengths,
        )
        self.log(f"Losses/{step_type}_loss", loss, on_step=True, on_epoch=True)
        return loss

    def configure_optimizers(self):
        return (
            [
                self.optimizer,
            ],
            [
                {"scheduler": self.lr_scheduler, "interval": "step"},
            ],
        )

    def training_step(self, batch: Batch_FineTune, batch_idx):
        return self._step(batch, batch_idx, "train")

    def validation_step(self, batch: Batch_FineTune, batch_idx):
        return self._step(batch, batch_idx, "val")

    def train_dataloader(self):
        dataset = torchaudio.datasets.LibriLightLimited(self.dataset_path, self.subset)
        lengths = _get_lengths_librilightlimited(dataset._fileids_paths)
        sampler = BucketizeBatchSampler(
            lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=True
        )
        sampler = DistributedBatchSampler(sampler, shuffle=True)
        sampler.set_epoch(self.global_step)
        dataloader = DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnLibriLightLimited(),
            num_workers=10,
        )
        return dataloader

    def val_dataloader(self):
        dataset = torchaudio.datasets.LIBRISPEECH(self.dataset_path, "dev-other")
        lengths = _get_lengths_librispeech(dataset._walker, dataset._path, dataset._ext_audio)
        sampler = BucketizeBatchSampler(
            lengths, num_buckets=100, max_token_count=self.seconds_per_batch * 16000, shuffle=False
        )
        dataloader = DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnLibriLightLimited(),
            num_workers=10,
        )
        return dataloader