_hubert_datamodule.py 2.08 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch
from pytorch_lightning import LightningDataModule

from ._utils import BucketizeBatchSampler, CollateFnHubert, DistributedBatchSampler, HuBERTDataSet


class HuBERTDataModule(LightningDataModule):
    hubert_cls = HuBERTDataSet

    def __init__(
        self,
        *,
        dataset_path,
        dataset,
        feature_type,
        seconds_per_batch,
        train_shuffle=True,
        num_workers=10,
    ):
        super().__init__()
        self.dataset_path = dataset_path
        self.dataset = dataset
        self.feature_type = feature_type
        self.seconds_per_batch = seconds_per_batch
        self.train_shuffle = train_shuffle
        self.num_workers = num_workers

    def train_dataloader(self):
        dataset = self.hubert_cls(self.dataset_path, self.dataset, "train")
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=10000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=True,
        )
        sampler = DistributedBatchSampler(sampler, shuffle=self.train_shuffle)
        sampler.set_epoch(self.trainer.current_epoch)
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader

    def val_dataloader(self):
        dataset = self.hubert_cls(self.dataset_path, self.dataset, "valid")
        sampler = BucketizeBatchSampler(
            dataset.len_list,
            num_buckets=1000,
            max_token_count=self.seconds_per_batch * 16000,
            min_len=32000,
            max_len=250000,
            shuffle=False,
        )
        dataloader = torch.utils.data.DataLoader(
            dataset,
            batch_sampler=sampler,
            collate_fn=CollateFnHubert(feature_type=self.feature_type, pad=False, rand_crop=True),
            num_workers=self.num_workers,
        )
        return dataloader