simple_model.py 4.6 KB
Newer Older
1
2
3
4
5
import os
import json
import argparse
import torch

6
7
from deepspeed.pipe import PipelineModule, LayerSpec

8
9

class SimpleModel(torch.nn.Module):
Jeff Rasley's avatar
Jeff Rasley committed
10
    def __init__(self, hidden_dim, empty_grad=False, rank=0):
11
12
13
        super(SimpleModel, self).__init__()
        self.linear = torch.nn.Linear(hidden_dim, hidden_dim)
        if empty_grad:
Jeff Rasley's avatar
Jeff Rasley committed
14
            self.linear2 = torch.nn.Linear(hidden_dim, hidden_dim)
15
        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()
Jeff Rasley's avatar
Jeff Rasley committed
16
17
        self.rank = rank
        self.empty_grad = empty_grad
18
19
20

    def forward(self, x, y):
        hidden_dim = x
Jeff Rasley's avatar
Jeff Rasley committed
21
22
23
24
        if self.rank == 0 and self.empty_grad:
            hidden_dim = self.linear(hidden_dim) + self.linear2(hidden_dim)
        else:
            hidden_dim = self.linear(hidden_dim)
25
26
27
        return self.cross_entropy_loss(hidden_dim, y)


28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
class LinearStack(torch.nn.Module):
    def __init__(self, input_dim=128, hidden_dim=128, output_dim=128, num_layers=4):
        super().__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim

        self.input_layer = VerboseLinear(in_features=self.input_dim,
                                         out_features=self.hidden_dim)
        self.layers = torch.nn.ModuleList([
            torch.nn.Linear(in_features=self.hidden_dim,
                            out_features=self.hidden_dim,
                            bias=False) for x in range(num_layers)
        ])
        self.output_layer = torch.nn.Linear(in_features=self.hidden_dim,
                                            out_features=self.output_dim)

        self.cross_entropy_loss = torch.nn.CrossEntropyLoss()

    def forward(self, x, y):
        x = self.input_layer(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_layer(x)
        return x


class LinearStackPipe(PipelineModule):
    def __init__(self,
                 input_dim=128,
                 hidden_dim=128,
                 output_dim=128,
                 num_layers=4,
                 **kwargs):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        layers = []
        layers.append(LayerSpec(torch.nn.Linear, self.input_dim, self.hidden_dim))
        for x in range(self.num_layers):
            layers.append(
                LayerSpec(torch.nn.Linear,
                          self.hidden_dim,
                          self.hidden_dim,
                          bias=False))
            layers.append(lambda x: x)
        layers.append(LayerSpec(torch.nn.Linear, self.hidden_dim, self.output_dim))

        super().__init__(layers=layers, loss_fn=torch.nn.CrossEntropyLoss(), **kwargs)


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
class SimpleOptimizer(torch.optim.Optimizer):
    def __init__(self, params, lr=0.11072018):
        defaults = dict(lr=lr)
        super(SimpleOptimizer, self).__init__(params, defaults)

    def __setstate__(self, state):
        super(SimpleOptimizer, self).__setstate__(state)

    def step(self, closure=None):
        loss = None
        if closure is not None:
            loss = closure()

        for group in self.param_groups:
            for p in group['params']:
                if p.grad is None:
                    continue
                d_p = p.grad.data
                p.data.add_(-group['lr'], d_p)

        return loss


104
def random_dataloader(model, total_samples, hidden_dim, device, dtype=torch.half):
105
    batch_size = model.train_micro_batch_size_per_gpu()
106
    train_data = torch.randn(total_samples, hidden_dim, device=device, dtype=dtype)
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
    train_label = torch.empty(total_samples,
                              dtype=torch.long,
                              device=device).random_(hidden_dim)
    train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size)
    return train_loader


def create_config_from_dict(tmpdir, config_dict):
    config_path = os.path.join(tmpdir, 'temp_config.json')
    with open(config_path, 'w') as fd:
        json.dump(config_dict, fd)
    return config_path


def args_from_dict(tmpdir, config_dict):
    config_path = create_config_from_dict(tmpdir, config_dict)
    parser = argparse.ArgumentParser()
    args = parser.parse_args(args='')
    args.deepspeed = True
    args.deepspeed_config = config_path
128
129
130
131
132
133
    if torch.distributed.is_initialized():
        # We assume up to one full node executing unit tests
        assert torch.distributed.get_world_size() <= torch.cuda.device_count()
        args.local_rank = torch.distributed.get_rank()
    else:
        args.local_rank = 0
134
    return args