test_data.py 74.7 KB
Newer Older
1
import gzip
2
import io
3
import os
4
import tarfile
5
import tempfile
6
import unittest
7

8
import backend as F
9
10
11

import dgl
import dgl.data as data
12
import numpy as np
13
14
import pandas as pd
import pytest
15
import yaml
16
from dgl import DGLError
17

18
19
20
21
22

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
23
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
24
25
26
27
def test_minigc():
    ds = data.MiniGCDataset(16, 10, 20)
    g, l = list(zip(*ds))
    print(g, l)
28
29
30
31
32
    g1 = ds[0][0]
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    ds = data.MiniGCDataset(16, 10, 20, transform=transform)
    g2 = ds[0][0]
    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
33

34
35
36
37
38

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
39
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
40
41
def test_gin():
    ds_n_graphs = {
42
43
44
45
46
        "MUTAG": 188,
        "IMDBBINARY": 1000,
        "IMDBMULTI": 1500,
        "PROTEINS": 1113,
        "PTC": 344,
47
    }
48
    transform = dgl.AddSelfLoop(allow_duplicate=True)
49
50
51
    for name, n_graphs in ds_n_graphs.items():
        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
        assert len(ds) == n_graphs, (len(ds), name)
52
        g1 = ds[0][0]
53
54
55
        ds = data.GINDataset(
            name, self_loop=False, degree_as_nlabel=False, transform=transform
        )
56
57
        g2 = ds[0][0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
Mufei Li's avatar
Mufei Li committed
58
        assert ds.num_classes == ds.gclasses
59

60
61
62
63
64

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
65
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
66
def test_fraud():
67
68
    transform = dgl.AddSelfLoop(allow_duplicate=True)

69
    g = data.FraudDataset("amazon")[0]
70
    assert g.num_nodes() == 11944
71
    num_edges1 = g.num_edges()
72
    g2 = data.FraudDataset("amazon", transform=transform)[0]
73
74
    # 3 edge types
    assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
75
76
77

    g = data.FraudAmazonDataset()[0]
    assert g.num_nodes() == 11944
78
79
80
    g2 = data.FraudAmazonDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
81
82
83

    g = data.FraudYelpDataset()[0]
    assert g.num_nodes() == 45954
84
85
86
    g2 = data.FraudYelpDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
87

88
89
90
91
92

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
93
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
94
def test_fakenews():
95
96
    transform = dgl.AddSelfLoop(allow_duplicate=True)

97
    ds = data.FakeNewsDataset("politifact", "bert")
98
    assert len(ds) == 314
99
    g = ds[0][0]
100
    g2 = data.FakeNewsDataset("politifact", "bert", transform=transform)[0][0]
101
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
102

103
    ds = data.FakeNewsDataset("gossipcop", "profile")
104
    assert len(ds) == 5464
105
    g = ds[0][0]
106
    g2 = data.FakeNewsDataset("gossipcop", "profile", transform=transform)[0][0]
107
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
108

109
110
111
112
113

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
114
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
115
def test_tudataset_regression():
116
    ds = data.TUDataset("ZINC_test", force_reload=True)
Mufei Li's avatar
Mufei Li committed
117
    assert ds.num_classes == ds.num_labels
Jinjing Zhou's avatar
Jinjing Zhou committed
118
    assert len(ds) == 5000
119
    g = ds[0][0]
Jinjing Zhou's avatar
Jinjing Zhou committed
120

121
    transform = dgl.AddSelfLoop(allow_duplicate=True)
122
    ds = data.TUDataset("ZINC_test", force_reload=True, transform=transform)
123
124
    g2 = ds[0][0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
125

126
127
128
129
130

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
131
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
132
133
134
def test_data_hash():
    class HashTestDataset(data.DGLDataset):
        def __init__(self, hash_key=()):
135
            super(HashTestDataset, self).__init__("hashtest", hash_key=hash_key)
136

137
138
139
        def _load(self):
            pass

140
141
142
    a = HashTestDataset((True, 0, "1", (1, 2, 3)))
    b = HashTestDataset((True, 0, "1", (1, 2, 3)))
    c = HashTestDataset((True, 0, "1", (1, 2, 4)))
143
144
145
    assert a.hash == b.hash
    assert a.hash != c.hash

146

147
148
149
150
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
151
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
152
def test_citation_graph():
153
154
    transform = dgl.AddSelfLoop(allow_duplicate=True)

155
    # cora
156
    g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]
157
158
159
160
    assert g.num_nodes() == 2708
    assert g.num_edges() == 10556
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
161
162
    g2 = data.CoraGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
163
164

    # Citeseer
165
    g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]
166
167
168
169
    assert g.num_nodes() == 3327
    assert g.num_edges() == 9228
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
170
171
    g2 = data.CiteseerGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
172
173

    # Pubmed
174
    g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]
175
176
177
178
    assert g.num_nodes() == 19717
    assert g.num_edges() == 88651
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
179
180
    g2 = data.PubmedGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
181
182


183
184
185
186
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
187
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
188
def test_gnn_benchmark():
189
190
    transform = dgl.AddSelfLoop(allow_duplicate=True)

191
192
193
194
195
196
    # AmazonCoBuyComputerDataset
    g = data.AmazonCoBuyComputerDataset()[0]
    assert g.num_nodes() == 13752
    assert g.num_edges() == 491722
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
197
198
    g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
199
200
201
202
203
204
205

    # AmazonCoBuyPhotoDataset
    g = data.AmazonCoBuyPhotoDataset()[0]
    assert g.num_nodes() == 7650
    assert g.num_edges() == 238163
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
206
207
    g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
208
209
210
211
212
213
214

    # CoauthorPhysicsDataset
    g = data.CoauthorPhysicsDataset()[0]
    assert g.num_nodes() == 34493
    assert g.num_edges() == 495924
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
215
216
    g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
217
218
219
220
221
222
223

    # CoauthorCSDataset
    g = data.CoauthorCSDataset()[0]
    assert g.num_nodes() == 18333
    assert g.num_edges() == 163788
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
224
225
    g2 = data.CoauthorCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
226
227
228
229
230
231
232

    # CoraFullDataset
    g = data.CoraFullDataset()[0]
    assert g.num_nodes() == 19793
    assert g.num_edges() == 126842
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
233
234
    g2 = data.CoraFullDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
235
236


237
238
239
240
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
241
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
242
243
244
245
246
247
248
249
def test_reddit():
    # RedditDataset
    g = data.RedditDataset()[0]
    assert g.num_nodes() == 232965
    assert g.num_edges() == 114615892
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

250
251
252
253
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.RedditDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

254
255
256
257
258

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
259
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
260
261
262
263
def test_explain_syn():
    dataset = data.BAShapeDataset()
    assert dataset.num_classes == 4
    g = dataset[0]
264
265
    assert "label" in g.ndata
    assert "feat" in g.ndata
266
267
268
269
270
271
272
273
274
275
276

    g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BACommunityDataset()
    assert dataset.num_classes == 8
    g = dataset[0]
277
278
    assert "label" in g.ndata
    assert "feat" in g.ndata
279
280
281
282
283
284
285
286
287
288
289

    g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeCycleDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
290
291
    assert "label" in g.ndata
    assert "feat" in g.ndata
292
293
294
295
296
297
298
299
300
301
302

    g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeGridDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
303
304
    assert "label" in g.ndata
    assert "feat" in g.ndata
305
306
307
308
309
310
311
312
313
314
315

    g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BA2MotifDataset()
    assert dataset.num_classes == 2
    g, label = dataset[0]
316
    assert "feat" in g.ndata
317

318
319
320
321
322

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
323
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
324
325
326
327
328
329
330
331
332
333
334
def test_wiki_cs():
    g = data.WikiCSDataset()[0]
    assert g.num_nodes() == 11701
    assert g.num_edges() == 431726
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.WikiCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

335

336
@unittest.skip(reason="Dataset too large to download for the latest CI.")
Minjie Wang's avatar
Minjie Wang committed
337
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
338
339
340
341
342
343
344
345
346
347
348
def test_yelp():
    g = data.YelpDataset(reorder=True)[0]
    assert g.num_nodes() == 716847
    assert g.num_edges() == 13954819
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.YelpDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

349
350
351
352
353

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
354
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
355
356
357
358
359
360
361
362
363
364
def test_flickr():
    g = data.FlickrDataset(reorder=True)[0]
    assert g.num_nodes() == 89250
    assert g.num_edges() == 899756
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.FlickrDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
365

366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_pattern():
    mode_n_graphs = {
        "train": 10000,
        "valid": 2000,
        "test": 2000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        ds = data.PATTERNDataset(mode=mode)
        assert len(ds) == n_graphs, (len(ds), mode)
        g1 = ds[0]
        ds = data.PATTERNDataset(mode=mode, transform=transform)
        g2 = ds[0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        assert ds.num_classes == 2


389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_cluster():
    mode_n_graphs = {
        "train": 10000,
        "valid": 1000,
        "test": 1000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        ds = data.CLUSTERDataset(mode=mode)
        assert len(ds) == n_graphs, (len(ds), mode)
        g1 = ds[0]
        ds = data.CLUSTERDataset(mode=mode, transform=transform)
        g2 = ds[0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        assert ds.num_classes == 6


411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_zinc():
    mode_n_graphs = {
        "train": 10000,
        "valid": 1000,
        "test": 1000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        dataset1 = data.ZINCDataset(mode=mode)
        g1, label = dataset1[0]
        dataset2 = data.ZINCDataset(mode=mode, transform=transform)
        g2, _ = dataset2[0]

        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        # return a scalar tensor
        assert not label.shape


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_extract_archive():
    # gzip
    with tempfile.TemporaryDirectory() as src_dir:
        gz_file = "gz_archive"
        gz_path = os.path.join(src_dir, gz_file + ".gz")
        content = b"test extract archive gzip"
        with gzip.open(gz_path, "wb") as f:
            f.write(content)
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, gz_file))

    # tar
    with tempfile.TemporaryDirectory() as src_dir:
        tar_file = "tar_archive"
        tar_path = os.path.join(src_dir, tar_file + ".tar")
        # default encode to utf8
        content = "test extract archive tar\n".encode()
        info = tarfile.TarInfo(name="tar_archive")
        info.size = len(content)
        with tarfile.open(tar_path, "w") as f:
            f.addfile(info, io.BytesIO(content))
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, tar_file))


468
def _test_construct_graphs_node_ids():
469
470
471
472
473
474
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

475
476
477
478
479
480
481
482
483
484
485
    num_nodes = 100
    num_edges = 1000

    # node IDs are required to be unique
    node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    expect_except = False
    try:
486
        _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)
487
488
489
490
491
492
493
494
495
496
497
    except:
        expect_except = True
    assert expect_except

    # node IDs are already labelled from 0~num_nodes-1
    node_ids = np.arange(num_nodes)
    np.random.shuffle(node_ids)
    _, idx = np.unique(node_ids, return_index=True)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_feat = np.random.rand(num_nodes, 3)
498
    node_data = NodeData(node_ids, {"feat": node_feat})
499
500
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
501
502
        node_data, edge_data
    )
503
504
505
506
507
508
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)
509
510
511
    assert F.array_equal(
        F.tensor(node_feat[idx], dtype=F.float32), g.ndata["feat"]
    )
512
513
514

    # node IDs are mixed with numeric and non-numeric values
    # homogeneous graph
515
    node_ids = [1, 2, 3, "a"]
516
    src_ids = [1, 2, 3]
517
    dst_ids = ["a", 1, 2]
518
519
520
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
521
522
        node_data, edge_data
    )
523
524
525
526
527
528
529
530
531
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)

    # heterogeneous graph
    node_ids_user = [1, 2, 3]
532
    node_ids_item = ["a", "b", "c"]
533
534
    src_ids = node_ids_user
    dst_ids = node_ids_item
535
536
537
    node_data_user = NodeData(node_ids_user, {}, type="user")
    node_data_item = NodeData(node_ids_item, {}, type="item")
    edge_data = EdgeData(src_ids, dst_ids, {}, type=("user", "like", "item"))
538
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
539
540
        [node_data_user, node_data_item], edge_data
    )
541
542
543
544
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
545
546
    assert g.num_nodes("user") == len(node_ids_user)
    assert g.num_nodes("item") == len(node_ids_item)
547
548
549
    assert g.num_edges() == len(src_ids)


550
def _test_construct_graphs_homo():
551
552
553
554
555
556
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

557
    # node_id could be non-sorted, non-numeric.
558
559
560
561
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
    node_ids = np.random.choice(
562
563
        np.arange(num_nodes * 2), size=num_nodes, replace=False
    )
564
    assert len(node_ids) == num_nodes
565
    # to be non-sorted
566
    np.random.shuffle(node_ids)
567
    # to be non-numeric
568
569
570
571
572
    node_ids = ["id_{}".format(id) for id in node_ids]
    t_ndata = {
        "feat": np.random.rand(num_nodes, num_dims),
        "label": np.random.randint(2, size=num_nodes),
    }
573
    _, u_indices = np.unique(node_ids, return_index=True)
574
575
576
577
    ndata = {
        "feat": t_ndata["feat"][u_indices],
        "label": t_ndata["label"][u_indices],
    }
578
    node_data = NodeData(node_ids, t_ndata)
579
580
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
581
582
583
584
    edata = {
        "feat": np.random.rand(num_edges, num_dims),
        "label": np.random.randint(2, size=num_edges),
    }
585
586
    edge_data = EdgeData(src_ids, dst_ids, edata)
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
587
588
        node_data, edge_data
    )
589
590
591
592
593
594
595
596
597
598
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
599
600
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
601
602
603
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

604
605
606
607
608
    assert_data(ndata, g.ndata)
    assert_data(edata, g.edata)


def _test_construct_graphs_hetero():
609
610
611
612
613
614
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

615
    # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
616
617
618
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
619
    ntypes = ["user", "item"]
620
621
622
623
624
    node_data = []
    node_ids_dict = {}
    ndata_dict = {}
    for ntype in ntypes:
        node_ids = np.random.choice(
625
626
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
627
        assert len(node_ids) == num_nodes
628
        # to be non-sorted
629
        np.random.shuffle(node_ids)
630
        # to be non-numeric
631
632
633
634
635
        node_ids = ["id_{}".format(id) for id in node_ids]
        t_ndata = {
            "feat": np.random.rand(num_nodes, num_dims),
            "label": np.random.randint(2, size=num_nodes),
        }
636
        _, u_indices = np.unique(node_ids, return_index=True)
637
638
639
640
        ndata = {
            "feat": t_ndata["feat"][u_indices],
            "label": t_ndata["label"][u_indices],
        }
641
        node_data.append(NodeData(node_ids, t_ndata, type=ntype))
642
643
        node_ids_dict[ntype] = node_ids
        ndata_dict[ntype] = ndata
644
    etypes = [("user", "follow", "user"), ("user", "like", "item")]
645
646
647
648
649
    edge_data = []
    edata_dict = {}
    for src_type, e_type, dst_type in etypes:
        src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)
        dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
650
651
652
653
654
655
656
        edata = {
            "feat": np.random.rand(num_edges, num_dims),
            "label": np.random.randint(2, size=num_edges),
        }
        edge_data.append(
            EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))
        )
657
        edata_dict[(src_type, e_type, dst_type)] = edata
658
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
659
660
        node_data, edge_data
    )
661
662
663
664
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
665
666
    assert g.num_nodes() == num_nodes * len(ntypes)
    assert g.num_edges() == num_edges * len(etypes)
667
668
669
670

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
671
672
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
673
674
675
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

676
677
678
679
680
681
682
683
684
    for ntype in g.ntypes:
        assert g.num_nodes(ntype) == num_nodes
        assert_data(ndata_dict[ntype], g.nodes[ntype].data)
    for etype in g.canonical_etypes:
        assert g.num_edges(etype) == num_edges
        assert_data(edata_dict[etype], g.edges[etype].data)


def _test_construct_graphs_multiple():
685
686
687
688
689
690
691
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        GraphData,
        NodeData,
    )

692
693
694
695
    num_nodes = 100
    num_edges = 1000
    num_graphs = 10
    num_dims = 3
696
697
698
699
700
701
    node_ids = np.array([], dtype=int)
    src_ids = np.array([], dtype=int)
    dst_ids = np.array([], dtype=int)
    ngraph_ids = np.array([], dtype=int)
    egraph_ids = np.array([], dtype=int)
    u_indices = np.array([], dtype=int)
702
703
    for i in range(num_graphs):
        l_node_ids = np.random.choice(
704
705
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
706
707
708
709
        node_ids = np.append(node_ids, l_node_ids)
        _, l_u_indices = np.unique(l_node_ids, return_index=True)
        u_indices = np.append(u_indices, l_u_indices)
        ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))
710
711
712
713
714
715
        src_ids = np.append(
            src_ids, np.random.choice(l_node_ids, size=num_edges)
        )
        dst_ids = np.append(
            dst_ids, np.random.choice(l_node_ids, size=num_edges)
        )
716
        egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
717
718
719
720
721
    ndata = {
        "feat": np.random.rand(num_nodes * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_nodes * num_graphs),
    }
    ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids]
722
    node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
723
724
725
726
727
    egraph_ids = ["graph_{}".format(id) for id in egraph_ids]
    edata = {
        "feat": np.random.rand(num_edges * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_edges * num_graphs),
    }
728
    edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
729
730
731
732
733
    gdata = {
        "feat": np.random.rand(num_graphs, num_dims),
        "label": np.random.randint(2, size=num_graphs),
    }
    graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)]
734
    graph_data = GraphData(graph_ids, gdata)
735
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
736
737
        node_data, edge_data, graph_data
    )
738
739
740
    assert len(graphs) == num_graphs
    assert len(data_dict) == len(gdata)
    for k, v in data_dict.items():
741
        assert F.dtype(v) != F.float64
742
743
744
745
        assert F.array_equal(
            F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)),
            v,
        )
746
747
748
749
750
751
752
753
    for i, g in enumerate(graphs):
        assert g.is_homogeneous
        assert g.num_nodes() == num_nodes
        assert g.num_edges() == num_edges

        def assert_data(lhs, rhs, size, node=False):
            for key, value in lhs.items():
                assert key in rhs
754
                value = value[i * size : (i + 1) * size]
755
                if node:
756
                    indices = u_indices[i * size : (i + 1) * size]
757
                    value = value[indices]
758
759
                assert F.dtype(rhs[key]) != F.float64
                assert F.array_equal(
760
761
762
                    F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
                )

763
764
765
766
        assert_data(ndata, g.ndata, num_nodes, node=True)
        assert_data(edata, g.edata, num_edges)

    # Graph IDs found in node/edge CSV but not in graph CSV
767
    graph_data = GraphData(np.arange(num_graphs - 2), {})
768
769
    expect_except = False
    try:
770
        _, _ = DGLGraphConstructor.construct_graphs(
771
772
            node_data, edge_data, graph_data
        )
773
774
775
776
777
778
    except:
        expect_except = True
    assert expect_except


def _test_DefaultDataParser():
779
    from dgl.data.csv_dataset_base import DefaultDataParser
780

781
782
783
784
785
786
787
788
789
    # common csv
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
        num_nodes = 5
        num_labels = 3
        num_dims = 2
        node_id = np.arange(num_nodes)
        label = np.random.randint(num_labels, size=num_nodes)
        feat = np.random.rand(num_nodes, num_dims)
790
791
792
793
794
795
796
        df = pd.DataFrame(
            {
                "node_id": node_id,
                "label": label,
                "feat": [line.tolist() for line in feat],
            }
        )
797
        df.to_csv(csv_path, index=False)
798
        dp = DefaultDataParser()
799
800
        df = pd.read_csv(csv_path)
        dt = dp(df)
801
802
803
        assert np.array_equal(node_id, dt["node_id"])
        assert np.array_equal(label, dt["label"])
        assert np.array_equal(feat, dt["feat"])
804
805
806
    # string consists of non-numeric values
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
807
        df = pd.DataFrame({"label": ["a", "b", "c"]})
808
        df.to_csv(csv_path, index=False)
809
        dp = DefaultDataParser()
810
811
812
813
814
815
816
817
818
819
        df = pd.read_csv(csv_path)
        expect_except = False
        try:
            dt = dp(df)
        except:
            expect_except = True
        assert expect_except
    # csv has index column which is ignored as it's unnamed
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
820
        df = pd.DataFrame({"label": [1, 2, 3]})
821
        df.to_csv(csv_path)
822
        dp = DefaultDataParser()
823
824
825
826
827
828
        df = pd.read_csv(csv_path)
        dt = dp(df)
        assert len(dt) == 1


def _test_load_yaml_with_sanity_check():
829
    from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
830

831
    with tempfile.TemporaryDirectory() as test_dir:
832
        yaml_path = os.path.join(test_dir, "meta.yaml")
833
        # workable but meaningless usually
834
835
836
837
838
839
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
        with open(yaml_path, "w") as f:
840
            yaml.dump(yaml_data, f, sort_keys=False)
841
        meta = load_yaml_with_sanity_check(yaml_path)
842
843
844
        assert meta.version == "1.0.0"
        assert meta.dataset_name == "default"
        assert meta.separator == ","
845
846
847
848
        assert len(meta.node_data) == 0
        assert len(meta.edge_data) == 0
        assert meta.graph_data is None
        # minimum with required fields only
849
850
851
852
853
854
855
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
856
            yaml.dump(yaml_data, f, sort_keys=False)
857
        meta = load_yaml_with_sanity_check(yaml_path)
858
        for ndata in meta.node_data:
859
860
861
862
            assert ndata.file_name == "nodes.csv"
            assert ndata.ntype == "_V"
            assert ndata.graph_id_field == "graph_id"
            assert ndata.node_id_field == "node_id"
863
        for edata in meta.edge_data:
864
865
866
867
868
            assert edata.file_name == "edges.csv"
            assert edata.etype == ["_V", "_E", "_V"]
            assert edata.graph_id_field == "graph_id"
            assert edata.src_id_field == "src_id"
            assert edata.dst_id_field == "dst_id"
869
        # optional fields are specified
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "separator": "|",
            "node_data": [
                {
                    "file_name": "nodes.csv",
                    "ntype": "user",
                    "graph_id_field": "xxx",
                    "node_id_field": "xxx",
                }
            ],
            "edge_data": [
                {
                    "file_name": "edges.csv",
                    "etype": ["user", "follow", "user"],
                    "graph_id_field": "xxx",
                    "src_id_field": "xxx",
                    "dst_id_field": "xxx",
                }
            ],
            "graph_data": {"file_name": "graph.csv", "graph_id_field": "xxx"},
        }
        with open(yaml_path, "w") as f:
894
            yaml.dump(yaml_data, f, sort_keys=False)
895
        meta = load_yaml_with_sanity_check(yaml_path)
896
897
        assert len(meta.node_data) == 1
        ndata = meta.node_data[0]
898
899
900
        assert ndata.ntype == "user"
        assert ndata.graph_id_field == "xxx"
        assert ndata.node_id_field == "xxx"
901
902
        assert len(meta.edge_data) == 1
        edata = meta.edge_data[0]
903
904
905
906
        assert edata.etype == ["user", "follow", "user"]
        assert edata.graph_id_field == "xxx"
        assert edata.src_id_field == "xxx"
        assert edata.dst_id_field == "xxx"
907
        assert meta.graph_data is not None
908
909
        assert meta.graph_data.file_name == "graph.csv"
        assert meta.graph_data.graph_id_field == "xxx"
910
        # some required fields are missing
911
912
913
914
915
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
916
917
918
        for field in yaml_data.keys():
            ydata = {k: v for k, v in yaml_data.items()}
            ydata.pop(field)
919
            with open(yaml_path, "w") as f:
920
921
922
                yaml.dump(ydata, f, sort_keys=False)
            expect_except = False
            try:
923
                meta = load_yaml_with_sanity_check(yaml_path)
924
925
926
927
            except:
                expect_except = True
            assert expect_except
        # inapplicable version
928
929
930
931
932
933
934
        yaml_data = {
            "version": "0.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes_0.csv"}],
            "edge_data": [{"file_name": "edges_0.csv"}],
        }
        with open(yaml_path, "w") as f:
935
936
937
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
938
            meta = load_yaml_with_sanity_check(yaml_path)
939
940
941
942
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate node types
943
944
945
946
947
948
949
950
951
952
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [
                {"file_name": "nodes.csv"},
                {"file_name": "nodes.csv"},
            ],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
953
954
955
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
956
            meta = load_yaml_with_sanity_check(yaml_path)
957
958
959
960
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate edge types
961
962
963
964
965
966
967
968
969
970
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [
                {"file_name": "edges.csv"},
                {"file_name": "edges.csv"},
            ],
        }
        with open(yaml_path, "w") as f:
971
972
973
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
974
            meta = load_yaml_with_sanity_check(yaml_path)
975
976
977
978
979
980
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_node_data_from_csv():
981
982
    from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData

983
984
985
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        # minimum
986
987
        df = pd.DataFrame({"node_id": np.arange(num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
988
        df.to_csv(csv_path, index=False)
989
        meta_node = MetaNode(file_name=csv_path)
990
991
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
992
993
994
        assert len(node_data.data) == 0

        # common case
995
996
997
998
999
1000
1001
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
1002
        df.to_csv(csv_path, index=False)
1003
        meta_node = MetaNode(file_name=csv_path)
1004
1005
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
1006
        assert len(node_data.data) == 1
1007
        assert np.array_equal(df["label"], node_data.data["label"])
1008
        assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)
1009
        assert node_data.type == "_V"
1010
1011

        # add more fields into nodes.csv
1012
1013
1014
1015
1016
1017
1018
1019
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
                "graph_id": np.full(num_nodes, 1),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
1020
        df.to_csv(csv_path, index=False)
1021
        meta_node = MetaNode(file_name=csv_path)
1022
1023
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
1024
        assert len(node_data.data) == 1
1025
1026
1027
        assert np.array_equal(df["label"], node_data.data["label"])
        assert np.array_equal(df["graph_id"], node_data.graph_id)
        assert node_data.type == "_V"
1028
1029

        # required header is missing
1030
1031
        df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
1032
        df.to_csv(csv_path, index=False)
1033
        meta_node = MetaNode(file_name=csv_path)
1034
1035
        expect_except = False
        try:
1036
            NodeData.load_from_csv(meta_node, DefaultDataParser())
1037
1038
1039
1040
1041
1042
        except:
            expect_except = True
        assert expect_except


def _test_load_edge_data_from_csv():
1043
1044
    from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge

1045
1046
1047
1048
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        num_edges = 1000
        # minimum
1049
1050
1051
1052
1053
1054
1055
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1056
        df.to_csv(csv_path, index=False)
1057
        meta_edge = MetaEdge(file_name=csv_path)
1058
1059
1060
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1061
1062
1063
        assert len(edge_data.data) == 0

        # common case
1064
1065
1066
1067
1068
1069
1070
1071
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1072
        df.to_csv(csv_path, index=False)
1073
        meta_edge = MetaEdge(file_name=csv_path)
1074
1075
1076
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1077
        assert len(edge_data.data) == 1
1078
        assert np.array_equal(df["label"], edge_data.data["label"])
1079
        assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)
1080
        assert edge_data.type == ("_V", "_E", "_V")
1081
1082

        # add more fields into edges.csv
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "graph_id": np.arange(num_edges),
                "feat": np.random.randint(3, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1093
        df.to_csv(csv_path, index=False)
1094
        meta_edge = MetaEdge(file_name=csv_path)
1095
1096
1097
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1098
        assert len(edge_data.data) == 2
1099
1100
1101
1102
        assert np.array_equal(df["feat"], edge_data.data["feat"])
        assert np.array_equal(df["label"], edge_data.data["label"])
        assert np.array_equal(df["graph_id"], edge_data.graph_id)
        assert edge_data.type == ("_V", "_E", "_V")
1103
1104

        # required headers are missing
1105
        df = pd.DataFrame(
1106
            {"src_id": np.random.randint(num_nodes, size=num_edges)}
1107
1108
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1109
        df.to_csv(csv_path, index=False)
1110
        meta_edge = MetaEdge(file_name=csv_path)
1111
1112
        expect_except = False
        try:
1113
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1114
1115
1116
        except DGLError:
            expect_except = True
        assert expect_except
1117
        df = pd.DataFrame(
1118
            {"dst_id": np.random.randint(num_nodes, size=num_edges)}
1119
1120
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1121
        df.to_csv(csv_path, index=False)
1122
        meta_edge = MetaEdge(file_name=csv_path)
1123
1124
        expect_except = False
        try:
1125
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1126
1127
1128
1129
1130
1131
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_graph_data_from_csv():
1132
1133
1134
1135
1136
1137
    from dgl.data.csv_dataset_base import (
        DefaultDataParser,
        GraphData,
        MetaGraph,
    )

1138
1139
1140
    with tempfile.TemporaryDirectory() as test_dir:
        num_graphs = 100
        # minimum
1141
1142
        df = pd.DataFrame({"graph_id": np.arange(num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1143
        df.to_csv(csv_path, index=False)
1144
        meta_graph = MetaGraph(file_name=csv_path)
1145
1146
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1147
1148
1149
        assert len(graph_data.data) == 0

        # common case
1150
1151
1152
1153
1154
1155
1156
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1157
        df.to_csv(csv_path, index=False)
1158
        meta_graph = MetaGraph(file_name=csv_path)
1159
1160
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1161
        assert len(graph_data.data) == 1
1162
        assert np.array_equal(df["label"], graph_data.data["label"])
1163
1164

        # add more fields into graph.csv
1165
1166
1167
1168
1169
1170
1171
1172
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "feat": np.random.randint(3, size=num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1173
        df.to_csv(csv_path, index=False)
1174
        meta_graph = MetaGraph(file_name=csv_path)
1175
1176
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1177
        assert len(graph_data.data) == 2
1178
1179
        assert np.array_equal(df["feat"], graph_data.data["feat"])
        assert np.array_equal(df["label"], graph_data.data["label"])
1180
1181

        # required header is missing
1182
1183
        df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1184
        df.to_csv(csv_path, index=False)
1185
        meta_graph = MetaGraph(file_name=csv_path)
1186
1187
        expect_except = False
        try:
1188
            GraphData.load_from_csv(meta_graph, DefaultDataParser())
1189
1190
1191
1192
1193
        except DGLError:
            expect_except = True
        assert expect_except


1194
def _test_CSVDataset_single():
1195
1196
1197
1198
1199
1200
1201
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
        }
        with open(meta_yaml_path, "w") as f:
1227
1228
1229
1230
1231
1232
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes)
1233
1234
1235
1236
1237
1238
1239
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1240
1241
1242
1243
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
        feat_edata = np.random.rand(num_edges, num_dims)
        label_edata = np.random.randint(2, size=num_edges)
1244
1245
1246
1247
1248
1249
1250
1251
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
            }
        )
1252
1253
1254
1255
1256
1257
1258
1259
1260
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)

        # load CSVDataset
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1261
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1262
1263
1264
1265
1266
1267
            assert len(csv_dataset) == 1
            g = csv_dataset[0]
            assert not g.is_homogeneous
            assert csv_dataset.has_cache()
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1268
1269
1270
1271
1272
1273
1274
                assert F.array_equal(
                    F.tensor(feat_ndata, dtype=F.float32),
                    g.nodes[ntype].data["feat"],
                )
                assert np.array_equal(
                    label_ndata, F.asnumpy(g.nodes[ntype].data["label"])
                )
1275
1276
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1277
1278
1279
1280
1281
1282
1283
                assert F.array_equal(
                    F.tensor(feat_edata, dtype=F.float32),
                    g.edges[etype].data["feat"],
                )
                assert np.array_equal(
                    label_edata, F.asnumpy(g.edges[etype].data["label"])
                )
1284
1285


1286
def _test_CSVDataset_multiple():
1287
1288
1289
1290
1291
1292
1293
1294
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1321
1322
1323
1324
1325
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
        num_dims = 3
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
        feat_ndata = np.random.rand(num_nodes * num_graphs, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1340
1341
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
        feat_edata = np.random.rand(num_edges * num_graphs, num_dims)
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1365
1366
1367
1368
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        feat_gdata = np.random.rand(num_graphs, num_dims)
        label_gdata = np.random.randint(2, size=num_graphs)
1369
1370
1371
1372
1373
1374
1375
        df = pd.DataFrame(
            {
                "label": label_gdata,
                "feat": [line.tolist() for line in feat_gdata],
                "graph_id": np.arange(num_graphs),
            }
        )
1376
1377
        df.to_csv(graph_csv_path, index=False)

1378
        # load CSVDataset with default node/edge/gdata_parser
1379
1380
1381
1382
1383
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1384
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1385
1386
1387
            assert len(csv_dataset) == num_graphs
            assert csv_dataset.has_cache()
            assert len(csv_dataset.data) == 2
1388
1389
1390
1391
1392
            assert "feat" in csv_dataset.data
            assert "label" in csv_dataset.data
            assert F.array_equal(
                F.tensor(feat_gdata, dtype=F.float32), csv_dataset.data["feat"]
            )
1393
            for i, (g, g_data) in enumerate(csv_dataset):
1394
                assert not g.is_homogeneous
1395
1396
1397
1398
                assert F.asnumpy(g_data["label"]) == label_gdata[i]
                assert F.array_equal(
                    g_data["feat"], F.tensor(feat_gdata[i], dtype=F.float32)
                )
1399
1400
                for ntype in g.ntypes:
                    assert g.num_nodes(ntype) == num_nodes
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
                    assert F.array_equal(
                        F.tensor(
                            feat_ndata[i * num_nodes : (i + 1) * num_nodes],
                            dtype=F.float32,
                        ),
                        g.nodes[ntype].data["feat"],
                    )
                    assert np.array_equal(
                        label_ndata[i * num_nodes : (i + 1) * num_nodes],
                        F.asnumpy(g.nodes[ntype].data["label"]),
                    )
1412
1413
                for etype in g.etypes:
                    assert g.num_edges(etype) == num_edges
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
                    assert F.array_equal(
                        F.tensor(
                            feat_edata[i * num_edges : (i + 1) * num_edges],
                            dtype=F.float32,
                        ),
                        g.edges[etype].data["feat"],
                    )
                    assert np.array_equal(
                        label_edata[i * num_edges : (i + 1) * num_edges],
                        F.asnumpy(g.edges[etype].data["label"]),
                    )
1425
1426


1427
def _test_CSVDataset_customized_data_parser():
1428
1429
1430
1431
1432
1433
1434
1435
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
        meta_yaml_data = {
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1461
1462
1463
1464
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1477
1478
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1500
1501
1502
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        label_gdata = np.random.randint(2, size=num_graphs)
1503
1504
1505
        df = pd.DataFrame(
            {"label": label_gdata, "graph_id": np.arange(num_graphs)}
        )
1506
1507
1508
1509
1510
1511
1512
        df.to_csv(graph_csv_path, index=False)

        class CustDataParser:
            def __call__(self, df):
                data = {}
                for header in df:
                    dt = df[header].to_numpy().squeeze()
1513
                    if header == "label":
1514
1515
1516
                        dt += 2
                    data[header] = dt
                return data
1517

1518
1519
1520
        # load CSVDataset with customized node/edge/gdata_parser
        # specify via dict[ntype/etype, callable]
        csv_dataset = data.CSVDataset(
1521
1522
1523
1524
1525
1526
            test_dir,
            force_reload=True,
            ndata_parser={"user": CustDataParser()},
            edata_parser={("user", "like", "item"): CustDataParser()},
            gdata_parser=CustDataParser(),
        )
1527
1528
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1529
        assert "label" in csv_dataset.data
1530
        for i, (g, g_data) in enumerate(csv_dataset):
1531
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1532
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1533
1534
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1535
1536
1537
1538
1539
                offset = 2 if ntype == "user" else 0
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1540
1541
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1542
1543
1544
1545
1546
                offset = 2 if etype == "like" else 0
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1547
1548
        # specify via callable
        csv_dataset = data.CSVDataset(
1549
1550
1551
1552
1553
1554
            test_dir,
            force_reload=True,
            ndata_parser=CustDataParser(),
            edata_parser=CustDataParser(),
            gdata_parser=CustDataParser(),
        )
1555
1556
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1557
        assert "label" in csv_dataset.data
1558
1559
        for i, (g, g_data) in enumerate(csv_dataset):
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1560
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1561
1562
1563
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
                offset = 2
1564
1565
1566
1567
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1568
1569
1570
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
                offset = 2
1571
1572
1573
1574
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1575
1576
1577


def _test_NodeEdgeGraphData():
1578
1579
    from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData

1580
1581
    # NodeData basics
    num_nodes = 100
1582
    node_ids = np.arange(num_nodes, dtype=float)
1583
    ndata = NodeData(node_ids, {})
1584
    assert np.array_equal(ndata.id, node_ids)
1585
    assert len(ndata.data) == 0
1586
    assert ndata.type == "_V"
1587
1588
    assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
    # NodeData more
1589
    data = {"feat": np.random.rand(num_nodes, 3)}
1590
    graph_id = np.arange(num_nodes)
1591
1592
    ndata = NodeData(node_ids, data, type="user", graph_id=graph_id)
    assert ndata.type == "user"
1593
1594
1595
1596
1597
1598
1599
1600
    assert np.array_equal(ndata.graph_id, graph_id)
    assert len(ndata.data) == len(data)
    for k, v in data.items():
        assert k in ndata.data
        assert np.array_equal(ndata.data[k], v)
    # NodeData except
    expect_except = False
    try:
1601
1602
1603
1604
1605
        NodeData(
            np.arange(num_nodes),
            {"feat": np.random.rand(num_nodes + 1, 3)},
            graph_id=np.arange(num_nodes - 1),
        )
1606
1607
1608
1609
1610
1611
1612
1613
1614
    except:
        expect_except = True
    assert expect_except

    # EdgeData basics
    num_nodes = 100
    num_edges = 1000
    src_ids = np.random.randint(num_nodes, size=num_edges)
    dst_ids = np.random.randint(num_nodes, size=num_edges)
1615
    edata = EdgeData(src_ids, dst_ids, {})
1616
1617
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
1618
    assert edata.type == ("_V", "_E", "_V")
1619
1620
1621
    assert len(edata.data) == 0
    assert np.array_equal(edata.graph_id, np.full(num_edges, 0))
    # EdageData more
1622
1623
    src_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
    dst_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
1624
1625
    data = {"feat": np.random.rand(num_edges, 3)}
    etype = ("user", "like", "item")
1626
    graph_ids = np.arange(num_edges)
1627
    edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
    assert edata.type == etype
    assert len(edata.data) == len(data)
    for k, v in data.items():
        assert k in edata.data
        assert np.array_equal(edata.data[k], v)
    assert np.array_equal(edata.graph_id, graph_ids)
    # EdgeData except
    expect_except = False
    try:
1639
1640
1641
1642
1643
1644
        EdgeData(
            np.arange(num_edges),
            np.arange(num_edges + 1),
            {"feat": np.random.rand(num_edges - 1, 3)},
            graph_id=np.arange(num_edges + 2),
        )
1645
1646
1647
1648
1649
1650
1651
    except:
        expect_except = True
    assert expect_except

    # GraphData basics
    num_graphs = 10
    graph_ids = np.arange(num_graphs)
1652
    gdata = GraphData(graph_ids, {})
1653
1654
1655
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == 0
    # GraphData more
1656
    graph_ids = np.arange(num_graphs).astype(float)
1657
    data = {"feat": np.random.rand(num_graphs, 3)}
1658
    gdata = GraphData(graph_ids, data)
1659
1660
1661
1662
1663
1664
1665
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == len(data)
    for k, v in data.items():
        assert k in gdata.data
        assert np.array_equal(gdata.data[k], v)


1666
1667
1668
1669
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1670
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1671
1672
def test_csvdataset():
    _test_NodeEdgeGraphData()
1673
    _test_construct_graphs_node_ids()
1674
1675
1676
1677
1678
1679
1680
1681
    _test_construct_graphs_homo()
    _test_construct_graphs_hetero()
    _test_construct_graphs_multiple()
    _test_DefaultDataParser()
    _test_load_yaml_with_sanity_check()
    _test_load_node_data_from_csv()
    _test_load_edge_data_from_csv()
    _test_load_graph_data_from_csv()
1682
1683
1684
    _test_CSVDataset_single()
    _test_CSVDataset_multiple()
    _test_CSVDataset_customized_data_parser()
1685

1686
1687
1688
1689
1690

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1691
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1692
1693
def test_as_nodepred1():
    ds = data.AmazonCoBuyComputerDataset()
1694
    print("train_mask" in ds[0].ndata)
1695
1696
1697
1698
    new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1
    assert new_ds[0].num_nodes() == ds[0].num_nodes()
    assert new_ds[0].num_edges() == ds[0].num_edges()
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    assert "train_mask" in new_ds[0].ndata
    assert F.array_equal(
        new_ds.train_idx, F.nonzero_1d(new_ds[0].ndata["train_mask"])
    )
    assert F.array_equal(
        new_ds.val_idx, F.nonzero_1d(new_ds[0].ndata["val_mask"])
    )
    assert F.array_equal(
        new_ds.test_idx, F.nonzero_1d(new_ds[0].ndata["test_mask"])
    )
1709
1710

    ds = data.AIFBDataset()
1711
1712
1713
1714
    print("train_mask" in ds[0].nodes["Personen"].data)
    new_ds = data.AsNodePredDataset(
        ds, [0.8, 0.1, 0.1], "Personen", verbose=True
    )
1715
1716
1717
    assert len(new_ds) == 1
    assert new_ds[0].ntypes == ds[0].ntypes
    assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
    assert "train_mask" in new_ds[0].nodes["Personen"].data
    assert F.array_equal(
        new_ds.train_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["train_mask"]),
    )
    assert F.array_equal(
        new_ds.val_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["val_mask"]),
    )
    assert F.array_equal(
        new_ds.test_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["test_mask"]),
    )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1737
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1738
1739
1740
1741
def test_as_nodepred2():
    # test proper reprocessing

    # create
1742
1743
1744
1745
1746
1747
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1748
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1749
    # read from cache
1750
1751
1752
1753
1754
1755
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1756
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1757
    # invalid cache, re-read
1758
1759
1760
1761
1762
1763
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.1
    )
1764
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)
1765
1766

    # create
1767
1768
1769
1770
1771
1772
1773
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1774
    # read from cache
1775
1776
1777
1778
1779
1780
1781
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1782
    # invalid cache, re-read
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.1, 0.1, 0.8], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.1)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.1)


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Minjie Wang's avatar
Minjie Wang committed
1795
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Jinjing Zhou's avatar
Jinjing Zhou committed
1796
1797
def test_as_nodepred_ogb():
    from ogb.nodeproppred import DglNodePropPredDataset
1798
1799
1800
1801

    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"), split_ratio=None, verbose=True
    )
1802
    split = DglNodePropPredDataset("ogbn-arxiv").get_idx_split()
1803
    train_idx, val_idx, test_idx = split["train"], split["valid"], split["test"]
1804
1805
1806
    assert F.array_equal(ds.train_idx, F.tensor(train_idx))
    assert F.array_equal(ds.val_idx, F.tensor(val_idx))
    assert F.array_equal(ds.test_idx, F.tensor(test_idx))
Jinjing Zhou's avatar
Jinjing Zhou committed
1807
    # force generate new split
1808
1809
1810
1811
1812
1813
    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )

1814

1815
1816
1817
1818
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1819
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1820
1821
def test_as_linkpred():
    # create
1822
1823
1824
1825
1826
1827
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.8, 0.1, 0.1],
        neg_ratio=1,
        verbose=True,
    )
1828
1829
1830
1831
1832
    # Cora has 10556 edges, 10% test edges can be 1057
    assert ds.test_edges[0][0].shape[0] == 1057
    # negative samples, not guaranteed, so the assert is in a relaxed range
    assert 1000 <= ds.test_edges[1][0].shape[0] <= 1057
    # read from cache
1833
1834
1835
1836
1837
1838
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.7, 0.1, 0.2],
        neg_ratio=2,
        verbose=True,
    )
1839
1840
1841
1842
1843
    assert ds.test_edges[0][0].shape[0] == 2112
    # negative samples, not guaranteed to be ratio 2, so the assert is in a relaxed range
    assert 4000 < ds.test_edges[1][0].shape[0] <= 4224


1844
1845
1846
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
1847
1848
def test_as_linkpred_ogb():
    from ogb.linkproppred import DglLinkPropPredDataset
1849
1850
1851
1852

    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"), split_ratio=None, verbose=True
    )
1853
1854
1855
    # original dataset has 46329 test edges
    assert ds.test_edges[0][0].shape[0] == 46329
    # force generate new split
1856
1857
1858
1859
1860
    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )
1861
1862
    assert ds.test_edges[0][0].shape[0] == 235812

1863
1864
1865
1866
1867

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1868
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1869
1870
1871
1872
1873
1874
def test_as_nodepred_csvdataset():
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path = os.path.join(test_dir, "test_edges.csv")
        nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
1875
1876
1877
1878
1879
1880
1881
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [{"file_name": os.path.basename(nodes_csv_path)}],
            "edge_data": [{"file_name": os.path.basename(edges_csv_path)}],
        }
        with open(meta_yaml_path, "w") as f:
1882
1883
1884
1885
1886
1887
1888
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        num_classes = num_nodes
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.arange(num_classes)
1889
1890
1891
1892
1893
1894
1895
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1896
        df.to_csv(nodes_csv_path, index=False)
1897
1898
1899
1900
1901
1902
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
1903
1904
        df.to_csv(edges_csv_path, index=False)

1905
        ds = data.CSVDataset(test_dir, force_reload=True)
1906
1907
1908
1909
1910
1911
1912
        assert "feat" in ds[0].ndata
        assert "label" in ds[0].ndata
        assert "train_mask" not in ds[0].ndata
        assert not hasattr(ds[0], "num_classes")
        new_ds = data.AsNodePredDataset(
            ds, split_ratio=[0.8, 0.1, 0.1], force_reload=True
        )
1913
        assert new_ds.num_classes == num_classes
1914
1915
1916
        assert "feat" in new_ds[0].ndata
        assert "label" in new_ds[0].ndata
        assert "train_mask" in new_ds[0].ndata
1917

1918
1919
1920
1921
1922

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1923
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1924
def test_as_graphpred():
1925
    ds = data.GINDataset(name="MUTAG", self_loop=True)
Mufei Li's avatar
Mufei Li committed
1926
1927
1928
1929
1930
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 188
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1931
    ds = data.FakeNewsDataset("politifact", "profile")
Mufei Li's avatar
Mufei Li committed
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
    new_ds = data.AsGraphPredDataset(ds, verbose=True)
    assert len(new_ds) == 314
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.QM7bDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 7211
    assert new_ds.num_tasks == 14
    assert new_ds.num_classes is None

1943
    ds = data.QM9Dataset(label_keys=["mu", "gap"])
Mufei Li's avatar
Mufei Li committed
1944
1945
1946
1947
1948
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1949
    ds = data.QM9EdgeDataset(label_keys=["mu", "alpha"])
Mufei Li's avatar
Mufei Li committed
1950
1951
1952
1953
1954
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1955
    ds = data.TUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1956
1957
1958
1959
1960
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1961
    ds = data.LegacyTUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.BA2MotifDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1000
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1973
1974
1975
1976
1977

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1978
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1979
def test_as_graphpred_reprocess():
1980
1981
1982
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1983
1984
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1985
1986
1987
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1988
1989
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1990
1991
1992
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1993
1994
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1995
1996
1997
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1998
1999
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2000
2001
2002
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2003
2004
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2005
2006
2007
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2019
2020
2021
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2022
2023
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2024
2025
2026
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2027
2028
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2029
2030
2031
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
2032
2033
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2034
2035
2036
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2037
2038
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2039
2040
2041
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2042
2043
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2044
2045
2046
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
2047
2048
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2049
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2050
2051
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2052
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2053
2054
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2055
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2056
2057
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2058
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2059
2060
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2061
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2062
2063
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2064
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2076
2077
2078
2079

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Mufei Li's avatar
Mufei Li committed
2080
2081
def test_as_graphpred_ogb():
    from ogb.graphproppred import DglGraphPropPredDataset
2082
2083
2084
2085

    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"), split_ratio=None, verbose=True
    )
Mufei Li's avatar
Mufei Li committed
2086
2087
    assert len(ds.train_idx) == 32901
    # force generate new split
2088
2089
2090
2091
2092
    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"),
        split_ratio=[0.6, 0.2, 0.2],
        verbose=True,
    )
Mufei Li's avatar
Mufei Li committed
2093
2094
    assert len(ds.train_idx) == 24676

2095
2096

if __name__ == "__main__":
2097
    test_minigc()
2098
    test_gin()
2099
    test_data_hash()
2100
2101
2102
    test_tudataset_regression()
    test_fraud()
    test_fakenews()
2103
    test_csvdataset()
2104
2105
    test_as_nodepred1()
    test_as_nodepred2()
2106
    test_as_nodepred_csvdataset()