beit.py 1.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
from timm.models.beit import Beit

from colossalai.utils.cuda import get_current_device

from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator


class DummyDataLoader(DummyDataGenerator):
    img_size = 64
    num_channel = 3
    num_class = 10
    batch_size = 4

    def generate(self):
        data = torch.randn((DummyDataLoader.batch_size, DummyDataLoader.num_channel, DummyDataLoader.img_size,
                            DummyDataLoader.img_size),
                           device=get_current_device())
        label = torch.randint(low=0,
                              high=DummyDataLoader.num_class,
                              size=(DummyDataLoader.batch_size,),
                              device=get_current_device())
        return data, label


@non_distributed_component_funcs.register(name='beit')
def get_training_components():

30
    def model_builder(checkpoint=False):
31
32
33
34
35
36
37
38
39
40
41
        model = Beit(img_size=DummyDataLoader.img_size,
                     num_classes=DummyDataLoader.num_class,
                     embed_dim=32,
                     depth=2,
                     num_heads=4)
        return model

    trainloader = DummyDataLoader()
    testloader = DummyDataLoader()

    criterion = torch.nn.CrossEntropyLoss()
42
    return model_builder, trainloader, testloader, torch.optim.Adam, criterion