"git@developer.sourcefind.cn:OpenDAS/colossalai.git" did not exist on "a9b27b9265c31175192643e3974187e5ea112c1d"
dummy_dataloader.py 1.23 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch


class DummyDataloader():

    def __init__(self, batch_size, vocab_size, seq_length):
        self.batch_size = batch_size
        self.vocab_size = vocab_size
        self.seq_length = seq_length
        self.step = 0

    def generate(self):
        tokens = torch.randint(low=0, high=self.vocab_size, size=(
            self.batch_size,
            self.seq_length,
        ))
        types = torch.randint(low=0, high=3, size=(
            self.batch_size,
            self.seq_length,
        ))
        sentence_order = torch.randint(low=0, high=2, size=(self.batch_size,))
        loss_mask = torch.randint(low=0, high=2, size=(
            self.batch_size,
            self.seq_length,
        ))
        lm_labels = torch.randint(low=0, high=self.vocab_size, size=(self.batch_size, self.seq_length))
        padding_mask = torch.randint(low=0, high=2, size=(self.batch_size, self.seq_length))
        return dict(text=tokens,
                    types=types,
                    is_random=sentence_order,
                    loss_mask=loss_mask,
                    labels=lm_labels,
                    padding_mask=padding_mask)

    def __iter__(self):
        return self

    def __next__(self):
        return self.generate()