test_frame.py 7.14 KB
Newer Older
1
2
3
import torch as th
from torch.autograd import Variable
import numpy as np
Minjie Wang's avatar
Minjie Wang committed
4
from dgl.frame import Frame, FrameRef
Minjie Wang's avatar
Minjie Wang committed
5
from dgl.utils import Index
6
7

N = 10
Minjie Wang's avatar
Minjie Wang committed
8
D = 5
9
10

def check_eq(a, b):
Minjie Wang's avatar
Minjie Wang committed
11
12
13
14
15
16
17
18
    return a.shape == b.shape and np.allclose(a.numpy(), b.numpy())

def check_fail(fn):
    try:
        fn()
        return False
    except:
        return True
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41

def create_test_data(grad=False):
    c1 = Variable(th.randn(N, D), requires_grad=grad)
    c2 = Variable(th.randn(N, D), requires_grad=grad)
    c3 = Variable(th.randn(N, D), requires_grad=grad)
    return {'a1' : c1, 'a2' : c2, 'a3' : c3}

def test_create():
    data = create_test_data()
    f1 = Frame()
    for k, v in data.items():
        f1.add_column(k, v)
    assert f1.schemes == set(data.keys())
    assert f1.num_columns == 3
    assert f1.num_rows == N
    f2 = Frame(data)
    assert f2.schemes == set(data.keys())
    assert f2.num_columns == 3
    assert f2.num_rows == N
    f1.clear()
    assert len(f1.schemes) == 0
    assert f1.num_rows == 0

Minjie Wang's avatar
Minjie Wang committed
42
43
def test_column1():
    # Test frame column getter/setter
44
45
    data = create_test_data()
    f = Frame(data)
Minjie Wang's avatar
Minjie Wang committed
46
47
48
    assert f.num_rows == N
    assert len(f) == 3
    assert check_eq(f['a1'], data['a1'])
49
    f['a1'] = data['a2']
Minjie Wang's avatar
Minjie Wang committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
    assert check_eq(f['a2'], data['a2'])
    # add a different length column should fail
    def failed_add_col():
        f['a4'] = th.zeros([N+1, D])
    assert check_fail(failed_add_col)
    # delete all the columns
    del f['a1']
    del f['a2']
    assert len(f) == 1
    del f['a3']
    assert f.num_rows == 0
    assert len(f) == 0
    # add a different length column should succeed
    f['a4'] = th.zeros([N+1, D])
    assert f.num_rows == N+1
    assert len(f) == 1

def test_column2():
    # Test frameref column getter/setter
    data = Frame(create_test_data())
    f = FrameRef(data, [3, 4, 5, 6, 7])
    assert f.num_rows == 5
    assert len(f) == 3
    assert check_eq(f['a1'], data['a1'][3:8])
    # set column should reflect on the referenced data
    f['a1'] = th.zeros([5, D])
    assert check_eq(data['a1'][3:8], th.zeros([5, D]))
    # add new column should be padded with zero
    f['a4'] = th.ones([5, D])
    assert len(data) == 4
    assert check_eq(data['a4'][0:3], th.zeros([3, D]))
    assert check_eq(data['a4'][3:8], th.ones([5, D]))
    assert check_eq(data['a4'][8:10], th.zeros([2, D]))
83

Minjie Wang's avatar
Minjie Wang committed
84
85
def test_append1():
    # test append API on Frame
86
    data = create_test_data()
Minjie Wang's avatar
Minjie Wang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
    f1 = Frame()
    f2 = Frame(data)
    f1.append(data)
    assert f1.num_rows == N
    f1.append(f2)
    assert f1.num_rows == 2 * N
    c1 = f1['a1']
    assert c1.shape == (2 * N, D)
    truth = th.cat([data['a1'], data['a1']])
    assert check_eq(truth, c1)

def test_append2():
    # test append on FrameRef
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    # append on the underlying frame should not reflect on the ref
    data.append(data)
    assert f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == N
    # append on the FrameRef should work
    f.append(data)
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    assert f.num_rows == 3 * N
    new_idx = list(range(N)) + list(range(2*N, 4*N))
Minjie Wang's avatar
Minjie Wang committed
116
    assert check_eq(f.index().totensor(), th.tensor(new_idx))
Minjie Wang's avatar
Minjie Wang committed
117
118
119
120
121
122
    assert data.num_rows == 4 * N

def test_row1():
    # test row getter/setter
    data = create_test_data()
    f = FrameRef(Frame(data))
123
124
125

    # getter
    # test non-duplicate keys
Minjie Wang's avatar
Minjie Wang committed
126
    rowid = Index(th.tensor([0, 2]))
127
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
128
    for k, v in rows.items():
129
        assert v.shape == (len(rowid), D)
Minjie Wang's avatar
Minjie Wang committed
130
        assert check_eq(v, data[k][rowid])
131
    # test duplicate keys
Minjie Wang's avatar
Minjie Wang committed
132
    rowid = Index(th.tensor([8, 2, 2, 1]))
133
    rows = f[rowid]
Minjie Wang's avatar
Minjie Wang committed
134
    for k, v in rows.items():
135
        assert v.shape == (len(rowid), D)
Minjie Wang's avatar
Minjie Wang committed
136
        assert check_eq(v, data[k][rowid])
137
138

    # setter
Minjie Wang's avatar
Minjie Wang committed
139
    rowid = Index(th.tensor([0, 2, 4]))
140
141
142
143
144
    vals = {'a1' : th.zeros((len(rowid), D)),
            'a2' : th.zeros((len(rowid), D)),
            'a3' : th.zeros((len(rowid), D)),
            }
    f[rowid] = vals
Minjie Wang's avatar
Minjie Wang committed
145
146
    for k, v in f[rowid].items():
        assert check_eq(v, th.zeros((len(rowid), D)))
147

148
149
150
151
152
    # setting rows with new column should automatically add a new column
    vals['a4'] = th.ones((len(rowid), D))
    f[rowid] = vals
    assert len(f) == 4

Minjie Wang's avatar
Minjie Wang committed
153
154
def test_row2():
    # test row getter/setter autograd compatibility
155
    data = create_test_data(grad=True)
Minjie Wang's avatar
Minjie Wang committed
156
    f = FrameRef(Frame(data))
157
158
159
160

    # getter
    c1 = f['a1']
    # test non-duplicate keys
Minjie Wang's avatar
Minjie Wang committed
161
    rowid = Index(th.tensor([0, 2]))
162
163
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
Minjie Wang's avatar
Minjie Wang committed
164
    assert check_eq(c1.grad[:,0], th.tensor([1., 0., 1., 0., 0., 0., 0., 0., 0., 0.]))
165
166
    c1.grad.data.zero_()
    # test duplicate keys
Minjie Wang's avatar
Minjie Wang committed
167
    rowid = Index(th.tensor([8, 2, 2, 1]))
168
169
    rows = f[rowid]
    rows['a1'].backward(th.ones((len(rowid), D)))
Minjie Wang's avatar
Minjie Wang committed
170
    assert check_eq(c1.grad[:,0], th.tensor([0., 1., 2., 0., 0., 0., 0., 0., 1., 0.]))
171
172
173
174
    c1.grad.data.zero_()

    # setter
    c1 = f['a1']
Minjie Wang's avatar
Minjie Wang committed
175
    rowid = Index(th.tensor([0, 2, 4]))
176
177
178
179
180
181
182
    vals = {'a1' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a2' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            'a3' : Variable(th.zeros((len(rowid), D)), requires_grad=True),
            }
    f[rowid] = vals
    c11 = f['a1']
    c11.backward(th.ones((N, D)))
Minjie Wang's avatar
Minjie Wang committed
183
184
    assert check_eq(c1.grad[:,0], th.tensor([0., 1., 0., 1., 0., 1., 1., 1., 1., 1.]))
    assert check_eq(vals['a1'].grad, th.ones((len(rowid), D)))
185
186
    assert vals['a2'].grad is None

Minjie Wang's avatar
Minjie Wang committed
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
def test_row3():
    # test row delete
    data = Frame(create_test_data())
    f = FrameRef(data)
    assert f.is_contiguous()
    assert f.is_span_whole_column()
    assert f.num_rows == N
    del f[th.tensor([2, 3])]
    assert not f.is_contiguous()
    assert not f.is_span_whole_column()
    # delete is lazy: only reflect on the ref while the
    # underlying storage should not be touched
    assert f.num_rows == N - 2
    assert data.num_rows == N
    newidx = list(range(N))
    newidx.pop(2)
    newidx.pop(2)
    for k, v in f.items():
        assert check_eq(v, data[k][th.tensor(newidx)])

def test_sharing():
    data = Frame(create_test_data())
    f1 = FrameRef(data, index=[0, 1, 2, 3])
    f2 = FrameRef(data, index=[2, 3, 4, 5, 6])
    # test read
    for k, v in f1.items():
        assert check_eq(data[k][0:4], v)
    for k, v in f2.items():
        assert check_eq(data[k][2:7], v)
    f2_a1 = f2['a1']
    # test write
    # update own ref should not been seen by the other.
Minjie Wang's avatar
Minjie Wang committed
219
    f1[Index(th.tensor([0, 1]))] = {
Minjie Wang's avatar
Minjie Wang committed
220
221
222
223
224
225
            'a1' : th.zeros([2, D]),
            'a2' : th.zeros([2, D]),
            'a3' : th.zeros([2, D]),
            }
    assert check_eq(f2['a1'], f2_a1)
    # update shared space should been seen by the other.
Minjie Wang's avatar
Minjie Wang committed
226
    f1[Index(th.tensor([2, 3]))] = {
Minjie Wang's avatar
Minjie Wang committed
227
228
229
230
231
232
            'a1' : th.ones([2, D]),
            'a2' : th.ones([2, D]),
            'a3' : th.ones([2, D]),
            }
    f2_a1[0:2] = th.ones([2, D])
    assert check_eq(f2['a1'], f2_a1)
233
234
235

if __name__ == '__main__':
    test_create()
Minjie Wang's avatar
Minjie Wang committed
236
237
238
239
240
241
242
243
    test_column1()
    test_column2()
    test_append1()
    test_append2()
    test_row1()
    test_row2()
    test_row3()
    test_sharing()