test_sparse_grads.py 2.39 KB
Newer Older
aiss's avatar
aiss committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import torch
import torch.distributed as dist
import deepspeed
import pytest
from .common import distributed_test

import deepspeed.utils.groups as groups


def test_sparse_adam(tmpdir):
    config_dict = {"train_batch_size": 2, "steps_per_print": 1, "sparse_gradients": True}

    class Model(torch.nn.Module):
        def __init__(self):
            super().__init__()
            self.emb = torch.nn.EmbeddingBag(10, 3, mode="sum", sparse=True)
            self.linear = torch.nn.Linear(3, 1)

        def forward(self, x, offsets):
            return self.linear(self.emb(x, offsets))

    class Adam(torch.optim.Optimizer):
        def __init__(self, dense_params, sparse_params):
            super().__init__(dense_params + sparse_params, defaults={})
            self.adam = torch.optim.Adam(dense_params)
            self.adam_sparse = torch.optim.SparseAdam(sparse_params)

        @torch.no_grad()
        def step(self, closure=None):
            loss_1 = self.adam.step(closure)
            loss_2 = self.adam_sparse.step(closure)

            if loss_1 is not None and loss_2 is not None:
                return loss_1 + loss_2
            return loss_1 or loss_2

    model = Model()
    optimizer = Adam(list(model.linear.parameters()), list(model.emb.parameters()))

    @distributed_test(world_size=[2])
    def _test(model, optimizer):
        engine, _, _, _ = deepspeed.initialize(model=model,
                                              optimizer=optimizer,
                                              config=config_dict)
        loss = torch.nn.BCEWithLogitsLoss()
        x = torch.tensor([1,
                          2,
                          4,
                          5,
                          4,
                          3,
                          2,
                          9],
                         dtype=torch.long,
                         device=engine.device)
        offsets = torch.tensor([0, 4], dtype=torch.long, device=engine.device)
        y = torch.tensor([[1.0], [0.0]], device=engine.device)
        res = engine(x, offsets)
        engine.backward(loss(res, y))
        engine.step()

        results = [
            engine.all_gather_scalar(i,
                                     groups._get_data_parallel_group())
            for i in model.emb.parameters()
        ]
        for res in results:
            assert torch.allclose(res[0], res[1])

    _test(model, optimizer)