inline_op_model.py 1.25 KB
Newer Older
1
2
3
import torch
import torch.nn as nn

4
from colossalai.legacy.nn import CheckpointModule
5
6
7
8
9
10
11
12
13
14
15
16
17

from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator


class InlineOpModule(CheckpointModule):
    """
    a module with inline Ops
    """

    def __init__(self, checkpoint=False) -> None:
        super().__init__(checkpoint=checkpoint)
        self.proj1 = nn.Linear(4, 8)
18
        self.proj2 = nn.Linear(8, 8)
19
20
21
22
23

    def forward(self, x):
        x = self.proj1(x)
        # inline add_
        x.add_(10)
24
        x = self.proj2(x)
25
26
27
28
29
30
31
32
33
34
35
36
37
        # inline relu_
        x = torch.relu_(x)
        x = self.proj2(x)
        return x


class DummyDataLoader(DummyDataGenerator):
    def generate(self):
        data = torch.rand(16, 4)
        label = torch.randint(low=0, high=2, size=(16,))
        return data, label


38
@non_distributed_component_funcs.register(name="inline_op_model")
39
def get_training_components():
HELSON's avatar
HELSON committed
40
    def model_builder(checkpoint=False):
41
42
43
44
45
46
47
        return InlineOpModule(checkpoint)

    trainloader = DummyDataLoader()
    testloader = DummyDataLoader()

    criterion = torch.nn.CrossEntropyLoss()
    from colossalai.nn.optimizer import HybridAdam
48

49
    return model_builder, trainloader, testloader, HybridAdam, criterion