test_frame.py.bak 12.8 KB
Newer Older
1
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
2
from dgl.frame import Frame, FrameRef
Minjie Wang's avatar
Minjie Wang committed
3
from dgl.utils import Index, toindex
4
import backend as F
VoVAllen's avatar
VoVAllen committed
5
6
import dgl
import unittest
7
8
9
import pickle
import pytest
import io
10
11

N = 10
Minjie Wang's avatar
Minjie Wang committed
12
D = 5
13

Minjie Wang's avatar
Minjie Wang committed
14
15
16
17
18
19
def check_fail(fn):
    try:
        fn()
        return False
    except:
        return True
20

21
22
23
24
def create_test_data(grad=False, dtype=F.float32):
    c1 = F.astype(F.randn((N, D)), dtype)
    c2 = F.astype(F.randn((N, D)), dtype)
    c3 = F.astype(F.randn((N, D)), dtype)
25
26
27
28
    if grad:
        c1 = F.attach_grad(c1)
        c2 = F.attach_grad(c2)
        c3 = F.attach_grad(c3)
29
30
31
32
    return {'a1' : c1, 'a2' : c2, 'a3' : c3}

def test_create():
    data = create_test_data()
33
    f1 = Frame(num_rows=N)
34
    for k, v in data.items():
Minjie Wang's avatar
Minjie Wang committed
35
36
37
        f1.update_column(k, v)
    print(f1.schemes)
    assert f1.keys() == set(data.keys())
38
39
40
    assert f1.num_columns == 3
    assert f1.num_rows == N
    f2 = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
41
    assert f2.keys() == set(data.keys())
42
43
44
45
46
47
    assert f2.num_columns == 3
    assert f2.num_rows == N
    f1.clear()
    assert len(f1.schemes) == 0
    assert f1.num_rows == 0

Minjie Wang's avatar
Minjie Wang committed
48
49
def test_column1():
    # Test frame column getter/setter
50
51
    data = create_test_data()
    f = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
52
53
    assert f.num_rows == N
    assert len(f) == 3
54
    assert F.allclose(f['a1'].data, data['a1'])
55
    f['a1'] = data['a2']
56
    assert F.allclose(f['a2'].data, data['a2'])
Minjie Wang's avatar
Minjie Wang committed
57
58
    # add a different length column should fail
    def failed_add_col():
59
        f['a4'] = F.zeros([N+1, D])
Minjie Wang's avatar
Minjie Wang committed
60
61
62
63
64
65
66
67
68
69
70
    assert check_fail(failed_add_col)
    # delete all the columns
    del f['a1']
    del f['a2']
    assert len(f) == 1
    del f['a3']
    assert len(f) == 0

def test_column2():
    # Test frameref column getter/setter
    data = Frame(create_test_data())
Minjie Wang's avatar
Minjie Wang committed
71
    f = FrameRef(data, toindex([3, 4, 5, 6, 7]))
Minjie Wang's avatar
Minjie Wang committed
72
73
    assert f.num_rows == 5
    assert len(f) == 3
74
    assert F.allclose(f['a1'], F.narrow_row(data['a1'].data, 3, 8))
Minjie Wang's avatar
Minjie Wang committed
75
    # set column should reflect on the referenced data
76
77
    f['a1'] = F.zeros([5, D])
    assert F.allclose(F.narrow_row(data['a1'].data, 3, 8), F.zeros([5, D]))
Minjie Wang's avatar
Minjie Wang committed
78
79
80
    # add new partial column should fail with error initializer
    f.set_initializer(lambda shape, dtype : assert_(False))
    def failed_add_col():
81
        f['a4'] = F.ones([5, D])
Minjie Wang's avatar
Minjie Wang committed
82
    assert check_fail(failed_add_col)
83

Minjie Wang's avatar
Minjie Wang committed
84
85
def test_append1():
    # test append API on Frame
86
    data = create_test_data()
Minjie Wang's avatar
Minjie Wang committed
87
88
89
90
91
92
93
    f1 = Frame()
    f2 = Frame(data)
    f1.append(data)
    assert f1.num_rows == N
    f1.append(f2)
    assert f1.num_rows == 2 * N
    c1 = f1['a1']
94
95
96
    assert tuple(F.shape(c1.data)) == (2 * N, D)
    truth = F.cat([data['a1'], data['a1']], 0)
    assert F.allclose(truth, c1.data)
Minjie Wang's avatar
Minjie Wang committed
97
    # append dict of different length columns should fail
98
    f3 = {'a1' : F.zeros((3, D)), 'a2' : F.zeros((3, D)), 'a3' : F.zeros((2, D))}
Minjie Wang's avatar
Minjie Wang committed
99
100
101
    def failed_append():
        f1.append(f3)
    assert check_fail(failed_append)
Minjie Wang's avatar
Minjie Wang committed
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

def test_append2():
    # test append on FrameRef
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    # append on the underlying frame should not reflect on the ref
    data.append(data)
    assert f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == N
    # append on the FrameRef should work
    f.append(data)
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == 3 * N
    new_idx = list(range(N)) + list(range(2*N, 4*N))
121
    assert F.array_equal(f._index.tousertensor(), F.copy_to(F.tensor(new_idx, dtype=F.int64), F.cpu()))
Minjie Wang's avatar
Minjie Wang committed
122
123
    assert data.num_rows == 4 * N

124
125
126
def test_append3():
    # test append on empty frame
    f = Frame(num_rows=5)
127
    data = {'h' : F.ones((3, 2))}
128
129
    f.append(data)
    assert f.num_rows == 8
130
131
    ans = F.cat([F.zeros((5, 2)), F.ones((3, 2))], 0)
    assert F.allclose(f['h'].data, ans)
132
    # test append with new column
133
    data = {'h' : 2 * F.ones((3, 2)), 'w' : 2 * F.ones((3, 2))}
134
135
    f.append(data)
    assert f.num_rows == 11
136
137
138
139
    ans1 = F.cat([ans, 2 * F.ones((3, 2))], 0)
    ans2 = F.cat([F.zeros((8, 2)), 2 * F.ones((3, 2))], 0)
    assert F.allclose(f['h'].data, ans1)
    assert F.allclose(f['w'].data, ans2)
140

Minjie Wang's avatar
Minjie Wang committed
141
142
143
144
def test_row1():
    # test row getter/setter
    data = create_test_data()
    f = FrameRef(Frame(data))
145
146
147

    # getter
    # test non-duplicate keys
148
    rowid = Index(F.tensor([0, 2]))
149
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
150
    for k, v in rows.items():
151
        assert tuple(F.shape(v)) == (len(rowid), D)
152
        assert F.allclose(v, F.gather_row(data[k], F.tensor(rowid.tousertensor())))
153
    # test duplicate keys
154
    rowid = Index(F.tensor([8, 2, 2, 1]))
155
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
156
    for k, v in rows.items():
157
        assert tuple(F.shape(v)) == (len(rowid), D)
158
        assert F.allclose(v, F.gather_row(data[k], F.tensor(rowid.tousertensor())))
159
160

    # setter
161
162
163
164
    rowid = Index(F.tensor([0, 2, 4]))
    vals = {'a1' : F.zeros((len(rowid), D)),
            'a2' : F.zeros((len(rowid), D)),
            'a3' : F.zeros((len(rowid), D)),
165
166
            }
    f[rowid] = vals
Minjie Wang's avatar
Minjie Wang committed
167
    for k, v in f[rowid].items():
168
        assert F.allclose(v, F.zeros((len(rowid), D)))
169

Minjie Wang's avatar
Minjie Wang committed
170
171
172
    # setting rows with new column should raise error with error initializer
    f.set_initializer(lambda shape, dtype : assert_(False))
    def failed_update_rows():
173
        vals['a4'] = F.ones((len(rowid), D))
Minjie Wang's avatar
Minjie Wang committed
174
175
        f[rowid] = vals
    assert check_fail(failed_update_rows)
176

Minjie Wang's avatar
Minjie Wang committed
177
178
def test_row2():
    # test row getter/setter autograd compatibility
179
    data = create_test_data(grad=True)
Minjie Wang's avatar
Minjie Wang committed
180
    f = FrameRef(Frame(data))
181

182
183
184
185
186
187
188
    with F.record_grad():
        # getter
        c1 = f['a1']
        # test non-duplicate keys
        rowid = Index(F.tensor([0, 2]))
        rows = f[rowid]
        y = rows['a1']
VoVAllen's avatar
VoVAllen committed
189
        F.backward(y, F.ones((len(rowid), D)))
190
    assert F.allclose(F.grad(c1)[:,0], F.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
191

192
193
194
195
196
197
198
    f['a1'] = F.attach_grad(f['a1'])
    with F.record_grad():
        c1 = f['a1']
        # test duplicate keys
        rowid = Index(F.tensor([8, 2, 2, 1]))
        rows = f[rowid]
        y = rows['a1']
VoVAllen's avatar
VoVAllen committed
199
        F.backward(y, F.ones((len(rowid), D)))
200
201
202
203
204
205
206
207
208
209
210
211
212
    assert F.allclose(F.grad(c1)[:,0], F.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))

    f['a1'] = F.attach_grad(f['a1'])
    with F.record_grad():
        # setter
        c1 = f['a1']
        rowid = Index(F.tensor([0, 2, 4]))
        vals = {'a1' : F.attach_grad(F.zeros((len(rowid), D))),
                'a2' : F.attach_grad(F.zeros((len(rowid), D))),
                'a3' : F.attach_grad(F.zeros((len(rowid), D))),
                }
        f[rowid] = vals
        c11 = f['a1']
VoVAllen's avatar
VoVAllen committed
213
        F.backward(c11, F.ones((N, D)))
214
215
216
    assert F.allclose(F.grad(c1)[:,0], F.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
    assert F.allclose(F.grad(vals['a1']), F.ones((len(rowid), D)))
    assert F.is_no_grad(vals['a2'])
217

Minjie Wang's avatar
Minjie Wang committed
218
219
220
221
222
223
224
def test_row3():
    # test row delete
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
225
    del f[toindex(F.tensor([2, 3]))]
Minjie Wang's avatar
Minjie Wang committed
226
227
228
229
230
231
232
233
234
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    # delete is lazy: only reflect on the ref while the
    # underlying storage should not be touched
    assert f.num_rows == N - 2
    assert data.num_rows == N
    newidx = list(range(N))
    newidx.pop(2)
    newidx.pop(2)
Minjie Wang's avatar
Minjie Wang committed
235
    newidx = toindex(newidx)
Minjie Wang's avatar
Minjie Wang committed
236
    for k, v in f.items():
237
        assert F.allclose(v, data[k][newidx])
Minjie Wang's avatar
Minjie Wang committed
238

VoVAllen's avatar
VoVAllen committed
239
240

@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
241
242
243
def test_row4():
    # test updating row with empty frame but has preset num_rows
    f = FrameRef(Frame(num_rows=5))
244
245
246
247
248
    rowid = Index(F.tensor([0, 2, 4]))
    f[rowid] = {'h' : F.ones((3, 2))}
    ans = F.zeros((5, 2))
    ans[F.tensor([0, 2, 4])] = F.ones((3, 2))
    assert F.allclose(f['h'], ans)
249

VoVAllen's avatar
VoVAllen committed
250
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
Minjie Wang's avatar
Minjie Wang committed
251
252
def test_sharing():
    data = Frame(create_test_data())
Minjie Wang's avatar
Minjie Wang committed
253
254
    f1 = FrameRef(data, index=toindex([0, 1, 2, 3]))
    f2 = FrameRef(data, index=toindex([2, 3, 4, 5, 6]))
Minjie Wang's avatar
Minjie Wang committed
255
256
    # test read
    for k, v in f1.items():
257
        assert F.allclose(F.narrow_row(data[k].data, 0, 4), v)
Minjie Wang's avatar
Minjie Wang committed
258
    for k, v in f2.items():
259
260
        assert F.allclose(F.narrow_row(data[k].data, 2, 7), v)
    f2_a1 = f2['a1']
Minjie Wang's avatar
Minjie Wang committed
261
262
    # test write
    # update own ref should not been seen by the other.
263
264
265
266
    f1[Index(F.tensor([0, 1]))] = {
            'a1' : F.zeros([2, D]),
            'a2' : F.zeros([2, D]),
            'a3' : F.zeros([2, D]),
Minjie Wang's avatar
Minjie Wang committed
267
            }
268
    assert F.allclose(f2['a1'], f2_a1)
Minjie Wang's avatar
Minjie Wang committed
269
    # update shared space should been seen by the other.
270
271
272
273
    f1[Index(F.tensor([2, 3]))] = {
            'a1' : F.ones([2, D]),
            'a2' : F.ones([2, D]),
            'a3' : F.ones([2, D]),
Minjie Wang's avatar
Minjie Wang committed
274
            }
275
276
    F.narrow_row_set(f2_a1, 0, 2, F.ones([2, D]))
    assert F.allclose(f2['a1'], f2_a1)
277

VoVAllen's avatar
VoVAllen committed
278
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
279
280
def test_slicing():
    data = Frame(create_test_data(grad=True))
Minjie Wang's avatar
Minjie Wang committed
281
282
    f1 = FrameRef(data, index=toindex(slice(1, 5)))
    f2 = FrameRef(data, index=toindex(slice(3, 8)))
283
284
    # test read
    for k, v in f1.items():
285
286
        assert F.allclose(F.narrow_row(data[k].data, 1, 5), v)
    f2_a1 = f2['a1']    # is a tensor
287
    # test write
288
289
290
291
    f1[Index(F.tensor([0, 1]))] = {
            'a1': F.zeros([2, D]),
            'a2': F.zeros([2, D]),
            'a3': F.zeros([2, D]),
292
            }
293
    assert F.allclose(f2['a1'], f2_a1)
294

295
296
297
298
    f1[Index(F.tensor([2, 3]))] = {
            'a1': F.ones([2, D]),
            'a2': F.ones([2, D]),
            'a3': F.ones([2, D]),
299
            }
300
301
    F.narrow_row_set(f2_a1, 0, 2, 1)
    assert F.allclose(f2['a1'], f2_a1)
302

303
304
305
306
    f1[toindex(slice(2, 4))] = {
            'a1': F.zeros([2, D]),
            'a2': F.zeros([2, D]),
            'a3': F.zeros([2, D]),
307
            }
308
309
    F.narrow_row_set(f2_a1, 0, 2, 0)
    assert F.allclose(f2['a1'], f2_a1)
310

311
312
313
314
def test_add_rows():
    data = Frame()
    f1 = FrameRef(data)
    f1.add_rows(4)
315
316
317
318
    x = F.randn((1, 4))
    f1[Index(F.tensor([0]))] = {'x': x}
    ans = F.cat([x, F.zeros((3, 4))], 0)
    assert F.allclose(f1['x'], ans)
319
    f1.add_rows(4)
320
321
322
323
324
    f1[toindex(slice(4, 8))] = {'x': F.ones((4, 4)), 'y': F.ones((4, 5))}
    ans = F.cat([ans, F.ones((4, 4))], 0)
    assert F.allclose(f1['x'], ans)
    ans = F.cat([F.zeros((4, 5)), F.ones((4, 5))], 0)
    assert F.allclose(f1['y'], ans)
325

VoVAllen's avatar
VoVAllen committed
326
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
Minjie Wang's avatar
Minjie Wang committed
327
328
329
def test_inplace():
    f = FrameRef(Frame(create_test_data()))
    print(f.schemes)
330
331
332
    a1addr = id(f['a1'])
    a2addr = id(f['a2'])
    a3addr = id(f['a3'])
Minjie Wang's avatar
Minjie Wang committed
333
334

    # column updates are always out-of-place
335
336
    f['a1'] = F.ones((N, D))
    newa1addr = id(f['a1'])
Minjie Wang's avatar
Minjie Wang committed
337
338
339
    assert a1addr != newa1addr
    a1addr = newa1addr
    # full row update that becomes column update
340
341
    f[toindex(slice(0, N))] = {'a1' : F.ones((N, D))}
    assert id(f['a1']) != a1addr
Minjie Wang's avatar
Minjie Wang committed
342
343

    # row update (outplace) w/ slice
344
345
    f[toindex(slice(1, 4))] = {'a2' : F.ones((3, D))}
    newa2addr = id(f['a2'])
Minjie Wang's avatar
Minjie Wang committed
346
347
348
    assert a2addr != newa2addr
    a2addr = newa2addr
    # row update (outplace) w/ list
349
350
    f[toindex([1, 3, 5])] = {'a2' : F.ones((3, D))}
    newa2addr = id(f['a2'])
Minjie Wang's avatar
Minjie Wang committed
351
352
353
354
    assert a2addr != newa2addr
    a2addr = newa2addr

    # row update (inplace) w/ slice
355
356
    f.update_data(toindex(slice(1, 4)), {'a2' : F.ones((3, D))}, True)
    newa2addr = id(f['a2'])
Minjie Wang's avatar
Minjie Wang committed
357
358
    assert a2addr == newa2addr
    # row update (inplace) w/ list
359
360
    f.update_data(toindex([1, 3, 5]), {'a2' : F.ones((3, D))}, True)
    newa2addr = id(f['a2'])
Minjie Wang's avatar
Minjie Wang committed
361
362
    assert a2addr == newa2addr

363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
@unittest.skipIf(dgl.backend.backend_name == "tensorflow", reason="TF doesn't support inplace update")
def test_clone():
    f = FrameRef(Frame(create_test_data()))
    f1 = f.clone()
    f2 = f.deepclone()

    f1['b'] = F.randn((N, D))
    f2['c'] = F.randn((N, D))
    assert 'b' not in f
    assert 'c' not in f

    f1['a1'][0, 0] = -10.
    assert float(F.asnumpy(f['a1'][0, 0])) == -10.
    x = float(F.asnumpy(f['a2'][0, 0]))
    f2['a2'][0, 0] = -10.
    assert float(F.asnumpy(f['a2'][0, 0])) == x

380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()
    return obj

@pytest.mark.parametrize('dtype',
        [F.float32, F.int32] if dgl.backend.backend_name == "mxnet" else [F.float32, F.int32, F.bool])
def test_pickle(dtype):
    f = create_test_data(dtype=dtype)
    newf = _reconstruct_pickle(f)
    assert F.array_equal(f['a1'], newf['a1'])
    assert F.array_equal(f['a2'], newf['a2'])
    assert F.array_equal(f['a3'], newf['a3'])

397
398
if __name__ == '__main__':
    test_create()
Minjie Wang's avatar
Minjie Wang committed
399
400
401
402
    test_column1()
    test_column2()
    test_append1()
    test_append2()
403
    test_append3()
Minjie Wang's avatar
Minjie Wang committed
404
405
406
    test_row1()
    test_row2()
    test_row3()
407
    test_row4()
Minjie Wang's avatar
Minjie Wang committed
408
    test_sharing()
409
    test_slicing()
410
    test_add_rows()
Minjie Wang's avatar
Minjie Wang committed
411
    test_inplace()