beit.py 1.29 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch
from timm.models.beit import Beit

from colossalai.utils.cuda import get_current_device

from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator


class DummyDataLoader(DummyDataGenerator):
    img_size = 64
    num_channel = 3
    num_class = 10
    batch_size = 4

    def generate(self):
17
18
19
20
21
22
23
24
25
26
27
28
        data = torch.randn(
            (
                DummyDataLoader.batch_size,
                DummyDataLoader.num_channel,
                DummyDataLoader.img_size,
                DummyDataLoader.img_size,
            ),
            device=get_current_device(),
        )
        label = torch.randint(
            low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device()
        )
29
30
31
        return data, label


32
@non_distributed_component_funcs.register(name="beit")
33
def get_training_components():
34
    def model_builder(checkpoint=False):
35
36
37
        model = Beit(
            img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4
        )
38
39
40
41
42
43
        return model

    trainloader = DummyDataLoader()
    testloader = DummyDataLoader()

    criterion = torch.nn.CrossEntropyLoss()
44
    return model_builder, trainloader, testloader, torch.optim.Adam, criterion