dummy_dataloader.py 1.38 KB
Newer Older
1
2
3
import torch


4
class DummyDataloader:
5
6
7
8
9
10
11
    def __init__(self, batch_size, vocab_size, seq_length):
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.step = 0

    def generate(self):
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
        tokens = torch.randint(
            low=0,
            high=self.vocab_size,
            size=(
                self.batch_size,
                self.seq_length,
            ),
        )
        types = torch.randint(
            low=0,
            high=3,
            size=(
                self.batch_size,
                self.seq_length,
            ),
        )
28
        sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
29
30
31
32
33
34
35
36
        loss_mask = torch.randint(
            low=0,
            high=2,
            size=(
                self.batch_size,
                self.seq_length,
            ),
        )
37
38
        lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
        padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
39
40
41
42
43
44
45
46
        return dict(
            text=tokens,
            types=types,
            is_random=sentence_order,
            loss_mask=loss_mask,
            labels=lm_labels,
            padding_mask=padding_mask,
        )
47
48
49
50
51

    def __iter__(self):
        return self

    def __next__(self):
52
        return self.generate()