simple_model.py 5.91 KB
Newer Older
1
2
3
4
5
import os
import json
import argparse
import torch

6
7
from deepspeed.pipe import PipelineModule, LayerSpec

8
9

class SimpleModel(torch.nn.Module):
Jeff Rasley's avatar
Jeff Rasley committed
10
    def __init__(self, hidden_dim, empty_grad=False, rank=0):
11
12
13
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        if empty_grad:
Jeff Rasley's avatar
Jeff Rasley committed
14
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
15
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
Jeff Rasley's avatar
Jeff Rasley committed
16
17
        self.rank = rank
        self.empty_grad = empty_grad
18
19
20

    def forward(self, x, y):
        hidden_dim = x
Jeff Rasley's avatar
Jeff Rasley committed
21
22
23
24
        if self.rank == 0 and self.empty_grad:
            hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
        else:
            hidden_dim = self.linear(hidden_dim)
25
26
27
        return self.cross_entropy_loss(hidden_dim, y)


28
29
30
31
32
33
34
class LinearStack(torch.nn.Module):
    def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

35
36
        self.input_layer = torch.nn.Linear(in_features=self.input_dim,
                                           out_features=self.hidden_dim)
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        self.layers = torch.nn.ModuleList([
            torch.nn.Linear(in_features=self.hidden_dim,
                            out_features=self.hidden_dim,
                            bias=False) for x in range(num_layers)
        ])
        self.output_layer = torch.nn.Linear(in_features=self.hidden_dim,
                                            out_features=self.output_dim)

        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.input_layer(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x


class LinearStackPipe(PipelineModule):
    def __init__(self,
                 input_dim=128,
                 hidden_dim=128,
                 output_dim=128,
                 num_layers=4,
                 **kwargs):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        layers = []
        layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
        for x in range(self.num_layers):
            layers.append(
                LayerSpec(torch.nn.Linear,
                          self.hidden_dim,
                          self.hidden_dim,
                          bias=False))
            layers.append(lambda x: x)
        layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))

        super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class SimpleOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.11072018):
        defaults = dict(lr=lr)
        super(SimpleOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SimpleOptimizer, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data.add_(-group['lr'], d_p)

        return loss


104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class HybridStateOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.11072018):
        defaults = dict(lr=lr)
        super(HybridStateOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(HybridStateOptimizer, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state['integer_step'] = 0
                    state['tensor_step'] = torch.zeros(1)

                d_p = p.grad.data
                p.data.add_(-group['lr'], d_p)
                state['integer_step'] += 1
                state['tensor_step'] += 1

        return loss


Olatunji Ruwase's avatar
Olatunji Ruwase committed
135
136
137
138
139
140
141
142
143
144
145
class PLD_SimpleModel(SimpleModel):
    def __init__(self, hidden_dim, empty_grad=False, rank=0):
        super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad, rank)

    def forward(self, x, y, **kwargs):
        pld = kwargs.get('progressive_layer_drop', False)
        theta = kwargs.get('pld_theta', 1.0)
        hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
        return hidden_dim


146
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
147
    batch_size = model.train_micro_batch_size_per_gpu()
148
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path


def args_from_dict(tmpdir, config_dict):
    config_path = create_config_from_dict(tmpdir, config_dict)
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args='')
    args.deepspeed = True
    args.deepspeed_config = config_path
170
171
172
173
174
175
    if torch.distributed.is_initialized():
        # We assume up to one full node executing unit tests
        assert torch.distributed.get_world_size() <= torch.cuda.device_count()
        args.local_rank = torch.distributed.get_rank()
    else:
        args.local_rank = 0
176
    return args