test_data.py 72.9 KB
Newer Older
1
2
3
import gzip
import os
import tempfile
4
import unittest
5

6
import backend as F
7
import numpy as np
8
9
import pandas as pd
import pytest
10
11
import yaml

12
import dgl
13
14
import dgl.data as data
from dgl import DGLError
15

16
17
18
19
20

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
21
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
22
23
24
25
def test_minigc():
    ds = data.MiniGCDataset(16, 10, 20)
    g, l = list(zip(*ds))
    print(g, l)
26
27
28
29
30
    g1 = ds[0][0]
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    ds = data.MiniGCDataset(16, 10, 20, transform=transform)
    g2 = ds[0][0]
    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
31

32
33
34
35
36

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
37
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
38
39
def test_gin():
    ds_n_graphs = {
40
41
42
43
44
        "MUTAG": 188,
        "IMDBBINARY": 1000,
        "IMDBMULTI": 1500,
        "PROTEINS": 1113,
        "PTC": 344,
45
    }
46
    transform = dgl.AddSelfLoop(allow_duplicate=True)
47
48
49
    for name, n_graphs in ds_n_graphs.items():
        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
        assert len(ds) == n_graphs, (len(ds), name)
50
        g1 = ds[0][0]
51
52
53
        ds = data.GINDataset(
            name, self_loop=False, degree_as_nlabel=False, transform=transform
        )
54
55
        g2 = ds[0][0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
Mufei Li's avatar
Mufei Li committed
56
        assert ds.num_classes == ds.gclasses
57

58
59
60
61
62

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
63
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
64
def test_fraud():
65
66
    transform = dgl.AddSelfLoop(allow_duplicate=True)

67
    g = data.FraudDataset("amazon")[0]
68
    assert g.num_nodes() == 11944
69
    num_edges1 = g.num_edges()
70
    g2 = data.FraudDataset("amazon", transform=transform)[0]
71
72
    # 3 edge types
    assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
73
74
75

    g = data.FraudAmazonDataset()[0]
    assert g.num_nodes() == 11944
76
77
78
    g2 = data.FraudAmazonDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
79
80
81

    g = data.FraudYelpDataset()[0]
    assert g.num_nodes() == 45954
82
83
84
    g2 = data.FraudYelpDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
85

86
87
88
89
90

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
91
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
92
def test_fakenews():
93
94
    transform = dgl.AddSelfLoop(allow_duplicate=True)

95
    ds = data.FakeNewsDataset("politifact", "bert")
96
    assert len(ds) == 314
97
    g = ds[0][0]
98
    g2 = data.FakeNewsDataset("politifact", "bert", transform=transform)[0][0]
99
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
100

101
    ds = data.FakeNewsDataset("gossipcop", "profile")
102
    assert len(ds) == 5464
103
    g = ds[0][0]
104
    g2 = data.FakeNewsDataset("gossipcop", "profile", transform=transform)[0][0]
105
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
106

107
108
109
110
111

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
112
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
113
def test_tudataset_regression():
114
    ds = data.TUDataset("ZINC_test", force_reload=True)
Mufei Li's avatar
Mufei Li committed
115
    assert ds.num_classes == ds.num_labels
Jinjing Zhou's avatar
Jinjing Zhou committed
116
    assert len(ds) == 5000
117
    g = ds[0][0]
Jinjing Zhou's avatar
Jinjing Zhou committed
118

119
    transform = dgl.AddSelfLoop(allow_duplicate=True)
120
    ds = data.TUDataset("ZINC_test", force_reload=True, transform=transform)
121
122
    g2 = ds[0][0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
123

124
125
126
127
128

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
129
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
130
131
132
def test_data_hash():
    class HashTestDataset(data.DGLDataset):
        def __init__(self, hash_key=()):
133
            super(HashTestDataset, self).__init__("hashtest", hash_key=hash_key)
134

135
136
137
        def _load(self):
            pass

138
139
140
    a = HashTestDataset((True, 0, "1", (1, 2, 3)))
    b = HashTestDataset((True, 0, "1", (1, 2, 3)))
    c = HashTestDataset((True, 0, "1", (1, 2, 4)))
141
142
143
    assert a.hash == b.hash
    assert a.hash != c.hash

144

145
146
147
148
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
149
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
150
def test_citation_graph():
151
152
    transform = dgl.AddSelfLoop(allow_duplicate=True)

153
    # cora
154
    g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]
155
156
157
158
    assert g.num_nodes() == 2708
    assert g.num_edges() == 10556
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
159
160
    g2 = data.CoraGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
161
162

    # Citeseer
163
    g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]
164
165
166
167
    assert g.num_nodes() == 3327
    assert g.num_edges() == 9228
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
168
169
    g2 = data.CiteseerGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
170
171

    # Pubmed
172
    g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]
173
174
175
176
    assert g.num_nodes() == 19717
    assert g.num_edges() == 88651
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
177
178
    g2 = data.PubmedGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
179
180


181
182
183
184
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
185
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
186
def test_gnn_benchmark():
187
188
    transform = dgl.AddSelfLoop(allow_duplicate=True)

189
190
191
192
193
194
    # AmazonCoBuyComputerDataset
    g = data.AmazonCoBuyComputerDataset()[0]
    assert g.num_nodes() == 13752
    assert g.num_edges() == 491722
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
195
196
    g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
197
198
199
200
201
202
203

    # AmazonCoBuyPhotoDataset
    g = data.AmazonCoBuyPhotoDataset()[0]
    assert g.num_nodes() == 7650
    assert g.num_edges() == 238163
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
204
205
    g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
206
207
208
209
210
211
212

    # CoauthorPhysicsDataset
    g = data.CoauthorPhysicsDataset()[0]
    assert g.num_nodes() == 34493
    assert g.num_edges() == 495924
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
213
214
    g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
215
216
217
218
219
220
221

    # CoauthorCSDataset
    g = data.CoauthorCSDataset()[0]
    assert g.num_nodes() == 18333
    assert g.num_edges() == 163788
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
222
223
    g2 = data.CoauthorCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
224
225
226
227
228
229
230

    # CoraFullDataset
    g = data.CoraFullDataset()[0]
    assert g.num_nodes() == 19793
    assert g.num_edges() == 126842
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
231
232
    g2 = data.CoraFullDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
233
234


235
236
237
238
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
239
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
240
241
242
243
244
245
246
247
def test_reddit():
    # RedditDataset
    g = data.RedditDataset()[0]
    assert g.num_nodes() == 232965
    assert g.num_edges() == 114615892
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

248
249
250
251
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.RedditDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

252
253
254
255
256

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
257
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
258
259
260
261
def test_explain_syn():
    dataset = data.BAShapeDataset()
    assert dataset.num_classes == 4
    g = dataset[0]
262
263
    assert "label" in g.ndata
    assert "feat" in g.ndata
264
265
266
267
268
269
270
271
272
273
274

    g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BACommunityDataset()
    assert dataset.num_classes == 8
    g = dataset[0]
275
276
    assert "label" in g.ndata
    assert "feat" in g.ndata
277
278
279
280
281
282
283
284
285
286
287

    g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeCycleDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
288
289
    assert "label" in g.ndata
    assert "feat" in g.ndata
290
291
292
293
294
295
296
297
298
299
300

    g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeGridDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
301
302
    assert "label" in g.ndata
    assert "feat" in g.ndata
303
304
305
306
307
308
309
310
311
312
313

    g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BA2MotifDataset()
    assert dataset.num_classes == 2
    g, label = dataset[0]
314
    assert "feat" in g.ndata
315

316
317
318
319
320

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
321
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
322
323
324
325
326
327
328
329
330
331
332
def test_wiki_cs():
    g = data.WikiCSDataset()[0]
    assert g.num_nodes() == 11701
    assert g.num_edges() == 431726
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.WikiCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

333

334
@unittest.skip(reason="Dataset too large to download for the latest CI.")
Minjie Wang's avatar
Minjie Wang committed
335
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
336
337
338
339
340
341
342
343
344
345
346
def test_yelp():
    g = data.YelpDataset(reorder=True)[0]
    assert g.num_nodes() == 716847
    assert g.num_edges() == 13954819
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.YelpDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

347
348
349
350
351

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
352
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
353
354
355
356
357
358
359
360
361
362
def test_flickr():
    g = data.FlickrDataset(reorder=True)[0]
    assert g.num_nodes() == 89250
    assert g.num_edges() == 899756
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.FlickrDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
363

364
365
366
367
368

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
369
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
370
371
372
def test_extract_archive():
    # gzip
    with tempfile.TemporaryDirectory() as src_dir:
373
374
        gz_file = "gz_archive"
        gz_path = os.path.join(src_dir, gz_file + ".gz")
375
        content = b"test extract archive gzip"
376
        with gzip.open(gz_path, "wb") as f:
377
378
379
380
381
382
            f.write(content)
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, gz_file))


383
def _test_construct_graphs_node_ids():
384
385
386
387
388
389
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

390
391
392
393
394
395
396
397
398
399
400
    num_nodes = 100
    num_edges = 1000

    # node IDs are required to be unique
    node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    expect_except = False
    try:
401
        _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)
402
403
404
405
406
407
408
409
410
411
412
    except:
        expect_except = True
    assert expect_except

    # node IDs are already labelled from 0~num_nodes-1
    node_ids = np.arange(num_nodes)
    np.random.shuffle(node_ids)
    _, idx = np.unique(node_ids, return_index=True)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_feat = np.random.rand(num_nodes, 3)
413
    node_data = NodeData(node_ids, {"feat": node_feat})
414
415
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
416
417
        node_data, edge_data
    )
418
419
420
421
422
423
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)
424
425
426
    assert F.array_equal(
        F.tensor(node_feat[idx], dtype=F.float32), g.ndata["feat"]
    )
427
428
429

    # node IDs are mixed with numeric and non-numeric values
    # homogeneous graph
430
    node_ids = [1, 2, 3, "a"]
431
    src_ids = [1, 2, 3]
432
    dst_ids = ["a", 1, 2]
433
434
435
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
436
437
        node_data, edge_data
    )
438
439
440
441
442
443
444
445
446
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)

    # heterogeneous graph
    node_ids_user = [1, 2, 3]
447
    node_ids_item = ["a", "b", "c"]
448
449
    src_ids = node_ids_user
    dst_ids = node_ids_item
450
451
452
    node_data_user = NodeData(node_ids_user, {}, type="user")
    node_data_item = NodeData(node_ids_item, {}, type="item")
    edge_data = EdgeData(src_ids, dst_ids, {}, type=("user", "like", "item"))
453
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
454
455
        [node_data_user, node_data_item], edge_data
    )
456
457
458
459
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
460
461
    assert g.num_nodes("user") == len(node_ids_user)
    assert g.num_nodes("item") == len(node_ids_item)
462
463
464
    assert g.num_edges() == len(src_ids)


465
def _test_construct_graphs_homo():
466
467
468
469
470
471
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

472
    # node_id could be non-sorted, non-numeric.
473
474
475
476
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
    node_ids = np.random.choice(
477
478
        np.arange(num_nodes * 2), size=num_nodes, replace=False
    )
479
    assert len(node_ids) == num_nodes
480
    # to be non-sorted
481
    np.random.shuffle(node_ids)
482
    # to be non-numeric
483
484
485
486
487
    node_ids = ["id_{}".format(id) for id in node_ids]
    t_ndata = {
        "feat": np.random.rand(num_nodes, num_dims),
        "label": np.random.randint(2, size=num_nodes),
    }
488
    _, u_indices = np.unique(node_ids, return_index=True)
489
490
491
492
    ndata = {
        "feat": t_ndata["feat"][u_indices],
        "label": t_ndata["label"][u_indices],
    }
493
    node_data = NodeData(node_ids, t_ndata)
494
495
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
496
497
498
499
    edata = {
        "feat": np.random.rand(num_edges, num_dims),
        "label": np.random.randint(2, size=num_edges),
    }
500
501
    edge_data = EdgeData(src_ids, dst_ids, edata)
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
502
503
        node_data, edge_data
    )
504
505
506
507
508
509
510
511
512
513
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
514
515
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
516
517
518
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

519
520
521
522
523
    assert_data(ndata, g.ndata)
    assert_data(edata, g.edata)


def _test_construct_graphs_hetero():
524
525
526
527
528
529
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

530
    # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
531
532
533
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
534
    ntypes = ["user", "item"]
535
536
537
538
539
    node_data = []
    node_ids_dict = {}
    ndata_dict = {}
    for ntype in ntypes:
        node_ids = np.random.choice(
540
541
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
542
        assert len(node_ids) == num_nodes
543
        # to be non-sorted
544
        np.random.shuffle(node_ids)
545
        # to be non-numeric
546
547
548
549
550
        node_ids = ["id_{}".format(id) for id in node_ids]
        t_ndata = {
            "feat": np.random.rand(num_nodes, num_dims),
            "label": np.random.randint(2, size=num_nodes),
        }
551
        _, u_indices = np.unique(node_ids, return_index=True)
552
553
554
555
        ndata = {
            "feat": t_ndata["feat"][u_indices],
            "label": t_ndata["label"][u_indices],
        }
556
        node_data.append(NodeData(node_ids, t_ndata, type=ntype))
557
558
        node_ids_dict[ntype] = node_ids
        ndata_dict[ntype] = ndata
559
    etypes = [("user", "follow", "user"), ("user", "like", "item")]
560
561
562
563
564
    edge_data = []
    edata_dict = {}
    for src_type, e_type, dst_type in etypes:
        src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)
        dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
565
566
567
568
569
570
571
        edata = {
            "feat": np.random.rand(num_edges, num_dims),
            "label": np.random.randint(2, size=num_edges),
        }
        edge_data.append(
            EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))
        )
572
        edata_dict[(src_type, e_type, dst_type)] = edata
573
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
574
575
        node_data, edge_data
    )
576
577
578
579
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
580
581
    assert g.num_nodes() == num_nodes * len(ntypes)
    assert g.num_edges() == num_edges * len(etypes)
582
583
584
585

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
586
587
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
588
589
590
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

591
592
593
594
595
596
597
598
599
    for ntype in g.ntypes:
        assert g.num_nodes(ntype) == num_nodes
        assert_data(ndata_dict[ntype], g.nodes[ntype].data)
    for etype in g.canonical_etypes:
        assert g.num_edges(etype) == num_edges
        assert_data(edata_dict[etype], g.edges[etype].data)


def _test_construct_graphs_multiple():
600
601
602
603
604
605
606
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        GraphData,
        NodeData,
    )

607
608
609
610
611
612
613
614
615
616
617
618
    num_nodes = 100
    num_edges = 1000
    num_graphs = 10
    num_dims = 3
    node_ids = np.array([], dtype=np.int)
    src_ids = np.array([], dtype=np.int)
    dst_ids = np.array([], dtype=np.int)
    ngraph_ids = np.array([], dtype=np.int)
    egraph_ids = np.array([], dtype=np.int)
    u_indices = np.array([], dtype=np.int)
    for i in range(num_graphs):
        l_node_ids = np.random.choice(
619
620
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
621
622
623
624
        node_ids = np.append(node_ids, l_node_ids)
        _, l_u_indices = np.unique(l_node_ids, return_index=True)
        u_indices = np.append(u_indices, l_u_indices)
        ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))
625
626
627
628
629
630
        src_ids = np.append(
            src_ids, np.random.choice(l_node_ids, size=num_edges)
        )
        dst_ids = np.append(
            dst_ids, np.random.choice(l_node_ids, size=num_edges)
        )
631
        egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
632
633
634
635
636
    ndata = {
        "feat": np.random.rand(num_nodes * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_nodes * num_graphs),
    }
    ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids]
637
    node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
638
639
640
641
642
    egraph_ids = ["graph_{}".format(id) for id in egraph_ids]
    edata = {
        "feat": np.random.rand(num_edges * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_edges * num_graphs),
    }
643
    edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
644
645
646
647
648
    gdata = {
        "feat": np.random.rand(num_graphs, num_dims),
        "label": np.random.randint(2, size=num_graphs),
    }
    graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)]
649
    graph_data = GraphData(graph_ids, gdata)
650
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
651
652
        node_data, edge_data, graph_data
    )
653
654
655
    assert len(graphs) == num_graphs
    assert len(data_dict) == len(gdata)
    for k, v in data_dict.items():
656
        assert F.dtype(v) != F.float64
657
658
659
660
        assert F.array_equal(
            F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)),
            v,
        )
661
662
663
664
665
666
667
668
    for i, g in enumerate(graphs):
        assert g.is_homogeneous
        assert g.num_nodes() == num_nodes
        assert g.num_edges() == num_edges

        def assert_data(lhs, rhs, size, node=False):
            for key, value in lhs.items():
                assert key in rhs
669
                value = value[i * size : (i + 1) * size]
670
                if node:
671
                    indices = u_indices[i * size : (i + 1) * size]
672
                    value = value[indices]
673
674
                assert F.dtype(rhs[key]) != F.float64
                assert F.array_equal(
675
676
677
                    F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
                )

678
679
680
681
        assert_data(ndata, g.ndata, num_nodes, node=True)
        assert_data(edata, g.edata, num_edges)

    # Graph IDs found in node/edge CSV but not in graph CSV
682
    graph_data = GraphData(np.arange(num_graphs - 2), {})
683
684
    expect_except = False
    try:
685
        _, _ = DGLGraphConstructor.construct_graphs(
686
687
            node_data, edge_data, graph_data
        )
688
689
690
691
692
693
    except:
        expect_except = True
    assert expect_except


def _test_DefaultDataParser():
694
    from dgl.data.csv_dataset_base import DefaultDataParser
695

696
697
698
699
700
701
702
703
704
    # common csv
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
        num_nodes = 5
        num_labels = 3
        num_dims = 2
        node_id = np.arange(num_nodes)
        label = np.random.randint(num_labels, size=num_nodes)
        feat = np.random.rand(num_nodes, num_dims)
705
706
707
708
709
710
711
        df = pd.DataFrame(
            {
                "node_id": node_id,
                "label": label,
                "feat": [line.tolist() for line in feat],
            }
        )
712
        df.to_csv(csv_path, index=False)
713
        dp = DefaultDataParser()
714
715
        df = pd.read_csv(csv_path)
        dt = dp(df)
716
717
718
        assert np.array_equal(node_id, dt["node_id"])
        assert np.array_equal(label, dt["label"])
        assert np.array_equal(feat, dt["feat"])
719
720
721
    # string consists of non-numeric values
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
722
723
724
725
726
        df = pd.DataFrame(
            {
                "label": ["a", "b", "c"],
            }
        )
727
        df.to_csv(csv_path, index=False)
728
        dp = DefaultDataParser()
729
730
731
732
733
734
735
736
737
738
        df = pd.read_csv(csv_path)
        expect_except = False
        try:
            dt = dp(df)
        except:
            expect_except = True
        assert expect_except
    # csv has index column which is ignored as it's unnamed
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
739
740
741
742
743
        df = pd.DataFrame(
            {
                "label": [1, 2, 3],
            }
        )
744
        df.to_csv(csv_path)
745
        dp = DefaultDataParser()
746
747
748
749
750
751
        df = pd.read_csv(csv_path)
        dt = dp(df)
        assert len(dt) == 1


def _test_load_yaml_with_sanity_check():
752
    from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
753

754
    with tempfile.TemporaryDirectory() as test_dir:
755
        yaml_path = os.path.join(test_dir, "meta.yaml")
756
        # workable but meaningless usually
757
758
759
760
761
762
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
        with open(yaml_path, "w") as f:
763
            yaml.dump(yaml_data, f, sort_keys=False)
764
        meta = load_yaml_with_sanity_check(yaml_path)
765
766
767
        assert meta.version == "1.0.0"
        assert meta.dataset_name == "default"
        assert meta.separator == ","
768
769
770
771
        assert len(meta.node_data) == 0
        assert len(meta.edge_data) == 0
        assert meta.graph_data is None
        # minimum with required fields only
772
773
774
775
776
777
778
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
779
            yaml.dump(yaml_data, f, sort_keys=False)
780
        meta = load_yaml_with_sanity_check(yaml_path)
781
        for ndata in meta.node_data:
782
783
784
785
            assert ndata.file_name == "nodes.csv"
            assert ndata.ntype == "_V"
            assert ndata.graph_id_field == "graph_id"
            assert ndata.node_id_field == "node_id"
786
        for edata in meta.edge_data:
787
788
789
790
791
            assert edata.file_name == "edges.csv"
            assert edata.etype == ["_V", "_E", "_V"]
            assert edata.graph_id_field == "graph_id"
            assert edata.src_id_field == "src_id"
            assert edata.dst_id_field == "dst_id"
792
        # optional fields are specified
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "separator": "|",
            "node_data": [
                {
                    "file_name": "nodes.csv",
                    "ntype": "user",
                    "graph_id_field": "xxx",
                    "node_id_field": "xxx",
                }
            ],
            "edge_data": [
                {
                    "file_name": "edges.csv",
                    "etype": ["user", "follow", "user"],
                    "graph_id_field": "xxx",
                    "src_id_field": "xxx",
                    "dst_id_field": "xxx",
                }
            ],
            "graph_data": {"file_name": "graph.csv", "graph_id_field": "xxx"},
        }
        with open(yaml_path, "w") as f:
817
            yaml.dump(yaml_data, f, sort_keys=False)
818
        meta = load_yaml_with_sanity_check(yaml_path)
819
820
        assert len(meta.node_data) == 1
        ndata = meta.node_data[0]
821
822
823
        assert ndata.ntype == "user"
        assert ndata.graph_id_field == "xxx"
        assert ndata.node_id_field == "xxx"
824
825
        assert len(meta.edge_data) == 1
        edata = meta.edge_data[0]
826
827
828
829
        assert edata.etype == ["user", "follow", "user"]
        assert edata.graph_id_field == "xxx"
        assert edata.src_id_field == "xxx"
        assert edata.dst_id_field == "xxx"
830
        assert meta.graph_data is not None
831
832
        assert meta.graph_data.file_name == "graph.csv"
        assert meta.graph_data.graph_id_field == "xxx"
833
        # some required fields are missing
834
835
836
837
838
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
839
840
841
        for field in yaml_data.keys():
            ydata = {k: v for k, v in yaml_data.items()}
            ydata.pop(field)
842
            with open(yaml_path, "w") as f:
843
844
845
                yaml.dump(ydata, f, sort_keys=False)
            expect_except = False
            try:
846
                meta = load_yaml_with_sanity_check(yaml_path)
847
848
849
850
            except:
                expect_except = True
            assert expect_except
        # inapplicable version
851
852
853
854
855
856
857
        yaml_data = {
            "version": "0.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes_0.csv"}],
            "edge_data": [{"file_name": "edges_0.csv"}],
        }
        with open(yaml_path, "w") as f:
858
859
860
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
861
            meta = load_yaml_with_sanity_check(yaml_path)
862
863
864
865
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate node types
866
867
868
869
870
871
872
873
874
875
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [
                {"file_name": "nodes.csv"},
                {"file_name": "nodes.csv"},
            ],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
876
877
878
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
879
            meta = load_yaml_with_sanity_check(yaml_path)
880
881
882
883
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate edge types
884
885
886
887
888
889
890
891
892
893
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [
                {"file_name": "edges.csv"},
                {"file_name": "edges.csv"},
            ],
        }
        with open(yaml_path, "w") as f:
894
895
896
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
897
            meta = load_yaml_with_sanity_check(yaml_path)
898
899
900
901
902
903
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_node_data_from_csv():
904
905
    from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData

906
907
908
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        # minimum
909
910
        df = pd.DataFrame({"node_id": np.arange(num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
911
        df.to_csv(csv_path, index=False)
912
        meta_node = MetaNode(file_name=csv_path)
913
914
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
915
916
917
        assert len(node_data.data) == 0

        # common case
918
919
920
921
922
923
924
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
925
        df.to_csv(csv_path, index=False)
926
        meta_node = MetaNode(file_name=csv_path)
927
928
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
929
        assert len(node_data.data) == 1
930
        assert np.array_equal(df["label"], node_data.data["label"])
931
        assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)
932
        assert node_data.type == "_V"
933
934

        # add more fields into nodes.csv
935
936
937
938
939
940
941
942
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
                "graph_id": np.full(num_nodes, 1),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
943
        df.to_csv(csv_path, index=False)
944
        meta_node = MetaNode(file_name=csv_path)
945
946
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
947
        assert len(node_data.data) == 1
948
949
950
        assert np.array_equal(df["label"], node_data.data["label"])
        assert np.array_equal(df["graph_id"], node_data.graph_id)
        assert node_data.type == "_V"
951
952

        # required header is missing
953
954
        df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
955
        df.to_csv(csv_path, index=False)
956
        meta_node = MetaNode(file_name=csv_path)
957
958
        expect_except = False
        try:
959
            NodeData.load_from_csv(meta_node, DefaultDataParser())
960
961
962
963
964
965
        except:
            expect_except = True
        assert expect_except


def _test_load_edge_data_from_csv():
966
967
    from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge

968
969
970
971
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        num_edges = 1000
        # minimum
972
973
974
975
976
977
978
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
979
        df.to_csv(csv_path, index=False)
980
        meta_edge = MetaEdge(file_name=csv_path)
981
982
983
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
984
985
986
        assert len(edge_data.data) == 0

        # common case
987
988
989
990
991
992
993
994
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
995
        df.to_csv(csv_path, index=False)
996
        meta_edge = MetaEdge(file_name=csv_path)
997
998
999
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1000
        assert len(edge_data.data) == 1
1001
        assert np.array_equal(df["label"], edge_data.data["label"])
1002
        assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)
1003
        assert edge_data.type == ("_V", "_E", "_V")
1004
1005

        # add more fields into edges.csv
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "graph_id": np.arange(num_edges),
                "feat": np.random.randint(3, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1016
        df.to_csv(csv_path, index=False)
1017
        meta_edge = MetaEdge(file_name=csv_path)
1018
1019
1020
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1021
        assert len(edge_data.data) == 2
1022
1023
1024
1025
        assert np.array_equal(df["feat"], edge_data.data["feat"])
        assert np.array_equal(df["label"], edge_data.data["label"])
        assert np.array_equal(df["graph_id"], edge_data.graph_id)
        assert edge_data.type == ("_V", "_E", "_V")
1026
1027

        # required headers are missing
1028
1029
1030
1031
1032
1033
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1034
        df.to_csv(csv_path, index=False)
1035
        meta_edge = MetaEdge(file_name=csv_path)
1036
1037
        expect_except = False
        try:
1038
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1039
1040
1041
        except DGLError:
            expect_except = True
        assert expect_except
1042
1043
1044
1045
1046
1047
        df = pd.DataFrame(
            {
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1048
        df.to_csv(csv_path, index=False)
1049
        meta_edge = MetaEdge(file_name=csv_path)
1050
1051
        expect_except = False
        try:
1052
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1053
1054
1055
1056
1057
1058
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_graph_data_from_csv():
1059
1060
1061
1062
1063
1064
    from dgl.data.csv_dataset_base import (
        DefaultDataParser,
        GraphData,
        MetaGraph,
    )

1065
1066
1067
    with tempfile.TemporaryDirectory() as test_dir:
        num_graphs = 100
        # minimum
1068
1069
        df = pd.DataFrame({"graph_id": np.arange(num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1070
        df.to_csv(csv_path, index=False)
1071
        meta_graph = MetaGraph(file_name=csv_path)
1072
1073
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1074
1075
1076
        assert len(graph_data.data) == 0

        # common case
1077
1078
1079
1080
1081
1082
1083
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1084
        df.to_csv(csv_path, index=False)
1085
        meta_graph = MetaGraph(file_name=csv_path)
1086
1087
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1088
        assert len(graph_data.data) == 1
1089
        assert np.array_equal(df["label"], graph_data.data["label"])
1090
1091

        # add more fields into graph.csv
1092
1093
1094
1095
1096
1097
1098
1099
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "feat": np.random.randint(3, size=num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1100
        df.to_csv(csv_path, index=False)
1101
        meta_graph = MetaGraph(file_name=csv_path)
1102
1103
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1104
        assert len(graph_data.data) == 2
1105
1106
        assert np.array_equal(df["feat"], graph_data.data["feat"])
        assert np.array_equal(df["label"], graph_data.data["label"])
1107
1108

        # required header is missing
1109
1110
        df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1111
        df.to_csv(csv_path, index=False)
1112
        meta_graph = MetaGraph(file_name=csv_path)
1113
1114
        expect_except = False
        try:
1115
            GraphData.load_from_csv(meta_graph, DefaultDataParser())
1116
1117
1118
1119
1120
        except DGLError:
            expect_except = True
        assert expect_except


1121
def _test_CSVDataset_single():
1122
1123
1124
1125
1126
1127
1128
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
        }
        with open(meta_yaml_path, "w") as f:
1154
1155
1156
1157
1158
1159
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes)
1160
1161
1162
1163
1164
1165
1166
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1167
1168
1169
1170
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
        feat_edata = np.random.rand(num_edges, num_dims)
        label_edata = np.random.randint(2, size=num_edges)
1171
1172
1173
1174
1175
1176
1177
1178
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
            }
        )
1179
1180
1181
1182
1183
1184
1185
1186
1187
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)

        # load CSVDataset
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1188
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1189
1190
1191
1192
1193
1194
            assert len(csv_dataset) == 1
            g = csv_dataset[0]
            assert not g.is_homogeneous
            assert csv_dataset.has_cache()
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1195
1196
1197
1198
1199
1200
1201
                assert F.array_equal(
                    F.tensor(feat_ndata, dtype=F.float32),
                    g.nodes[ntype].data["feat"],
                )
                assert np.array_equal(
                    label_ndata, F.asnumpy(g.nodes[ntype].data["label"])
                )
1202
1203
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1204
1205
1206
1207
1208
1209
1210
                assert F.array_equal(
                    F.tensor(feat_edata, dtype=F.float32),
                    g.edges[etype].data["feat"],
                )
                assert np.array_equal(
                    label_edata, F.asnumpy(g.edges[etype].data["label"])
                )
1211
1212


1213
def _test_CSVDataset_multiple():
1214
1215
1216
1217
1218
1219
1220
1221
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1248
1249
1250
1251
1252
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
        num_dims = 3
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
        feat_ndata = np.random.rand(num_nodes * num_graphs, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1267
1268
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
        feat_edata = np.random.rand(num_edges * num_graphs, num_dims)
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1292
1293
1294
1295
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        feat_gdata = np.random.rand(num_graphs, num_dims)
        label_gdata = np.random.randint(2, size=num_graphs)
1296
1297
1298
1299
1300
1301
1302
        df = pd.DataFrame(
            {
                "label": label_gdata,
                "feat": [line.tolist() for line in feat_gdata],
                "graph_id": np.arange(num_graphs),
            }
        )
1303
1304
        df.to_csv(graph_csv_path, index=False)

1305
        # load CSVDataset with default node/edge/gdata_parser
1306
1307
1308
1309
1310
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1311
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1312
1313
1314
            assert len(csv_dataset) == num_graphs
            assert csv_dataset.has_cache()
            assert len(csv_dataset.data) == 2
1315
1316
1317
1318
1319
            assert "feat" in csv_dataset.data
            assert "label" in csv_dataset.data
            assert F.array_equal(
                F.tensor(feat_gdata, dtype=F.float32), csv_dataset.data["feat"]
            )
1320
            for i, (g, g_data) in enumerate(csv_dataset):
1321
                assert not g.is_homogeneous
1322
1323
1324
1325
                assert F.asnumpy(g_data["label"]) == label_gdata[i]
                assert F.array_equal(
                    g_data["feat"], F.tensor(feat_gdata[i], dtype=F.float32)
                )
1326
1327
                for ntype in g.ntypes:
                    assert g.num_nodes(ntype) == num_nodes
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
                    assert F.array_equal(
                        F.tensor(
                            feat_ndata[i * num_nodes : (i + 1) * num_nodes],
                            dtype=F.float32,
                        ),
                        g.nodes[ntype].data["feat"],
                    )
                    assert np.array_equal(
                        label_ndata[i * num_nodes : (i + 1) * num_nodes],
                        F.asnumpy(g.nodes[ntype].data["label"]),
                    )
1339
1340
                for etype in g.etypes:
                    assert g.num_edges(etype) == num_edges
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
                    assert F.array_equal(
                        F.tensor(
                            feat_edata[i * num_edges : (i + 1) * num_edges],
                            dtype=F.float32,
                        ),
                        g.edges[etype].data["feat"],
                    )
                    assert np.array_equal(
                        label_edata[i * num_edges : (i + 1) * num_edges],
                        F.asnumpy(g.edges[etype].data["label"]),
                    )
1352
1353


1354
def _test_CSVDataset_customized_data_parser():
1355
1356
1357
1358
1359
1360
1361
1362
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
        meta_yaml_data = {
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1388
1389
1390
1391
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1404
1405
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1427
1428
1429
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        label_gdata = np.random.randint(2, size=num_graphs)
1430
1431
1432
        df = pd.DataFrame(
            {"label": label_gdata, "graph_id": np.arange(num_graphs)}
        )
1433
1434
1435
1436
1437
1438
1439
        df.to_csv(graph_csv_path, index=False)

        class CustDataParser:
            def __call__(self, df):
                data = {}
                for header in df:
                    dt = df[header].to_numpy().squeeze()
1440
                    if header == "label":
1441
1442
1443
                        dt += 2
                    data[header] = dt
                return data
1444

1445
1446
1447
        # load CSVDataset with customized node/edge/gdata_parser
        # specify via dict[ntype/etype, callable]
        csv_dataset = data.CSVDataset(
1448
1449
1450
1451
1452
1453
            test_dir,
            force_reload=True,
            ndata_parser={"user": CustDataParser()},
            edata_parser={("user", "like", "item"): CustDataParser()},
            gdata_parser=CustDataParser(),
        )
1454
1455
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1456
        assert "label" in csv_dataset.data
1457
        for i, (g, g_data) in enumerate(csv_dataset):
1458
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1459
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1460
1461
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1462
1463
1464
1465
1466
                offset = 2 if ntype == "user" else 0
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1467
1468
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1469
1470
1471
1472
1473
                offset = 2 if etype == "like" else 0
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1474
1475
        # specify via callable
        csv_dataset = data.CSVDataset(
1476
1477
1478
1479
1480
1481
            test_dir,
            force_reload=True,
            ndata_parser=CustDataParser(),
            edata_parser=CustDataParser(),
            gdata_parser=CustDataParser(),
        )
1482
1483
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1484
        assert "label" in csv_dataset.data
1485
1486
        for i, (g, g_data) in enumerate(csv_dataset):
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1487
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1488
1489
1490
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
                offset = 2
1491
1492
1493
1494
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1495
1496
1497
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
                offset = 2
1498
1499
1500
1501
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1502
1503
1504


def _test_NodeEdgeGraphData():
1505
1506
    from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData

1507
1508
1509
    # NodeData basics
    num_nodes = 100
    node_ids = np.arange(num_nodes, dtype=np.float)
1510
    ndata = NodeData(node_ids, {})
1511
    assert np.array_equal(ndata.id, node_ids)
1512
    assert len(ndata.data) == 0
1513
    assert ndata.type == "_V"
1514
1515
    assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
    # NodeData more
1516
    data = {"feat": np.random.rand(num_nodes, 3)}
1517
    graph_id = np.arange(num_nodes)
1518
1519
    ndata = NodeData(node_ids, data, type="user", graph_id=graph_id)
    assert ndata.type == "user"
1520
1521
1522
1523
1524
1525
1526
1527
    assert np.array_equal(ndata.graph_id, graph_id)
    assert len(ndata.data) == len(data)
    for k, v in data.items():
        assert k in ndata.data
        assert np.array_equal(ndata.data[k], v)
    # NodeData except
    expect_except = False
    try:
1528
1529
1530
1531
1532
        NodeData(
            np.arange(num_nodes),
            {"feat": np.random.rand(num_nodes + 1, 3)},
            graph_id=np.arange(num_nodes - 1),
        )
1533
1534
1535
1536
1537
1538
1539
1540
1541
    except:
        expect_except = True
    assert expect_except

    # EdgeData basics
    num_nodes = 100
    num_edges = 1000
    src_ids = np.random.randint(num_nodes, size=num_edges)
    dst_ids = np.random.randint(num_nodes, size=num_edges)
1542
    edata = EdgeData(src_ids, dst_ids, {})
1543
1544
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
1545
    assert edata.type == ("_V", "_E", "_V")
1546
1547
1548
1549
1550
    assert len(edata.data) == 0
    assert np.array_equal(edata.graph_id, np.full(num_edges, 0))
    # EdageData more
    src_ids = np.random.randint(num_nodes, size=num_edges).astype(np.float)
    dst_ids = np.random.randint(num_nodes, size=num_edges).astype(np.float)
1551
1552
    data = {"feat": np.random.rand(num_edges, 3)}
    etype = ("user", "like", "item")
1553
    graph_ids = np.arange(num_edges)
1554
    edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
    assert edata.type == etype
    assert len(edata.data) == len(data)
    for k, v in data.items():
        assert k in edata.data
        assert np.array_equal(edata.data[k], v)
    assert np.array_equal(edata.graph_id, graph_ids)
    # EdgeData except
    expect_except = False
    try:
1566
1567
1568
1569
1570
1571
        EdgeData(
            np.arange(num_edges),
            np.arange(num_edges + 1),
            {"feat": np.random.rand(num_edges - 1, 3)},
            graph_id=np.arange(num_edges + 2),
        )
1572
1573
1574
1575
1576
1577
1578
    except:
        expect_except = True
    assert expect_except

    # GraphData basics
    num_graphs = 10
    graph_ids = np.arange(num_graphs)
1579
    gdata = GraphData(graph_ids, {})
1580
1581
1582
1583
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == 0
    # GraphData more
    graph_ids = np.arange(num_graphs).astype(np.float)
1584
    data = {"feat": np.random.rand(num_graphs, 3)}
1585
    gdata = GraphData(graph_ids, data)
1586
1587
1588
1589
1590
1591
1592
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == len(data)
    for k, v in data.items():
        assert k in gdata.data
        assert np.array_equal(gdata.data[k], v)


1593
1594
1595
1596
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1597
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1598
1599
def test_csvdataset():
    _test_NodeEdgeGraphData()
1600
    _test_construct_graphs_node_ids()
1601
1602
1603
1604
1605
1606
1607
1608
    _test_construct_graphs_homo()
    _test_construct_graphs_hetero()
    _test_construct_graphs_multiple()
    _test_DefaultDataParser()
    _test_load_yaml_with_sanity_check()
    _test_load_node_data_from_csv()
    _test_load_edge_data_from_csv()
    _test_load_graph_data_from_csv()
1609
1610
1611
    _test_CSVDataset_single()
    _test_CSVDataset_multiple()
    _test_CSVDataset_customized_data_parser()
1612

1613
1614
1615
1616
1617

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1618
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1619
1620
def test_add_nodepred_split():
    dataset = data.AmazonCoBuyComputerDataset()
1621
    print("train_mask" in dataset[0].ndata)
1622
    data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
1623
    assert "train_mask" in dataset[0].ndata
1624
1625

    dataset = data.AIFBDataset()
1626
1627
1628
1629
1630
    print("train_mask" in dataset[0].nodes["Publikationen"].data)
    data.utils.add_nodepred_split(
        dataset, [0.8, 0.1, 0.1], ntype="Publikationen"
    )
    assert "train_mask" in dataset[0].nodes["Publikationen"].data
1631

1632
1633
1634
1635
1636

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1637
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1638
1639
def test_as_nodepred1():
    ds = data.AmazonCoBuyComputerDataset()
1640
    print("train_mask" in ds[0].ndata)
1641
1642
1643
1644
    new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1
    assert new_ds[0].num_nodes() == ds[0].num_nodes()
    assert new_ds[0].num_edges() == ds[0].num_edges()
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
    assert "train_mask" in new_ds[0].ndata
    assert F.array_equal(
        new_ds.train_idx, F.nonzero_1d(new_ds[0].ndata["train_mask"])
    )
    assert F.array_equal(
        new_ds.val_idx, F.nonzero_1d(new_ds[0].ndata["val_mask"])
    )
    assert F.array_equal(
        new_ds.test_idx, F.nonzero_1d(new_ds[0].ndata["test_mask"])
    )
1655
1656

    ds = data.AIFBDataset()
1657
1658
1659
1660
    print("train_mask" in ds[0].nodes["Personen"].data)
    new_ds = data.AsNodePredDataset(
        ds, [0.8, 0.1, 0.1], "Personen", verbose=True
    )
1661
1662
1663
    assert len(new_ds) == 1
    assert new_ds[0].ntypes == ds[0].ntypes
    assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
    assert "train_mask" in new_ds[0].nodes["Personen"].data
    assert F.array_equal(
        new_ds.train_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["train_mask"]),
    )
    assert F.array_equal(
        new_ds.val_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["val_mask"]),
    )
    assert F.array_equal(
        new_ds.test_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["test_mask"]),
    )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1683
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1684
1685
1686
1687
def test_as_nodepred2():
    # test proper reprocessing

    # create
1688
1689
1690
1691
1692
1693
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1694
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1695
    # read from cache
1696
1697
1698
1699
1700
1701
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1702
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1703
    # invalid cache, re-read
1704
1705
1706
1707
1708
1709
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.1
    )
1710
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)
1711
1712

    # create
1713
1714
1715
1716
1717
1718
1719
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1720
    # read from cache
1721
1722
1723
1724
1725
1726
1727
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1728
    # invalid cache, re-read
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.1, 0.1, 0.8], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.1)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.1)


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Minjie Wang's avatar
Minjie Wang committed
1741
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Jinjing Zhou's avatar
Jinjing Zhou committed
1742
1743
def test_as_nodepred_ogb():
    from ogb.nodeproppred import DglNodePropPredDataset
1744
1745
1746
1747

    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"), split_ratio=None, verbose=True
    )
1748
    split = DglNodePropPredDataset("ogbn-arxiv").get_idx_split()
1749
    train_idx, val_idx, test_idx = split["train"], split["valid"], split["test"]
1750
1751
1752
    assert F.array_equal(ds.train_idx, F.tensor(train_idx))
    assert F.array_equal(ds.val_idx, F.tensor(val_idx))
    assert F.array_equal(ds.test_idx, F.tensor(test_idx))
Jinjing Zhou's avatar
Jinjing Zhou committed
1753
    # force generate new split
1754
1755
1756
1757
1758
1759
    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )

1760

1761
1762
1763
1764
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1765
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1766
1767
def test_as_linkpred():
    # create
1768
1769
1770
1771
1772
1773
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.8, 0.1, 0.1],
        neg_ratio=1,
        verbose=True,
    )
1774
1775
1776
1777
1778
    # Cora has 10556 edges, 10% test edges can be 1057
    assert ds.test_edges[0][0].shape[0] == 1057
    # negative samples, not guaranteed, so the assert is in a relaxed range
    assert 1000 <= ds.test_edges[1][0].shape[0] <= 1057
    # read from cache
1779
1780
1781
1782
1783
1784
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.7, 0.1, 0.2],
        neg_ratio=2,
        verbose=True,
    )
1785
1786
1787
1788
1789
    assert ds.test_edges[0][0].shape[0] == 2112
    # negative samples, not guaranteed to be ratio 2, so the assert is in a relaxed range
    assert 4000 < ds.test_edges[1][0].shape[0] <= 4224


1790
1791
1792
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
1793
1794
def test_as_linkpred_ogb():
    from ogb.linkproppred import DglLinkPropPredDataset
1795
1796
1797
1798

    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"), split_ratio=None, verbose=True
    )
1799
1800
1801
    # original dataset has 46329 test edges
    assert ds.test_edges[0][0].shape[0] == 46329
    # force generate new split
1802
1803
1804
1805
1806
    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )
1807
1808
    assert ds.test_edges[0][0].shape[0] == 235812

1809
1810
1811
1812
1813

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1814
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1815
1816
1817
1818
1819
1820
def test_as_nodepred_csvdataset():
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path = os.path.join(test_dir, "test_edges.csv")
        nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
1821
1822
1823
1824
1825
1826
1827
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [{"file_name": os.path.basename(nodes_csv_path)}],
            "edge_data": [{"file_name": os.path.basename(edges_csv_path)}],
        }
        with open(meta_yaml_path, "w") as f:
1828
1829
1830
1831
1832
1833
1834
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        num_classes = num_nodes
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.arange(num_classes)
1835
1836
1837
1838
1839
1840
1841
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1842
        df.to_csv(nodes_csv_path, index=False)
1843
1844
1845
1846
1847
1848
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
1849
1850
        df.to_csv(edges_csv_path, index=False)

1851
        ds = data.CSVDataset(test_dir, force_reload=True)
1852
1853
1854
1855
1856
1857
1858
        assert "feat" in ds[0].ndata
        assert "label" in ds[0].ndata
        assert "train_mask" not in ds[0].ndata
        assert not hasattr(ds[0], "num_classes")
        new_ds = data.AsNodePredDataset(
            ds, split_ratio=[0.8, 0.1, 0.1], force_reload=True
        )
1859
        assert new_ds.num_classes == num_classes
1860
1861
1862
        assert "feat" in new_ds[0].ndata
        assert "label" in new_ds[0].ndata
        assert "train_mask" in new_ds[0].ndata
1863

1864
1865
1866
1867
1868

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1869
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1870
def test_as_graphpred():
1871
    ds = data.GINDataset(name="MUTAG", self_loop=True)
Mufei Li's avatar
Mufei Li committed
1872
1873
1874
1875
1876
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 188
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1877
    ds = data.FakeNewsDataset("politifact", "profile")
Mufei Li's avatar
Mufei Li committed
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
    new_ds = data.AsGraphPredDataset(ds, verbose=True)
    assert len(new_ds) == 314
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.QM7bDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 7211
    assert new_ds.num_tasks == 14
    assert new_ds.num_classes is None

1889
    ds = data.QM9Dataset(label_keys=["mu", "gap"])
Mufei Li's avatar
Mufei Li committed
1890
1891
1892
1893
1894
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1895
    ds = data.QM9EdgeDataset(label_keys=["mu", "alpha"])
Mufei Li's avatar
Mufei Li committed
1896
1897
1898
1899
1900
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1901
    ds = data.TUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1902
1903
1904
1905
1906
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1907
    ds = data.LegacyTUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.BA2MotifDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1000
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1919
1920
1921
1922
1923

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1924
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1925
def test_as_graphpred_reprocess():
1926
1927
1928
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1929
1930
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1931
1932
1933
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1934
1935
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1936
1937
1938
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1939
1940
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1941
1942
1943
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1944
1945
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1946
1947
1948
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1949
1950
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1951
1952
1953
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1965
1966
1967
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1968
1969
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1970
1971
1972
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1973
1974
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1975
1976
1977
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1978
1979
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1980
1981
1982
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1983
1984
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1985
1986
1987
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1988
1989
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1990
1991
1992
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1993
1994
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1995
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1996
1997
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1998
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1999
2000
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2001
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2002
2003
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2004
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2005
2006
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2007
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2008
2009
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2010
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2022
2023
2024
2025

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Mufei Li's avatar
Mufei Li committed
2026
2027
def test_as_graphpred_ogb():
    from ogb.graphproppred import DglGraphPropPredDataset
2028
2029
2030
2031

    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"), split_ratio=None, verbose=True
    )
Mufei Li's avatar
Mufei Li committed
2032
2033
    assert len(ds.train_idx) == 32901
    # force generate new split
2034
2035
2036
2037
2038
    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"),
        split_ratio=[0.6, 0.2, 0.2],
        verbose=True,
    )
Mufei Li's avatar
Mufei Li committed
2039
2040
    assert len(ds.train_idx) == 24676

2041
2042

if __name__ == "__main__":
2043
    test_minigc()
2044
    test_gin()
2045
    test_data_hash()
2046
2047
2048
    test_tudataset_regression()
    test_fraud()
    test_fakenews()
2049
    test_extract_archive()
2050
    test_csvdataset()
2051
2052
2053
    test_add_nodepred_split()
    test_as_nodepred1()
    test_as_nodepred2()
2054
    test_as_nodepred_csvdataset()