repeated_computed_layers.py 1.37 KB
Newer Older
1
2
3
4
#!/usr/bin/env python

import torch
import torch.nn as nn
HELSON's avatar
HELSON committed
5

6
from colossalai.nn import CheckpointModule
HELSON's avatar
HELSON committed
7

8
from .registry import non_distributed_component_funcs
HELSON's avatar
HELSON committed
9
from .utils.dummy_data_generator import DummyDataGenerator
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40


class NetWithRepeatedlyComputedLayers(CheckpointModule):
    """
    This model is to test with layers which go through forward pass multiple times.
    In this model, the fc1 and fc2 call forward twice
    """

    def __init__(self, checkpoint=False) -> None:
        super().__init__(checkpoint=checkpoint)
        self.fc1 = nn.Linear(5, 5)
        self.fc2 = nn.Linear(5, 5)
        self.fc3 = nn.Linear(5, 2)
        self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x


class DummyDataLoader(DummyDataGenerator):

    def generate(self):
        data = torch.rand(16, 5)
        label = torch.randint(low=0, high=2, size=(16,))
        return data, label


@non_distributed_component_funcs.register(name='repeated_computed_layers')
def get_training_components():
41

HELSON's avatar
HELSON committed
42
    def model_builder(checkpoint=False):
43
44
        return NetWithRepeatedlyComputedLayers(checkpoint)

45
46
    trainloader = DummyDataLoader()
    testloader = DummyDataLoader()
47

48
    criterion = torch.nn.CrossEntropyLoss()
49
    return model_builder, trainloader, testloader, torch.optim.Adam, criterion