test_optim.py 1.36 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch as th
import backend as F

from dgl.nn import NodeEmbedding
from dgl.optim import SparseAdam, SparseAdagrad

import unittest, os

@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
def test_sparse_adam():
    num_embs = 10
    emb_dim = 4
    device=F.ctx()
    dgl_emb = NodeEmbedding(num_embs, emb_dim, 'test')
    torch_emb = th.nn.Embedding(num_embs, emb_dim, sparse=True)
    th.manual_seed(0)
    th.nn.init.uniform_(torch_emb.weight, 0, 1.0)
    th.manual_seed(0)
    th.nn.init.uniform_(dgl_emb.emb_tensor, 0, 1.0)

    dgl_adam = SparseAdam(params=[dgl_emb], lr=0.01)
    torch_adam = th.optim.SparseAdam(list(torch_emb.parameters()), lr=0.01)

    # first step
    idx = th.randint(0, num_embs, size=(4,))
    dgl_value = dgl_emb(idx, device).to(th.device('cpu'))
    torch_value = torch_emb(idx)
    labels = th.ones((4,)).long()

    dgl_adam.zero_grad()
    torch_adam.zero_grad()
    dgl_loss = th.nn.functional.cross_entropy(dgl_value, labels)
    torch_loss = th.nn.functional.cross_entropy(torch_value, labels)
    dgl_loss.backward()
    torch_loss.backward()

    dgl_adam.step()
    torch_adam.step()
    assert F.allclose(dgl_emb.emb_tensor, torch_emb.weight)

    # Can not test second step
    # Pytorch sparseAdam maintains a global step
    # DGL sparseAdam use a per embedding step

if __name__ == '__main__':
    test_sparse_adam()