test_frame.py 11 KB
Newer Older
1
2
3
import torch as th
from torch.autograd import Variable
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
from dgl.frame import Frame, FrameRef
Minjie Wang's avatar
Minjie Wang committed
5
from dgl.utils import Index, toindex
6
import utils as U
7
8

N = 10
Minjie Wang's avatar
Minjie Wang committed
9
D = 5
10

Minjie Wang's avatar
Minjie Wang committed
11
12
13
14
15
16
def check_fail(fn):
    try:
        fn()
        return False
    except:
        return True
17
18
19
20
21
22
23
24
25

def create_test_data(grad=False):
    c1 = Variable(th.randn(N, D), requires_grad=grad)
    c2 = Variable(th.randn(N, D), requires_grad=grad)
    c3 = Variable(th.randn(N, D), requires_grad=grad)
    return {'a1' : c1, 'a2' : c2, 'a3' : c3}

def test_create():
    data = create_test_data()
26
    f1 = Frame(num_rows=N)
27
    for k, v in data.items():
Minjie Wang's avatar
Minjie Wang committed
28
29
30
        f1.update_column(k, v)
    print(f1.schemes)
    assert f1.keys() == set(data.keys())
31
32
33
    assert f1.num_columns == 3
    assert f1.num_rows == N
    f2 = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
34
    assert f2.keys() == set(data.keys())
35
36
37
38
39
40
    assert f2.num_columns == 3
    assert f2.num_rows == N
    f1.clear()
    assert len(f1.schemes) == 0
    assert f1.num_rows == 0

Minjie Wang's avatar
Minjie Wang committed
41
42
def test_column1():
    # Test frame column getter/setter
43
44
    data = create_test_data()
    f = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
45
46
    assert f.num_rows == N
    assert len(f) == 3
47
    assert U.allclose(f['a1'].data, data['a1'].data)
48
    f['a1'] = data['a2']
49
    assert U.allclose(f['a2'].data, data['a2'].data)
Minjie Wang's avatar
Minjie Wang committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    # add a different length column should fail
    def failed_add_col():
        f['a4'] = th.zeros([N+1, D])
    assert check_fail(failed_add_col)
    # delete all the columns
    del f['a1']
    del f['a2']
    assert len(f) == 1
    del f['a3']
    assert len(f) == 0

def test_column2():
    # Test frameref column getter/setter
    data = Frame(create_test_data())
Minjie Wang's avatar
Minjie Wang committed
64
    f = FrameRef(data, toindex([3, 4, 5, 6, 7]))
Minjie Wang's avatar
Minjie Wang committed
65
66
    assert f.num_rows == 5
    assert len(f) == 3
67
    assert U.allclose(f['a1'], data['a1'].data[3:8])
Minjie Wang's avatar
Minjie Wang committed
68
69
    # set column should reflect on the referenced data
    f['a1'] = th.zeros([5, D])
70
    assert U.allclose(data['a1'].data[3:8], th.zeros([5, D]))
Minjie Wang's avatar
Minjie Wang committed
71
72
73
74
75
    # add new partial column should fail with error initializer
    f.set_initializer(lambda shape, dtype : assert_(False))
    def failed_add_col():
        f['a4'] = th.ones([5, D])
    assert check_fail(failed_add_col)
76

Minjie Wang's avatar
Minjie Wang committed
77
78
def test_append1():
    # test append API on Frame
79
    data = create_test_data()
Minjie Wang's avatar
Minjie Wang committed
80
81
82
83
84
85
86
    f1 = Frame()
    f2 = Frame(data)
    f1.append(data)
    assert f1.num_rows == N
    f1.append(f2)
    assert f1.num_rows == 2 * N
    c1 = f1['a1']
Minjie Wang's avatar
Minjie Wang committed
87
    assert c1.data.shape == (2 * N, D)
Minjie Wang's avatar
Minjie Wang committed
88
    truth = th.cat([data['a1'], data['a1']])
89
    assert U.allclose(truth, c1.data)
Minjie Wang's avatar
Minjie Wang committed
90
91
92
93
94
    # append dict of different length columns should fail
    f3 = {'a1' : th.zeros((3, D)), 'a2' : th.zeros((3, D)), 'a3' : th.zeros((2, D))}
    def failed_append():
        f1.append(f3)
    assert check_fail(failed_append)
Minjie Wang's avatar
Minjie Wang committed
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113

def test_append2():
    # test append on FrameRef
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    # append on the underlying frame should not reflect on the ref
    data.append(data)
    assert f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == N
    # append on the FrameRef should work
    f.append(data)
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == 3 * N
    new_idx = list(range(N)) + list(range(2*N, 4*N))
Minjie Wang's avatar
Minjie Wang committed
114
    assert th.all(f._index.tousertensor() == th.tensor(new_idx, dtype=th.int64))
Minjie Wang's avatar
Minjie Wang committed
115
116
    assert data.num_rows == 4 * N

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
def test_append3():
    # test append on empty frame
    f = Frame(num_rows=5)
    data = {'h' : th.ones((3, 2))}
    f.append(data)
    assert f.num_rows == 8
    ans = th.cat([th.zeros((5, 2)), th.ones((3, 2))], dim=0)
    assert U.allclose(f['h'].data, ans)
    # test append with new column
    data = {'h' : 2 * th.ones((3, 2)), 'w' : 2 * th.ones((3, 2))}
    f.append(data)
    assert f.num_rows == 11
    ans1 = th.cat([ans, 2 * th.ones((3, 2))], 0)
    ans2 = th.cat([th.zeros((8, 2)), 2 * th.ones((3, 2))], 0)
    assert U.allclose(f['h'].data, ans1)
    assert U.allclose(f['w'].data, ans2)

Minjie Wang's avatar
Minjie Wang committed
134
135
136
137
def test_row1():
    # test row getter/setter
    data = create_test_data()
    f = FrameRef(Frame(data))
138
139
140

    # getter
    # test non-duplicate keys
Minjie Wang's avatar
Minjie Wang committed
141
    rowid = Index(th.tensor([0, 2]))
142
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
143
    for k, v in rows.items():
144
        assert v.shape == (len(rowid), D)
145
        assert U.allclose(v, data[k][rowid])
146
    # test duplicate keys
Minjie Wang's avatar
Minjie Wang committed
147
    rowid = Index(th.tensor([8, 2, 2, 1]))
148
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
149
    for k, v in rows.items():
150
        assert v.shape == (len(rowid), D)
151
        assert U.allclose(v, data[k][rowid])
152
153

    # setter
Minjie Wang's avatar
Minjie Wang committed
154
    rowid = Index(th.tensor([0, 2, 4]))
155
156
157
158
159
    vals = {'a1' : th.zeros((len(rowid), D)),
            'a2' : th.zeros((len(rowid), D)),
            'a3' : th.zeros((len(rowid), D)),
            }
    f[rowid] = vals
Minjie Wang's avatar
Minjie Wang committed
160
    for k, v in f[rowid].items():
161
        assert U.allclose(v, th.zeros((len(rowid), D)))
162

Minjie Wang's avatar
Minjie Wang committed
163
164
165
166
167
168
    # setting rows with new column should raise error with error initializer
    f.set_initializer(lambda shape, dtype : assert_(False))
    def failed_update_rows():
        vals['a4'] = th.ones((len(rowid), D))
        f[rowid] = vals
    assert check_fail(failed_update_rows)
169

Minjie Wang's avatar
Minjie Wang committed
170
171
def test_row2():
    # test row getter/setter autograd compatibility
172
    data = create_test_data(grad=True)
Minjie Wang's avatar
Minjie Wang committed
173
    f = FrameRef(Frame(data))
174
175
176
177

    # getter
    c1 = f['a1']
    # test non-duplicate keys
Minjie Wang's avatar
Minjie Wang committed
178
    rowid = Index(th.tensor([0, 2]))
179
180
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
181
    assert U.allclose(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
182
183
    c1.grad.data.zero_()
    # test duplicate keys
Minjie Wang's avatar
Minjie Wang committed
184
    rowid = Index(th.tensor([8, 2, 2, 1]))
185
186
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
187
    assert U.allclose(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
188
189
190
191
    c1.grad.data.zero_()

    # setter
    c1 = f['a1']
Minjie Wang's avatar
Minjie Wang committed
192
    rowid = Index(th.tensor([0, 2, 4]))
193
194
195
196
197
198
199
    vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            }
    f[rowid] = vals
    c11 = f['a1']
    c11.backward(th.ones((N, D)))
200
201
    assert U.allclose(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
    assert U.allclose(vals['a1'].grad, th.ones((len(rowid), D)))
202
203
    assert vals['a2'].grad is None

Minjie Wang's avatar
Minjie Wang committed
204
205
206
207
208
209
210
def test_row3():
    # test row delete
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
211
    del f[toindex(th.tensor([2, 3]))]
Minjie Wang's avatar
Minjie Wang committed
212
213
214
215
216
217
218
219
220
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    # delete is lazy: only reflect on the ref while the
    # underlying storage should not be touched
    assert f.num_rows == N - 2
    assert data.num_rows == N
    newidx = list(range(N))
    newidx.pop(2)
    newidx.pop(2)
Minjie Wang's avatar
Minjie Wang committed
221
    newidx = toindex(newidx)
Minjie Wang's avatar
Minjie Wang committed
222
    for k, v in f.items():
223
        assert U.allclose(v, data[k][newidx])
Minjie Wang's avatar
Minjie Wang committed
224

225
226
227
228
229
230
231
232
233
def test_row4():
    # test updating row with empty frame but has preset num_rows
    f = FrameRef(Frame(num_rows=5))
    rowid = Index(th.tensor([0, 2, 4]))
    f[rowid] = {'h' : th.ones((3, 2))}
    ans = th.zeros((5, 2))
    ans[th.tensor([0, 2, 4])] = th.ones((3, 2))
    assert U.allclose(f['h'], ans)

Minjie Wang's avatar
Minjie Wang committed
234
235
def test_sharing():
    data = Frame(create_test_data())
Minjie Wang's avatar
Minjie Wang committed
236
237
    f1 = FrameRef(data, index=toindex([0, 1, 2, 3]))
    f2 = FrameRef(data, index=toindex([2, 3, 4, 5, 6]))
Minjie Wang's avatar
Minjie Wang committed
238
239
    # test read
    for k, v in f1.items():
240
        assert U.allclose(data[k].data[0:4], v)
Minjie Wang's avatar
Minjie Wang committed
241
    for k, v in f2.items():
242
        assert U.allclose(data[k].data[2:7], v)
Minjie Wang's avatar
Minjie Wang committed
243
    f2_a1 = f2['a1'].data
Minjie Wang's avatar
Minjie Wang committed
244
245
    # test write
    # update own ref should not been seen by the other.
Minjie Wang's avatar
Minjie Wang committed
246
    f1[Index(th.tensor([0, 1]))] = {
Minjie Wang's avatar
Minjie Wang committed
247
248
249
250
            'a1' : th.zeros([2, D]),
            'a2' : th.zeros([2, D]),
            'a3' : th.zeros([2, D]),
            }
251
    assert U.allclose(f2['a1'], f2_a1)
Minjie Wang's avatar
Minjie Wang committed
252
    # update shared space should been seen by the other.
Minjie Wang's avatar
Minjie Wang committed
253
    f1[Index(th.tensor([2, 3]))] = {
Minjie Wang's avatar
Minjie Wang committed
254
255
256
257
258
            'a1' : th.ones([2, D]),
            'a2' : th.ones([2, D]),
            'a3' : th.ones([2, D]),
            }
    f2_a1[0:2] = th.ones([2, D])
259
    assert U.allclose(f2['a1'], f2_a1)
260

261
262
def test_slicing():
    data = Frame(create_test_data(grad=True))
Minjie Wang's avatar
Minjie Wang committed
263
264
    f1 = FrameRef(data, index=toindex(slice(1, 5)))
    f2 = FrameRef(data, index=toindex(slice(3, 8)))
265
266
    # test read
    for k, v in f1.items():
267
        assert U.allclose(data[k].data[1:5], v)
268
269
270
271
272
273
274
    f2_a1 = f2['a1'].data
    # test write
    f1[Index(th.tensor([0, 1]))] = {
            'a1': th.zeros([2, D]),
            'a2': th.zeros([2, D]),
            'a3': th.zeros([2, D]),
            }
275
    assert U.allclose(f2['a1'], f2_a1)
276
277
278
279
280
281
    
    f1[Index(th.tensor([2, 3]))] = {
            'a1': th.ones([2, D]),
            'a2': th.ones([2, D]),
            'a3': th.ones([2, D]),
            }
Minjie Wang's avatar
Minjie Wang committed
282
    f2_a1[toindex(slice(0,2))] = 1
283
    assert U.allclose(f2['a1'], f2_a1)
284

Minjie Wang's avatar
Minjie Wang committed
285
    f1[toindex(slice(2,4))] = {
286
287
288
289
            'a1': th.zeros([2, D]),
            'a2': th.zeros([2, D]),
            'a3': th.zeros([2, D]),
            }
Minjie Wang's avatar
Minjie Wang committed
290
    f2_a1[toindex(slice(0,2))] = 0
291
    assert U.allclose(f2['a1'], f2_a1)
292

293
294
295
296
297
298
299
300
301
def test_add_rows():
    data = Frame()
    f1 = FrameRef(data)
    f1.add_rows(4)
    x = th.randn(1, 4)
    f1[Index(th.tensor([0]))] = {'x': x}
    ans = th.cat([x, th.zeros(3, 4)])
    assert U.allclose(f1['x'], ans)
    f1.add_rows(4)
Minjie Wang's avatar
Minjie Wang committed
302
    f1[toindex(slice(4,8))] = {'x': th.ones(4, 4), 'y': th.ones(4, 5)}
303
304
305
306
307
    ans = th.cat([ans, th.ones(4, 4)])
    assert U.allclose(f1['x'], ans)
    ans = th.cat([th.zeros(4, 5), th.ones(4, 5)])
    assert U.allclose(f1['y'], ans)

Minjie Wang's avatar
Minjie Wang committed
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
def test_inplace():
    f = FrameRef(Frame(create_test_data()))
    print(f.schemes)
    a1addr = f['a1'].data.data_ptr()
    a2addr = f['a2'].data.data_ptr()
    a3addr = f['a3'].data.data_ptr()

    # column updates are always out-of-place
    f['a1'] = th.ones((N, D))
    newa1addr = f['a1'].data.data_ptr()
    assert a1addr != newa1addr
    a1addr = newa1addr
    # full row update that becomes column update
    f[toindex(slice(0, N))] = {'a1' : th.ones((N, D))}
    assert f['a1'].data.data_ptr() != a1addr

    # row update (outplace) w/ slice
    f[toindex(slice(1, 4))] = {'a2' : th.ones((3, D))}
    newa2addr = f['a2'].data.data_ptr()
    assert a2addr != newa2addr
    a2addr = newa2addr
    # row update (outplace) w/ list
    f[toindex([1, 3, 5])] = {'a2' : th.ones((3, D))}
    newa2addr = f['a2'].data.data_ptr()
    assert a2addr != newa2addr
    a2addr = newa2addr

    # row update (inplace) w/ slice
    f.update_data(toindex(slice(1, 4)), {'a2' : th.ones((3, D))}, True)
    newa2addr = f['a2'].data.data_ptr()
    assert a2addr == newa2addr
    # row update (inplace) w/ list
    f.update_data(toindex([1, 3, 5]), {'a2' : th.ones((3, D))}, True)
    newa2addr = f['a2'].data.data_ptr()
    assert a2addr == newa2addr

344
345
if __name__ == '__main__':
    test_create()
Minjie Wang's avatar
Minjie Wang committed
346
347
348
349
    test_column1()
    test_column2()
    test_append1()
    test_append2()
350
    test_append3()
Minjie Wang's avatar
Minjie Wang committed
351
352
353
    test_row1()
    test_row2()
    test_row3()
354
    test_row4()
Minjie Wang's avatar
Minjie Wang committed
355
    test_sharing()
356
    test_slicing()
357
    test_add_rows()
Minjie Wang's avatar
Minjie Wang committed
358
    test_inplace()