test_frame.py 6.9 KB
Newer Older
1
2
3
import torch as th
from torch.autograd import Variable
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
from dgl.frame import Frame, FrameRef
5
6

N = 10
Minjie Wang's avatar
Minjie Wang committed
7
D = 5
8
9

def check_eq(a, b):
Minjie Wang's avatar
Minjie Wang committed
10
11
12
13
14
15
16
17
    return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())

def check_fail(fn):
    try:
        fn()
        return False
    except:
        return True
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

def create_test_data(grad=False):
    c1 = Variable(th.randn(N, D), requires_grad=grad)
    c2 = Variable(th.randn(N, D), requires_grad=grad)
    c3 = Variable(th.randn(N, D), requires_grad=grad)
    return {'a1' : c1, 'a2' : c2, 'a3' : c3}

def test_create():
    data = create_test_data()
    f1 = Frame()
    for k, v in data.items():
        f1.add_column(k, v)
    assert f1.schemes == set(data.keys())
    assert f1.num_columns == 3
    assert f1.num_rows == N
    f2 = Frame(data)
    assert f2.schemes == set(data.keys())
    assert f2.num_columns == 3
    assert f2.num_rows == N
    f1.clear()
    assert len(f1.schemes) == 0
    assert f1.num_rows == 0

Minjie Wang's avatar
Minjie Wang committed
41
42
def test_column1():
    # Test frame column getter/setter
43
44
    data = create_test_data()
    f = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
45
46
47
    assert f.num_rows == N
    assert len(f) == 3
    assert check_eq(f['a1'], data['a1'])
48
    f['a1'] = data['a2']
Minjie Wang's avatar
Minjie Wang committed
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
    assert check_eq(f['a2'], data['a2'])
    # add a different length column should fail
    def failed_add_col():
        f['a4'] = th.zeros([N+1, D])
    assert check_fail(failed_add_col)
    # delete all the columns
    del f['a1']
    del f['a2']
    assert len(f) == 1
    del f['a3']
    assert f.num_rows == 0
    assert len(f) == 0
    # add a different length column should succeed
    f['a4'] = th.zeros([N+1, D])
    assert f.num_rows == N+1
    assert len(f) == 1

def test_column2():
    # Test frameref column getter/setter
    data = Frame(create_test_data())
    f = FrameRef(data, [3, 4, 5, 6, 7])
    assert f.num_rows == 5
    assert len(f) == 3
    assert check_eq(f['a1'], data['a1'][3:8])
    # set column should reflect on the referenced data
    f['a1'] = th.zeros([5, D])
    assert check_eq(data['a1'][3:8], th.zeros([5, D]))
    # add new column should be padded with zero
    f['a4'] = th.ones([5, D])
    assert len(data) == 4
    assert check_eq(data['a4'][0:3], th.zeros([3, D]))
    assert check_eq(data['a4'][3:8], th.ones([5, D]))
    assert check_eq(data['a4'][8:10], th.zeros([2, D]))
82

Minjie Wang's avatar
Minjie Wang committed
83
84
def test_append1():
    # test append API on Frame
85
    data = create_test_data()
Minjie Wang's avatar
Minjie Wang committed
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
    f1 = Frame()
    f2 = Frame(data)
    f1.append(data)
    assert f1.num_rows == N
    f1.append(f2)
    assert f1.num_rows == 2 * N
    c1 = f1['a1']
    assert c1.shape == (2 * N, D)
    truth = th.cat([data['a1'], data['a1']])
    assert check_eq(truth, c1)

def test_append2():
    # test append on FrameRef
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    # append on the underlying frame should not reflect on the ref
    data.append(data)
    assert f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == N
    # append on the FrameRef should work
    f.append(data)
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == 3 * N
    new_idx = list(range(N)) + list(range(2*N, 4*N))
    assert check_eq(f.index_tensor(), th.tensor(new_idx))
    assert data.num_rows == 4 * N

def test_row1():
    # test row getter/setter
    data = create_test_data()
    f = FrameRef(Frame(data))
122
123
124
125
126

    # getter
    # test non-duplicate keys
    rowid = th.tensor([0, 2])
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
127
    for k, v in rows.items():
128
        assert v.shape == (len(rowid), D)
Minjie Wang's avatar
Minjie Wang committed
129
        assert check_eq(v, data[k][rowid])
130
131
132
    # test duplicate keys
    rowid = th.tensor([8, 2, 2, 1])
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
133
    for k, v in rows.items():
134
        assert v.shape == (len(rowid), D)
Minjie Wang's avatar
Minjie Wang committed
135
        assert check_eq(v, data[k][rowid])
136
137
138
139
140
141
142
143

    # setter
    rowid = th.tensor([0, 2, 4])
    vals = {'a1' : th.zeros((len(rowid), D)),
            'a2' : th.zeros((len(rowid), D)),
            'a3' : th.zeros((len(rowid), D)),
            }
    f[rowid] = vals
Minjie Wang's avatar
Minjie Wang committed
144
145
    for k, v in f[rowid].items():
        assert check_eq(v, th.zeros((len(rowid), D)))
146

Minjie Wang's avatar
Minjie Wang committed
147
148
def test_row2():
    # test row getter/setter autograd compatibility
149
    data = create_test_data(grad=True)
Minjie Wang's avatar
Minjie Wang committed
150
    f = FrameRef(Frame(data))
151
152
153
154
155
156
157

    # getter
    c1 = f['a1']
    # test non-duplicate keys
    rowid = th.tensor([0, 2])
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
Minjie Wang's avatar
Minjie Wang committed
158
    assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
159
160
161
162
163
    c1.grad.data.zero_()
    # test duplicate keys
    rowid = th.tensor([8, 2, 2, 1])
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
Minjie Wang's avatar
Minjie Wang committed
164
    assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
165
166
167
168
169
170
171
172
173
174
175
176
    c1.grad.data.zero_()

    # setter
    c1 = f['a1']
    rowid = th.tensor([0, 2, 4])
    vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            }
    f[rowid] = vals
    c11 = f['a1']
    c11.backward(th.ones((N, D)))
Minjie Wang's avatar
Minjie Wang committed
177
178
    assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
    assert check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
179
180
    assert vals['a2'].grad is None

Minjie Wang's avatar
Minjie Wang committed
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
def test_row3():
    # test row delete
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    del f[th.tensor([2, 3])]
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    # delete is lazy: only reflect on the ref while the
    # underlying storage should not be touched
    assert f.num_rows == N - 2
    assert data.num_rows == N
    newidx = list(range(N))
    newidx.pop(2)
    newidx.pop(2)
    for k, v in f.items():
        assert check_eq(v, data[k][th.tensor(newidx)])

def test_sharing():
    data = Frame(create_test_data())
    f1 = FrameRef(data, index=[0, 1, 2, 3])
    f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
    # test read
    for k, v in f1.items():
        assert check_eq(data[k][0:4], v)
    for k, v in f2.items():
        assert check_eq(data[k][2:7], v)
    f2_a1 = f2['a1']
    # test write
    # update own ref should not been seen by the other.
    f1[th.tensor([0, 1])] = {
            'a1' : th.zeros([2, D]),
            'a2' : th.zeros([2, D]),
            'a3' : th.zeros([2, D]),
            }
    assert check_eq(f2['a1'], f2_a1)
    # update shared space should been seen by the other.
    f1[th.tensor([2, 3])] = {
            'a1' : th.ones([2, D]),
            'a2' : th.ones([2, D]),
            'a3' : th.ones([2, D]),
            }
    f2_a1[0:2] = th.ones([2, D])
    assert check_eq(f2['a1'], f2_a1)
227
228
229

if __name__ == '__main__':
    test_create()
Minjie Wang's avatar
Minjie Wang committed
230
231
232
233
234
235
236
237
    test_column1()
    test_column2()
    test_append1()
    test_append2()
    test_row1()
    test_row2()
    test_row3()
    test_sharing()