simple_model.py 5.93 KB
Newer Older
1
2
3
4
5
import os
import json
import argparse
import torch

6
7
from deepspeed.pipe import PipelineModule, LayerSpec

8
9

class SimpleModel(torch.nn.Module):
10
    def __init__(self, hidden_dim, empty_grad=False):
11
12
13
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        if empty_grad:
Jeff Rasley's avatar
Jeff Rasley committed
14
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
15
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
Jeff Rasley's avatar
Jeff Rasley committed
16
        self.empty_grad = empty_grad
17
18
19

    def forward(self, x, y):
        hidden_dim = x
20
        if self.empty_grad and torch.distributed.get_rank() == 0:
Jeff Rasley's avatar
Jeff Rasley committed
21
22
23
            hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
        else:
            hidden_dim = self.linear(hidden_dim)
24
25
26
        return self.cross_entropy_loss(hidden_dim, y)


27
28
29
30
31
32
33
class LinearStack(torch.nn.Module):
    def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

34
35
        self.input_layer = torch.nn.Linear(in_features=self.input_dim,
                                           out_features=self.hidden_dim)
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        self.layers = torch.nn.ModuleList([
            torch.nn.Linear(in_features=self.hidden_dim,
                            out_features=self.hidden_dim,
                            bias=False) for x in range(num_layers)
        ])
        self.output_layer = torch.nn.Linear(in_features=self.hidden_dim,
                                            out_features=self.output_dim)

        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.input_layer(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x


class LinearStackPipe(PipelineModule):
    def __init__(self,
                 input_dim=128,
                 hidden_dim=128,
                 output_dim=128,
                 num_layers=4,
                 **kwargs):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        layers = []
        layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
        for x in range(self.num_layers):
            layers.append(
                LayerSpec(torch.nn.Linear,
                          self.hidden_dim,
                          self.hidden_dim,
                          bias=False))
            layers.append(lambda x: x)
        layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))

        super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)


80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
class SimpleOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.11072018):
        defaults = dict(lr=lr)
        super(SimpleOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SimpleOptimizer, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data.add_(-group['lr'], d_p)

        return loss


103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
class HybridStateOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.11072018):
        defaults = dict(lr=lr)
        super(HybridStateOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(HybridStateOptimizer, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue

                state = self.state[p]
                if len(state) == 0:
                    state['integer_step'] = 0
                    state['tensor_step'] = torch.zeros(1)

                d_p = p.grad.data
                p.data.add_(-group['lr'], d_p)
                state['integer_step'] += 1
                state['tensor_step'] += 1

        return loss


Olatunji Ruwase's avatar
Olatunji Ruwase committed
134
class PLD_SimpleModel(SimpleModel):
135
136
    def __init__(self, hidden_dim, empty_grad=False):
        super(PLD_SimpleModel, self).__init__(hidden_dim, empty_grad)
Olatunji Ruwase's avatar
Olatunji Ruwase committed
137
138
139
140
141
142
143
144

    def forward(self, x, y, **kwargs):
        pld = kwargs.get('progressive_layer_drop', False)
        theta = kwargs.get('pld_theta', 1.0)
        hidden_dim = super(PLD_SimpleModel, self).forward(x, y)
        return hidden_dim


145
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
146
    batch_size = model.train_micro_batch_size_per_gpu()
147
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path


163
def create_deepspeed_args():
164
165
166
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args='')
    args.deepspeed = True
167
168
169
170
    if torch.distributed.is_initialized():
        # We assume up to one full node executing unit tests
        assert torch.distributed.get_world_size() <= torch.cuda.device_count()
        args.local_rank = torch.distributed.get_rank()
171
    return args
172
173
174
175
176
177
178


def args_from_dict(tmpdir, config_dict):
    args = create_deepspeed_args()
    config_path = create_config_from_dict(tmpdir, config_dict)
    args.deepspeed_config = config_path
    return args