test_data.py 74.1 KB
Newer Older
1
import gzip
2
import io
3
import os
4
import tarfile
5
import tempfile
6
import unittest
7

8
import backend as F
9
10
11

import dgl
import dgl.data as data
12
import numpy as np
13
14
import pandas as pd
import pytest
15
import yaml
16
from dgl import DGLError
17

18
19
20
21
22

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
23
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
24
25
26
27
def test_minigc():
    ds = data.MiniGCDataset(16, 10, 20)
    g, l = list(zip(*ds))
    print(g, l)
28
29
30
31
32
    g1 = ds[0][0]
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    ds = data.MiniGCDataset(16, 10, 20, transform=transform)
    g2 = ds[0][0]
    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
33

34
35
36
37
38

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
39
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
40
41
def test_gin():
    ds_n_graphs = {
42
43
44
45
46
        "MUTAG": 188,
        "IMDBBINARY": 1000,
        "IMDBMULTI": 1500,
        "PROTEINS": 1113,
        "PTC": 344,
47
    }
48
    transform = dgl.AddSelfLoop(allow_duplicate=True)
49
50
51
    for name, n_graphs in ds_n_graphs.items():
        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
        assert len(ds) == n_graphs, (len(ds), name)
52
        g1 = ds[0][0]
53
54
55
        ds = data.GINDataset(
            name, self_loop=False, degree_as_nlabel=False, transform=transform
        )
56
57
        g2 = ds[0][0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
Mufei Li's avatar
Mufei Li committed
58
        assert ds.num_classes == ds.gclasses
59

60
61
62
63
64

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
65
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
66
def test_fraud():
67
68
    transform = dgl.AddSelfLoop(allow_duplicate=True)

69
    g = data.FraudDataset("amazon")[0]
70
    assert g.num_nodes() == 11944
71
    num_edges1 = g.num_edges()
72
    g2 = data.FraudDataset("amazon", transform=transform)[0]
73
74
    # 3 edge types
    assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
75
76
77

    g = data.FraudAmazonDataset()[0]
    assert g.num_nodes() == 11944
78
79
80
    g2 = data.FraudAmazonDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
81
82
83

    g = data.FraudYelpDataset()[0]
    assert g.num_nodes() == 45954
84
85
86
    g2 = data.FraudYelpDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
87

88
89
90
91
92

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
93
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
94
def test_fakenews():
95
96
    transform = dgl.AddSelfLoop(allow_duplicate=True)

97
    ds = data.FakeNewsDataset("politifact", "bert")
98
    assert len(ds) == 314
99
    g = ds[0][0]
100
    g2 = data.FakeNewsDataset("politifact", "bert", transform=transform)[0][0]
101
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
102

103
    ds = data.FakeNewsDataset("gossipcop", "profile")
104
    assert len(ds) == 5464
105
    g = ds[0][0]
106
    g2 = data.FakeNewsDataset("gossipcop", "profile", transform=transform)[0][0]
107
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
Jinjing Zhou's avatar
Jinjing Zhou committed
108

109
110
111
112
113

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
114
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
115
def test_tudataset_regression():
116
    ds = data.TUDataset("ZINC_test", force_reload=True)
Mufei Li's avatar
Mufei Li committed
117
    assert ds.num_classes == ds.num_labels
Jinjing Zhou's avatar
Jinjing Zhou committed
118
    assert len(ds) == 5000
119
    g = ds[0][0]
Jinjing Zhou's avatar
Jinjing Zhou committed
120

121
    transform = dgl.AddSelfLoop(allow_duplicate=True)
122
    ds = data.TUDataset("ZINC_test", force_reload=True, transform=transform)
123
124
    g2 = ds[0][0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
125

126
127
128
129
130

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
131
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
132
133
134
def test_data_hash():
    class HashTestDataset(data.DGLDataset):
        def __init__(self, hash_key=()):
135
            super(HashTestDataset, self).__init__("hashtest", hash_key=hash_key)
136

137
138
139
        def _load(self):
            pass

140
141
142
    a = HashTestDataset((True, 0, "1", (1, 2, 3)))
    b = HashTestDataset((True, 0, "1", (1, 2, 3)))
    c = HashTestDataset((True, 0, "1", (1, 2, 4)))
143
144
145
    assert a.hash == b.hash
    assert a.hash != c.hash

146

147
148
149
150
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
151
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
152
def test_citation_graph():
153
154
    transform = dgl.AddSelfLoop(allow_duplicate=True)

155
    # cora
156
    g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]
157
158
159
160
    assert g.num_nodes() == 2708
    assert g.num_edges() == 10556
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
161
162
    g2 = data.CoraGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
163
164

    # Citeseer
165
    g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]
166
167
168
169
    assert g.num_nodes() == 3327
    assert g.num_edges() == 9228
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
170
171
    g2 = data.CiteseerGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
172
173

    # Pubmed
174
    g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]
175
176
177
178
    assert g.num_nodes() == 19717
    assert g.num_edges() == 88651
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
179
180
    g2 = data.PubmedGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
181
182


183
184
185
186
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
187
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
188
def test_gnn_benchmark():
189
190
    transform = dgl.AddSelfLoop(allow_duplicate=True)

191
192
193
194
195
196
    # AmazonCoBuyComputerDataset
    g = data.AmazonCoBuyComputerDataset()[0]
    assert g.num_nodes() == 13752
    assert g.num_edges() == 491722
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
197
198
    g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
199
200
201
202
203
204
205

    # AmazonCoBuyPhotoDataset
    g = data.AmazonCoBuyPhotoDataset()[0]
    assert g.num_nodes() == 7650
    assert g.num_edges() == 238163
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
206
207
    g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
208
209
210
211
212
213
214

    # CoauthorPhysicsDataset
    g = data.CoauthorPhysicsDataset()[0]
    assert g.num_nodes() == 34493
    assert g.num_edges() == 495924
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
215
216
    g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
217
218
219
220
221
222
223

    # CoauthorCSDataset
    g = data.CoauthorCSDataset()[0]
    assert g.num_nodes() == 18333
    assert g.num_edges() == 163788
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
224
225
    g2 = data.CoauthorCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
226
227
228
229
230
231
232

    # CoraFullDataset
    g = data.CoraFullDataset()[0]
    assert g.num_nodes() == 19793
    assert g.num_edges() == 126842
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
233
234
    g2 = data.CoraFullDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
235
236


237
238
239
240
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
241
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
242
243
244
245
246
247
248
249
def test_reddit():
    # RedditDataset
    g = data.RedditDataset()[0]
    assert g.num_nodes() == 232965
    assert g.num_edges() == 114615892
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

250
251
252
253
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.RedditDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

254
255
256
257
258

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
259
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
260
261
262
263
def test_explain_syn():
    dataset = data.BAShapeDataset()
    assert dataset.num_classes == 4
    g = dataset[0]
264
265
    assert "label" in g.ndata
    assert "feat" in g.ndata
266
267
268
269
270
271
272
273
274
275
276

    g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BACommunityDataset()
    assert dataset.num_classes == 8
    g = dataset[0]
277
278
    assert "label" in g.ndata
    assert "feat" in g.ndata
279
280
281
282
283
284
285
286
287
288
289

    g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeCycleDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
290
291
    assert "label" in g.ndata
    assert "feat" in g.ndata
292
293
294
295
296
297
298
299
300
301
302

    g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeGridDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
303
304
    assert "label" in g.ndata
    assert "feat" in g.ndata
305
306
307
308
309
310
311
312
313
314
315

    g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BA2MotifDataset()
    assert dataset.num_classes == 2
    g, label = dataset[0]
316
    assert "feat" in g.ndata
317

318
319
320
321
322

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
323
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
324
325
326
327
328
329
330
331
332
333
334
def test_wiki_cs():
    g = data.WikiCSDataset()[0]
    assert g.num_nodes() == 11701
    assert g.num_edges() == 431726
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.WikiCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

335

336
@unittest.skip(reason="Dataset too large to download for the latest CI.")
Minjie Wang's avatar
Minjie Wang committed
337
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
338
339
340
341
342
343
344
345
346
347
348
def test_yelp():
    g = data.YelpDataset(reorder=True)[0]
    assert g.num_nodes() == 716847
    assert g.num_edges() == 13954819
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.YelpDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

349
350
351
352
353

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
354
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
355
356
357
358
359
360
361
362
363
364
def test_flickr():
    g = data.FlickrDataset(reorder=True)[0]
    assert g.num_nodes() == 89250
    assert g.num_edges() == 899756
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.FlickrDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
365

366

367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_cluster():
    mode_n_graphs = {
        "train": 10000,
        "valid": 1000,
        "test": 1000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        ds = data.CLUSTERDataset(mode=mode)
        assert len(ds) == n_graphs, (len(ds), mode)
        g1 = ds[0]
        ds = data.CLUSTERDataset(mode=mode, transform=transform)
        g2 = ds[0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        assert ds.num_classes == 6


389
390
391
392
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
393
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
394
395
396
def test_extract_archive():
    # gzip
    with tempfile.TemporaryDirectory() as src_dir:
397
398
        gz_file = "gz_archive"
        gz_path = os.path.join(src_dir, gz_file + ".gz")
399
        content = b"test extract archive gzip"
400
        with gzip.open(gz_path, "wb") as f:
401
402
403
404
405
            f.write(content)
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, gz_file))

406
407
408
409
410
411
412
413
414
415
416
417
418
419
    # tar
    with tempfile.TemporaryDirectory() as src_dir:
        tar_file = "tar_archive"
        tar_path = os.path.join(src_dir, tar_file + ".tar")
        # default encode to utf8
        content = "test extract archive tar\n".encode()
        info = tarfile.TarInfo(name="tar_archive")
        info.size = len(content)
        with tarfile.open(tar_path, "w") as f:
            f.addfile(info, io.BytesIO(content))
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, tar_file))

420

421
def _test_construct_graphs_node_ids():
422
423
424
425
426
427
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

428
429
430
431
432
433
434
435
436
437
438
    num_nodes = 100
    num_edges = 1000

    # node IDs are required to be unique
    node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    expect_except = False
    try:
439
        _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)
440
441
442
443
444
445
446
447
448
449
450
    except:
        expect_except = True
    assert expect_except

    # node IDs are already labelled from 0~num_nodes-1
    node_ids = np.arange(num_nodes)
    np.random.shuffle(node_ids)
    _, idx = np.unique(node_ids, return_index=True)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_feat = np.random.rand(num_nodes, 3)
451
    node_data = NodeData(node_ids, {"feat": node_feat})
452
453
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
454
455
        node_data, edge_data
    )
456
457
458
459
460
461
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)
462
463
464
    assert F.array_equal(
        F.tensor(node_feat[idx], dtype=F.float32), g.ndata["feat"]
    )
465
466
467

    # node IDs are mixed with numeric and non-numeric values
    # homogeneous graph
468
    node_ids = [1, 2, 3, "a"]
469
    src_ids = [1, 2, 3]
470
    dst_ids = ["a", 1, 2]
471
472
473
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
474
475
        node_data, edge_data
    )
476
477
478
479
480
481
482
483
484
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)

    # heterogeneous graph
    node_ids_user = [1, 2, 3]
485
    node_ids_item = ["a", "b", "c"]
486
487
    src_ids = node_ids_user
    dst_ids = node_ids_item
488
489
490
    node_data_user = NodeData(node_ids_user, {}, type="user")
    node_data_item = NodeData(node_ids_item, {}, type="item")
    edge_data = EdgeData(src_ids, dst_ids, {}, type=("user", "like", "item"))
491
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
492
493
        [node_data_user, node_data_item], edge_data
    )
494
495
496
497
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
498
499
    assert g.num_nodes("user") == len(node_ids_user)
    assert g.num_nodes("item") == len(node_ids_item)
500
501
502
    assert g.num_edges() == len(src_ids)


503
def _test_construct_graphs_homo():
504
505
506
507
508
509
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

510
    # node_id could be non-sorted, non-numeric.
511
512
513
514
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
    node_ids = np.random.choice(
515
516
        np.arange(num_nodes * 2), size=num_nodes, replace=False
    )
517
    assert len(node_ids) == num_nodes
518
    # to be non-sorted
519
    np.random.shuffle(node_ids)
520
    # to be non-numeric
521
522
523
524
525
    node_ids = ["id_{}".format(id) for id in node_ids]
    t_ndata = {
        "feat": np.random.rand(num_nodes, num_dims),
        "label": np.random.randint(2, size=num_nodes),
    }
526
    _, u_indices = np.unique(node_ids, return_index=True)
527
528
529
530
    ndata = {
        "feat": t_ndata["feat"][u_indices],
        "label": t_ndata["label"][u_indices],
    }
531
    node_data = NodeData(node_ids, t_ndata)
532
533
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
534
535
536
537
    edata = {
        "feat": np.random.rand(num_edges, num_dims),
        "label": np.random.randint(2, size=num_edges),
    }
538
539
    edge_data = EdgeData(src_ids, dst_ids, edata)
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
540
541
        node_data, edge_data
    )
542
543
544
545
546
547
548
549
550
551
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
552
553
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
554
555
556
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

557
558
559
560
561
    assert_data(ndata, g.ndata)
    assert_data(edata, g.edata)


def _test_construct_graphs_hetero():
562
563
564
565
566
567
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

568
    # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
569
570
571
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
572
    ntypes = ["user", "item"]
573
574
575
576
577
    node_data = []
    node_ids_dict = {}
    ndata_dict = {}
    for ntype in ntypes:
        node_ids = np.random.choice(
578
579
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
580
        assert len(node_ids) == num_nodes
581
        # to be non-sorted
582
        np.random.shuffle(node_ids)
583
        # to be non-numeric
584
585
586
587
588
        node_ids = ["id_{}".format(id) for id in node_ids]
        t_ndata = {
            "feat": np.random.rand(num_nodes, num_dims),
            "label": np.random.randint(2, size=num_nodes),
        }
589
        _, u_indices = np.unique(node_ids, return_index=True)
590
591
592
593
        ndata = {
            "feat": t_ndata["feat"][u_indices],
            "label": t_ndata["label"][u_indices],
        }
594
        node_data.append(NodeData(node_ids, t_ndata, type=ntype))
595
596
        node_ids_dict[ntype] = node_ids
        ndata_dict[ntype] = ndata
597
    etypes = [("user", "follow", "user"), ("user", "like", "item")]
598
599
600
601
602
    edge_data = []
    edata_dict = {}
    for src_type, e_type, dst_type in etypes:
        src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)
        dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
603
604
605
606
607
608
609
        edata = {
            "feat": np.random.rand(num_edges, num_dims),
            "label": np.random.randint(2, size=num_edges),
        }
        edge_data.append(
            EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))
        )
610
        edata_dict[(src_type, e_type, dst_type)] = edata
611
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
612
613
        node_data, edge_data
    )
614
615
616
617
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
618
619
    assert g.num_nodes() == num_nodes * len(ntypes)
    assert g.num_edges() == num_edges * len(etypes)
620
621
622
623

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
624
625
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
626
627
628
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

629
630
631
632
633
634
635
636
637
    for ntype in g.ntypes:
        assert g.num_nodes(ntype) == num_nodes
        assert_data(ndata_dict[ntype], g.nodes[ntype].data)
    for etype in g.canonical_etypes:
        assert g.num_edges(etype) == num_edges
        assert_data(edata_dict[etype], g.edges[etype].data)


def _test_construct_graphs_multiple():
638
639
640
641
642
643
644
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        GraphData,
        NodeData,
    )

645
646
647
648
649
650
651
652
653
654
655
656
    num_nodes = 100
    num_edges = 1000
    num_graphs = 10
    num_dims = 3
    node_ids = np.array([], dtype=np.int)
    src_ids = np.array([], dtype=np.int)
    dst_ids = np.array([], dtype=np.int)
    ngraph_ids = np.array([], dtype=np.int)
    egraph_ids = np.array([], dtype=np.int)
    u_indices = np.array([], dtype=np.int)
    for i in range(num_graphs):
        l_node_ids = np.random.choice(
657
658
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
659
660
661
662
        node_ids = np.append(node_ids, l_node_ids)
        _, l_u_indices = np.unique(l_node_ids, return_index=True)
        u_indices = np.append(u_indices, l_u_indices)
        ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))
663
664
665
666
667
668
        src_ids = np.append(
            src_ids, np.random.choice(l_node_ids, size=num_edges)
        )
        dst_ids = np.append(
            dst_ids, np.random.choice(l_node_ids, size=num_edges)
        )
669
        egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
670
671
672
673
674
    ndata = {
        "feat": np.random.rand(num_nodes * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_nodes * num_graphs),
    }
    ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids]
675
    node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
676
677
678
679
680
    egraph_ids = ["graph_{}".format(id) for id in egraph_ids]
    edata = {
        "feat": np.random.rand(num_edges * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_edges * num_graphs),
    }
681
    edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
682
683
684
685
686
    gdata = {
        "feat": np.random.rand(num_graphs, num_dims),
        "label": np.random.randint(2, size=num_graphs),
    }
    graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)]
687
    graph_data = GraphData(graph_ids, gdata)
688
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
689
690
        node_data, edge_data, graph_data
    )
691
692
693
    assert len(graphs) == num_graphs
    assert len(data_dict) == len(gdata)
    for k, v in data_dict.items():
694
        assert F.dtype(v) != F.float64
695
696
697
698
        assert F.array_equal(
            F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)),
            v,
        )
699
700
701
702
703
704
705
706
    for i, g in enumerate(graphs):
        assert g.is_homogeneous
        assert g.num_nodes() == num_nodes
        assert g.num_edges() == num_edges

        def assert_data(lhs, rhs, size, node=False):
            for key, value in lhs.items():
                assert key in rhs
707
                value = value[i * size : (i + 1) * size]
708
                if node:
709
                    indices = u_indices[i * size : (i + 1) * size]
710
                    value = value[indices]
711
712
                assert F.dtype(rhs[key]) != F.float64
                assert F.array_equal(
713
714
715
                    F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
                )

716
717
718
719
        assert_data(ndata, g.ndata, num_nodes, node=True)
        assert_data(edata, g.edata, num_edges)

    # Graph IDs found in node/edge CSV but not in graph CSV
720
    graph_data = GraphData(np.arange(num_graphs - 2), {})
721
722
    expect_except = False
    try:
723
        _, _ = DGLGraphConstructor.construct_graphs(
724
725
            node_data, edge_data, graph_data
        )
726
727
728
729
730
731
    except:
        expect_except = True
    assert expect_except


def _test_DefaultDataParser():
732
    from dgl.data.csv_dataset_base import DefaultDataParser
733

734
735
736
737
738
739
740
741
742
    # common csv
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
        num_nodes = 5
        num_labels = 3
        num_dims = 2
        node_id = np.arange(num_nodes)
        label = np.random.randint(num_labels, size=num_nodes)
        feat = np.random.rand(num_nodes, num_dims)
743
744
745
746
747
748
749
        df = pd.DataFrame(
            {
                "node_id": node_id,
                "label": label,
                "feat": [line.tolist() for line in feat],
            }
        )
750
        df.to_csv(csv_path, index=False)
751
        dp = DefaultDataParser()
752
753
        df = pd.read_csv(csv_path)
        dt = dp(df)
754
755
756
        assert np.array_equal(node_id, dt["node_id"])
        assert np.array_equal(label, dt["label"])
        assert np.array_equal(feat, dt["feat"])
757
758
759
    # string consists of non-numeric values
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
760
        df = pd.DataFrame({"label": ["a", "b", "c"]})
761
        df.to_csv(csv_path, index=False)
762
        dp = DefaultDataParser()
763
764
765
766
767
768
769
770
771
772
        df = pd.read_csv(csv_path)
        expect_except = False
        try:
            dt = dp(df)
        except:
            expect_except = True
        assert expect_except
    # csv has index column which is ignored as it's unnamed
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
773
        df = pd.DataFrame({"label": [1, 2, 3]})
774
        df.to_csv(csv_path)
775
        dp = DefaultDataParser()
776
777
778
779
780
781
        df = pd.read_csv(csv_path)
        dt = dp(df)
        assert len(dt) == 1


def _test_load_yaml_with_sanity_check():
782
    from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
783

784
    with tempfile.TemporaryDirectory() as test_dir:
785
        yaml_path = os.path.join(test_dir, "meta.yaml")
786
        # workable but meaningless usually
787
788
789
790
791
792
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
        with open(yaml_path, "w") as f:
793
            yaml.dump(yaml_data, f, sort_keys=False)
794
        meta = load_yaml_with_sanity_check(yaml_path)
795
796
797
        assert meta.version == "1.0.0"
        assert meta.dataset_name == "default"
        assert meta.separator == ","
798
799
800
801
        assert len(meta.node_data) == 0
        assert len(meta.edge_data) == 0
        assert meta.graph_data is None
        # minimum with required fields only
802
803
804
805
806
807
808
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
809
            yaml.dump(yaml_data, f, sort_keys=False)
810
        meta = load_yaml_with_sanity_check(yaml_path)
811
        for ndata in meta.node_data:
812
813
814
815
            assert ndata.file_name == "nodes.csv"
            assert ndata.ntype == "_V"
            assert ndata.graph_id_field == "graph_id"
            assert ndata.node_id_field == "node_id"
816
        for edata in meta.edge_data:
817
818
819
820
821
            assert edata.file_name == "edges.csv"
            assert edata.etype == ["_V", "_E", "_V"]
            assert edata.graph_id_field == "graph_id"
            assert edata.src_id_field == "src_id"
            assert edata.dst_id_field == "dst_id"
822
        # optional fields are specified
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "separator": "|",
            "node_data": [
                {
                    "file_name": "nodes.csv",
                    "ntype": "user",
                    "graph_id_field": "xxx",
                    "node_id_field": "xxx",
                }
            ],
            "edge_data": [
                {
                    "file_name": "edges.csv",
                    "etype": ["user", "follow", "user"],
                    "graph_id_field": "xxx",
                    "src_id_field": "xxx",
                    "dst_id_field": "xxx",
                }
            ],
            "graph_data": {"file_name": "graph.csv", "graph_id_field": "xxx"},
        }
        with open(yaml_path, "w") as f:
847
            yaml.dump(yaml_data, f, sort_keys=False)
848
        meta = load_yaml_with_sanity_check(yaml_path)
849
850
        assert len(meta.node_data) == 1
        ndata = meta.node_data[0]
851
852
853
        assert ndata.ntype == "user"
        assert ndata.graph_id_field == "xxx"
        assert ndata.node_id_field == "xxx"
854
855
        assert len(meta.edge_data) == 1
        edata = meta.edge_data[0]
856
857
858
859
        assert edata.etype == ["user", "follow", "user"]
        assert edata.graph_id_field == "xxx"
        assert edata.src_id_field == "xxx"
        assert edata.dst_id_field == "xxx"
860
        assert meta.graph_data is not None
861
862
        assert meta.graph_data.file_name == "graph.csv"
        assert meta.graph_data.graph_id_field == "xxx"
863
        # some required fields are missing
864
865
866
867
868
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
869
870
871
        for field in yaml_data.keys():
            ydata = {k: v for k, v in yaml_data.items()}
            ydata.pop(field)
872
            with open(yaml_path, "w") as f:
873
874
875
                yaml.dump(ydata, f, sort_keys=False)
            expect_except = False
            try:
876
                meta = load_yaml_with_sanity_check(yaml_path)
877
878
879
880
            except:
                expect_except = True
            assert expect_except
        # inapplicable version
881
882
883
884
885
886
887
        yaml_data = {
            "version": "0.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes_0.csv"}],
            "edge_data": [{"file_name": "edges_0.csv"}],
        }
        with open(yaml_path, "w") as f:
888
889
890
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
891
            meta = load_yaml_with_sanity_check(yaml_path)
892
893
894
895
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate node types
896
897
898
899
900
901
902
903
904
905
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [
                {"file_name": "nodes.csv"},
                {"file_name": "nodes.csv"},
            ],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
906
907
908
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
909
            meta = load_yaml_with_sanity_check(yaml_path)
910
911
912
913
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate edge types
914
915
916
917
918
919
920
921
922
923
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [
                {"file_name": "edges.csv"},
                {"file_name": "edges.csv"},
            ],
        }
        with open(yaml_path, "w") as f:
924
925
926
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
927
            meta = load_yaml_with_sanity_check(yaml_path)
928
929
930
931
932
933
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_node_data_from_csv():
934
935
    from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData

936
937
938
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        # minimum
939
940
        df = pd.DataFrame({"node_id": np.arange(num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
941
        df.to_csv(csv_path, index=False)
942
        meta_node = MetaNode(file_name=csv_path)
943
944
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
945
946
947
        assert len(node_data.data) == 0

        # common case
948
949
950
951
952
953
954
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
955
        df.to_csv(csv_path, index=False)
956
        meta_node = MetaNode(file_name=csv_path)
957
958
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
959
        assert len(node_data.data) == 1
960
        assert np.array_equal(df["label"], node_data.data["label"])
961
        assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)
962
        assert node_data.type == "_V"
963
964

        # add more fields into nodes.csv
965
966
967
968
969
970
971
972
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
                "graph_id": np.full(num_nodes, 1),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
973
        df.to_csv(csv_path, index=False)
974
        meta_node = MetaNode(file_name=csv_path)
975
976
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
977
        assert len(node_data.data) == 1
978
979
980
        assert np.array_equal(df["label"], node_data.data["label"])
        assert np.array_equal(df["graph_id"], node_data.graph_id)
        assert node_data.type == "_V"
981
982

        # required header is missing
983
984
        df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
985
        df.to_csv(csv_path, index=False)
986
        meta_node = MetaNode(file_name=csv_path)
987
988
        expect_except = False
        try:
989
            NodeData.load_from_csv(meta_node, DefaultDataParser())
990
991
992
993
994
995
        except:
            expect_except = True
        assert expect_except


def _test_load_edge_data_from_csv():
996
997
    from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge

998
999
1000
1001
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        num_edges = 1000
        # minimum
1002
1003
1004
1005
1006
1007
1008
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1009
        df.to_csv(csv_path, index=False)
1010
        meta_edge = MetaEdge(file_name=csv_path)
1011
1012
1013
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1014
1015
1016
        assert len(edge_data.data) == 0

        # common case
1017
1018
1019
1020
1021
1022
1023
1024
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1025
        df.to_csv(csv_path, index=False)
1026
        meta_edge = MetaEdge(file_name=csv_path)
1027
1028
1029
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1030
        assert len(edge_data.data) == 1
1031
        assert np.array_equal(df["label"], edge_data.data["label"])
1032
        assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)
1033
        assert edge_data.type == ("_V", "_E", "_V")
1034
1035

        # add more fields into edges.csv
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "graph_id": np.arange(num_edges),
                "feat": np.random.randint(3, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1046
        df.to_csv(csv_path, index=False)
1047
        meta_edge = MetaEdge(file_name=csv_path)
1048
1049
1050
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1051
        assert len(edge_data.data) == 2
1052
1053
1054
1055
        assert np.array_equal(df["feat"], edge_data.data["feat"])
        assert np.array_equal(df["label"], edge_data.data["label"])
        assert np.array_equal(df["graph_id"], edge_data.graph_id)
        assert edge_data.type == ("_V", "_E", "_V")
1056
1057

        # required headers are missing
1058
        df = pd.DataFrame(
1059
            {"src_id": np.random.randint(num_nodes, size=num_edges)}
1060
1061
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1062
        df.to_csv(csv_path, index=False)
1063
        meta_edge = MetaEdge(file_name=csv_path)
1064
1065
        expect_except = False
        try:
1066
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1067
1068
1069
        except DGLError:
            expect_except = True
        assert expect_except
1070
        df = pd.DataFrame(
1071
            {"dst_id": np.random.randint(num_nodes, size=num_edges)}
1072
1073
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1074
        df.to_csv(csv_path, index=False)
1075
        meta_edge = MetaEdge(file_name=csv_path)
1076
1077
        expect_except = False
        try:
1078
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1079
1080
1081
1082
1083
1084
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_graph_data_from_csv():
1085
1086
1087
1088
1089
1090
    from dgl.data.csv_dataset_base import (
        DefaultDataParser,
        GraphData,
        MetaGraph,
    )

1091
1092
1093
    with tempfile.TemporaryDirectory() as test_dir:
        num_graphs = 100
        # minimum
1094
1095
        df = pd.DataFrame({"graph_id": np.arange(num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1096
        df.to_csv(csv_path, index=False)
1097
        meta_graph = MetaGraph(file_name=csv_path)
1098
1099
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1100
1101
1102
        assert len(graph_data.data) == 0

        # common case
1103
1104
1105
1106
1107
1108
1109
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1110
        df.to_csv(csv_path, index=False)
1111
        meta_graph = MetaGraph(file_name=csv_path)
1112
1113
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1114
        assert len(graph_data.data) == 1
1115
        assert np.array_equal(df["label"], graph_data.data["label"])
1116
1117

        # add more fields into graph.csv
1118
1119
1120
1121
1122
1123
1124
1125
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "feat": np.random.randint(3, size=num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1126
        df.to_csv(csv_path, index=False)
1127
        meta_graph = MetaGraph(file_name=csv_path)
1128
1129
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1130
        assert len(graph_data.data) == 2
1131
1132
        assert np.array_equal(df["feat"], graph_data.data["feat"])
        assert np.array_equal(df["label"], graph_data.data["label"])
1133
1134

        # required header is missing
1135
1136
        df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1137
        df.to_csv(csv_path, index=False)
1138
        meta_graph = MetaGraph(file_name=csv_path)
1139
1140
        expect_except = False
        try:
1141
            GraphData.load_from_csv(meta_graph, DefaultDataParser())
1142
1143
1144
1145
1146
        except DGLError:
            expect_except = True
        assert expect_except


1147
def _test_CSVDataset_single():
1148
1149
1150
1151
1152
1153
1154
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
        }
        with open(meta_yaml_path, "w") as f:
1180
1181
1182
1183
1184
1185
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes)
1186
1187
1188
1189
1190
1191
1192
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1193
1194
1195
1196
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
        feat_edata = np.random.rand(num_edges, num_dims)
        label_edata = np.random.randint(2, size=num_edges)
1197
1198
1199
1200
1201
1202
1203
1204
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
            }
        )
1205
1206
1207
1208
1209
1210
1211
1212
1213
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)

        # load CSVDataset
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1214
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1215
1216
1217
1218
1219
1220
            assert len(csv_dataset) == 1
            g = csv_dataset[0]
            assert not g.is_homogeneous
            assert csv_dataset.has_cache()
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1221
1222
1223
1224
1225
1226
1227
                assert F.array_equal(
                    F.tensor(feat_ndata, dtype=F.float32),
                    g.nodes[ntype].data["feat"],
                )
                assert np.array_equal(
                    label_ndata, F.asnumpy(g.nodes[ntype].data["label"])
                )
1228
1229
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1230
1231
1232
1233
1234
1235
1236
                assert F.array_equal(
                    F.tensor(feat_edata, dtype=F.float32),
                    g.edges[etype].data["feat"],
                )
                assert np.array_equal(
                    label_edata, F.asnumpy(g.edges[etype].data["label"])
                )
1237
1238


1239
def _test_CSVDataset_multiple():
1240
1241
1242
1243
1244
1245
1246
1247
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1274
1275
1276
1277
1278
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
        num_dims = 3
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
        feat_ndata = np.random.rand(num_nodes * num_graphs, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1293
1294
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
        feat_edata = np.random.rand(num_edges * num_graphs, num_dims)
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1318
1319
1320
1321
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        feat_gdata = np.random.rand(num_graphs, num_dims)
        label_gdata = np.random.randint(2, size=num_graphs)
1322
1323
1324
1325
1326
1327
1328
        df = pd.DataFrame(
            {
                "label": label_gdata,
                "feat": [line.tolist() for line in feat_gdata],
                "graph_id": np.arange(num_graphs),
            }
        )
1329
1330
        df.to_csv(graph_csv_path, index=False)

1331
        # load CSVDataset with default node/edge/gdata_parser
1332
1333
1334
1335
1336
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1337
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1338
1339
1340
            assert len(csv_dataset) == num_graphs
            assert csv_dataset.has_cache()
            assert len(csv_dataset.data) == 2
1341
1342
1343
1344
1345
            assert "feat" in csv_dataset.data
            assert "label" in csv_dataset.data
            assert F.array_equal(
                F.tensor(feat_gdata, dtype=F.float32), csv_dataset.data["feat"]
            )
1346
            for i, (g, g_data) in enumerate(csv_dataset):
1347
                assert not g.is_homogeneous
1348
1349
1350
1351
                assert F.asnumpy(g_data["label"]) == label_gdata[i]
                assert F.array_equal(
                    g_data["feat"], F.tensor(feat_gdata[i], dtype=F.float32)
                )
1352
1353
                for ntype in g.ntypes:
                    assert g.num_nodes(ntype) == num_nodes
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
                    assert F.array_equal(
                        F.tensor(
                            feat_ndata[i * num_nodes : (i + 1) * num_nodes],
                            dtype=F.float32,
                        ),
                        g.nodes[ntype].data["feat"],
                    )
                    assert np.array_equal(
                        label_ndata[i * num_nodes : (i + 1) * num_nodes],
                        F.asnumpy(g.nodes[ntype].data["label"]),
                    )
1365
1366
                for etype in g.etypes:
                    assert g.num_edges(etype) == num_edges
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
                    assert F.array_equal(
                        F.tensor(
                            feat_edata[i * num_edges : (i + 1) * num_edges],
                            dtype=F.float32,
                        ),
                        g.edges[etype].data["feat"],
                    )
                    assert np.array_equal(
                        label_edata[i * num_edges : (i + 1) * num_edges],
                        F.asnumpy(g.edges[etype].data["label"]),
                    )
1378
1379


1380
def _test_CSVDataset_customized_data_parser():
1381
1382
1383
1384
1385
1386
1387
1388
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
        meta_yaml_data = {
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1414
1415
1416
1417
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1430
1431
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1453
1454
1455
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        label_gdata = np.random.randint(2, size=num_graphs)
1456
1457
1458
        df = pd.DataFrame(
            {"label": label_gdata, "graph_id": np.arange(num_graphs)}
        )
1459
1460
1461
1462
1463
1464
1465
        df.to_csv(graph_csv_path, index=False)

        class CustDataParser:
            def __call__(self, df):
                data = {}
                for header in df:
                    dt = df[header].to_numpy().squeeze()
1466
                    if header == "label":
1467
1468
1469
                        dt += 2
                    data[header] = dt
                return data
1470

1471
1472
1473
        # load CSVDataset with customized node/edge/gdata_parser
        # specify via dict[ntype/etype, callable]
        csv_dataset = data.CSVDataset(
1474
1475
1476
1477
1478
1479
            test_dir,
            force_reload=True,
            ndata_parser={"user": CustDataParser()},
            edata_parser={("user", "like", "item"): CustDataParser()},
            gdata_parser=CustDataParser(),
        )
1480
1481
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1482
        assert "label" in csv_dataset.data
1483
        for i, (g, g_data) in enumerate(csv_dataset):
1484
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1485
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1486
1487
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1488
1489
1490
1491
1492
                offset = 2 if ntype == "user" else 0
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1493
1494
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1495
1496
1497
1498
1499
                offset = 2 if etype == "like" else 0
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1500
1501
        # specify via callable
        csv_dataset = data.CSVDataset(
1502
1503
1504
1505
1506
1507
            test_dir,
            force_reload=True,
            ndata_parser=CustDataParser(),
            edata_parser=CustDataParser(),
            gdata_parser=CustDataParser(),
        )
1508
1509
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1510
        assert "label" in csv_dataset.data
1511
1512
        for i, (g, g_data) in enumerate(csv_dataset):
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1513
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1514
1515
1516
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
                offset = 2
1517
1518
1519
1520
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1521
1522
1523
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
                offset = 2
1524
1525
1526
1527
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1528
1529
1530


def _test_NodeEdgeGraphData():
1531
1532
    from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData

1533
1534
1535
    # NodeData basics
    num_nodes = 100
    node_ids = np.arange(num_nodes, dtype=np.float)
1536
    ndata = NodeData(node_ids, {})
1537
    assert np.array_equal(ndata.id, node_ids)
1538
    assert len(ndata.data) == 0
1539
    assert ndata.type == "_V"
1540
1541
    assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
    # NodeData more
1542
    data = {"feat": np.random.rand(num_nodes, 3)}
1543
    graph_id = np.arange(num_nodes)
1544
1545
    ndata = NodeData(node_ids, data, type="user", graph_id=graph_id)
    assert ndata.type == "user"
1546
1547
1548
1549
1550
1551
1552
1553
    assert np.array_equal(ndata.graph_id, graph_id)
    assert len(ndata.data) == len(data)
    for k, v in data.items():
        assert k in ndata.data
        assert np.array_equal(ndata.data[k], v)
    # NodeData except
    expect_except = False
    try:
1554
1555
1556
1557
1558
        NodeData(
            np.arange(num_nodes),
            {"feat": np.random.rand(num_nodes + 1, 3)},
            graph_id=np.arange(num_nodes - 1),
        )
1559
1560
1561
1562
1563
1564
1565
1566
1567
    except:
        expect_except = True
    assert expect_except

    # EdgeData basics
    num_nodes = 100
    num_edges = 1000
    src_ids = np.random.randint(num_nodes, size=num_edges)
    dst_ids = np.random.randint(num_nodes, size=num_edges)
1568
    edata = EdgeData(src_ids, dst_ids, {})
1569
1570
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
1571
    assert edata.type == ("_V", "_E", "_V")
1572
1573
1574
1575
1576
    assert len(edata.data) == 0
    assert np.array_equal(edata.graph_id, np.full(num_edges, 0))
    # EdageData more
    src_ids = np.random.randint(num_nodes, size=num_edges).astype(np.float)
    dst_ids = np.random.randint(num_nodes, size=num_edges).astype(np.float)
1577
1578
    data = {"feat": np.random.rand(num_edges, 3)}
    etype = ("user", "like", "item")
1579
    graph_ids = np.arange(num_edges)
1580
    edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
    assert edata.type == etype
    assert len(edata.data) == len(data)
    for k, v in data.items():
        assert k in edata.data
        assert np.array_equal(edata.data[k], v)
    assert np.array_equal(edata.graph_id, graph_ids)
    # EdgeData except
    expect_except = False
    try:
1592
1593
1594
1595
1596
1597
        EdgeData(
            np.arange(num_edges),
            np.arange(num_edges + 1),
            {"feat": np.random.rand(num_edges - 1, 3)},
            graph_id=np.arange(num_edges + 2),
        )
1598
1599
1600
1601
1602
1603
1604
    except:
        expect_except = True
    assert expect_except

    # GraphData basics
    num_graphs = 10
    graph_ids = np.arange(num_graphs)
1605
    gdata = GraphData(graph_ids, {})
1606
1607
1608
1609
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == 0
    # GraphData more
    graph_ids = np.arange(num_graphs).astype(np.float)
1610
    data = {"feat": np.random.rand(num_graphs, 3)}
1611
    gdata = GraphData(graph_ids, data)
1612
1613
1614
1615
1616
1617
1618
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == len(data)
    for k, v in data.items():
        assert k in gdata.data
        assert np.array_equal(gdata.data[k], v)


1619
1620
1621
1622
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1623
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1624
1625
def test_csvdataset():
    _test_NodeEdgeGraphData()
1626
    _test_construct_graphs_node_ids()
1627
1628
1629
1630
1631
1632
1633
1634
    _test_construct_graphs_homo()
    _test_construct_graphs_hetero()
    _test_construct_graphs_multiple()
    _test_DefaultDataParser()
    _test_load_yaml_with_sanity_check()
    _test_load_node_data_from_csv()
    _test_load_edge_data_from_csv()
    _test_load_graph_data_from_csv()
1635
1636
1637
    _test_CSVDataset_single()
    _test_CSVDataset_multiple()
    _test_CSVDataset_customized_data_parser()
1638

1639
1640
1641
1642
1643

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1644
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1645
1646
def test_add_nodepred_split():
    dataset = data.AmazonCoBuyComputerDataset()
1647
    print("train_mask" in dataset[0].ndata)
1648
    data.utils.add_nodepred_split(dataset, [0.8, 0.1, 0.1])
1649
    assert "train_mask" in dataset[0].ndata
1650
1651

    dataset = data.AIFBDataset()
1652
1653
1654
1655
1656
    print("train_mask" in dataset[0].nodes["Publikationen"].data)
    data.utils.add_nodepred_split(
        dataset, [0.8, 0.1, 0.1], ntype="Publikationen"
    )
    assert "train_mask" in dataset[0].nodes["Publikationen"].data
1657

1658
1659
1660
1661
1662

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1663
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1664
1665
def test_as_nodepred1():
    ds = data.AmazonCoBuyComputerDataset()
1666
    print("train_mask" in ds[0].ndata)
1667
1668
1669
1670
    new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1
    assert new_ds[0].num_nodes() == ds[0].num_nodes()
    assert new_ds[0].num_edges() == ds[0].num_edges()
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
    assert "train_mask" in new_ds[0].ndata
    assert F.array_equal(
        new_ds.train_idx, F.nonzero_1d(new_ds[0].ndata["train_mask"])
    )
    assert F.array_equal(
        new_ds.val_idx, F.nonzero_1d(new_ds[0].ndata["val_mask"])
    )
    assert F.array_equal(
        new_ds.test_idx, F.nonzero_1d(new_ds[0].ndata["test_mask"])
    )
1681
1682

    ds = data.AIFBDataset()
1683
1684
1685
1686
    print("train_mask" in ds[0].nodes["Personen"].data)
    new_ds = data.AsNodePredDataset(
        ds, [0.8, 0.1, 0.1], "Personen", verbose=True
    )
1687
1688
1689
    assert len(new_ds) == 1
    assert new_ds[0].ntypes == ds[0].ntypes
    assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
    assert "train_mask" in new_ds[0].nodes["Personen"].data
    assert F.array_equal(
        new_ds.train_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["train_mask"]),
    )
    assert F.array_equal(
        new_ds.val_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["val_mask"]),
    )
    assert F.array_equal(
        new_ds.test_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["test_mask"]),
    )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1709
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1710
1711
1712
1713
def test_as_nodepred2():
    # test proper reprocessing

    # create
1714
1715
1716
1717
1718
1719
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1720
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1721
    # read from cache
1722
1723
1724
1725
1726
1727
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1728
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1729
    # invalid cache, re-read
1730
1731
1732
1733
1734
1735
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.1
    )
1736
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)
1737
1738

    # create
1739
1740
1741
1742
1743
1744
1745
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1746
    # read from cache
1747
1748
1749
1750
1751
1752
1753
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1754
    # invalid cache, re-read
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.1, 0.1, 0.8], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.1)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.1)


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Minjie Wang's avatar
Minjie Wang committed
1767
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Jinjing Zhou's avatar
Jinjing Zhou committed
1768
1769
def test_as_nodepred_ogb():
    from ogb.nodeproppred import DglNodePropPredDataset
1770
1771
1772
1773

    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"), split_ratio=None, verbose=True
    )
1774
    split = DglNodePropPredDataset("ogbn-arxiv").get_idx_split()
1775
    train_idx, val_idx, test_idx = split["train"], split["valid"], split["test"]
1776
1777
1778
    assert F.array_equal(ds.train_idx, F.tensor(train_idx))
    assert F.array_equal(ds.val_idx, F.tensor(val_idx))
    assert F.array_equal(ds.test_idx, F.tensor(test_idx))
Jinjing Zhou's avatar
Jinjing Zhou committed
1779
    # force generate new split
1780
1781
1782
1783
1784
1785
    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )

1786

1787
1788
1789
1790
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1791
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1792
1793
def test_as_linkpred():
    # create
1794
1795
1796
1797
1798
1799
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.8, 0.1, 0.1],
        neg_ratio=1,
        verbose=True,
    )
1800
1801
1802
1803
1804
    # Cora has 10556 edges, 10% test edges can be 1057
    assert ds.test_edges[0][0].shape[0] == 1057
    # negative samples, not guaranteed, so the assert is in a relaxed range
    assert 1000 <= ds.test_edges[1][0].shape[0] <= 1057
    # read from cache
1805
1806
1807
1808
1809
1810
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.7, 0.1, 0.2],
        neg_ratio=2,
        verbose=True,
    )
1811
1812
1813
1814
1815
    assert ds.test_edges[0][0].shape[0] == 2112
    # negative samples, not guaranteed to be ratio 2, so the assert is in a relaxed range
    assert 4000 < ds.test_edges[1][0].shape[0] <= 4224


1816
1817
1818
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
1819
1820
def test_as_linkpred_ogb():
    from ogb.linkproppred import DglLinkPropPredDataset
1821
1822
1823
1824

    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"), split_ratio=None, verbose=True
    )
1825
1826
1827
    # original dataset has 46329 test edges
    assert ds.test_edges[0][0].shape[0] == 46329
    # force generate new split
1828
1829
1830
1831
1832
    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )
1833
1834
    assert ds.test_edges[0][0].shape[0] == 235812

1835
1836
1837
1838
1839

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1840
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1841
1842
1843
1844
1845
1846
def test_as_nodepred_csvdataset():
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path = os.path.join(test_dir, "test_edges.csv")
        nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
1847
1848
1849
1850
1851
1852
1853
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [{"file_name": os.path.basename(nodes_csv_path)}],
            "edge_data": [{"file_name": os.path.basename(edges_csv_path)}],
        }
        with open(meta_yaml_path, "w") as f:
1854
1855
1856
1857
1858
1859
1860
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        num_classes = num_nodes
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.arange(num_classes)
1861
1862
1863
1864
1865
1866
1867
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1868
        df.to_csv(nodes_csv_path, index=False)
1869
1870
1871
1872
1873
1874
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
1875
1876
        df.to_csv(edges_csv_path, index=False)

1877
        ds = data.CSVDataset(test_dir, force_reload=True)
1878
1879
1880
1881
1882
1883
1884
        assert "feat" in ds[0].ndata
        assert "label" in ds[0].ndata
        assert "train_mask" not in ds[0].ndata
        assert not hasattr(ds[0], "num_classes")
        new_ds = data.AsNodePredDataset(
            ds, split_ratio=[0.8, 0.1, 0.1], force_reload=True
        )
1885
        assert new_ds.num_classes == num_classes
1886
1887
1888
        assert "feat" in new_ds[0].ndata
        assert "label" in new_ds[0].ndata
        assert "train_mask" in new_ds[0].ndata
1889

1890
1891
1892
1893
1894

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1895
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1896
def test_as_graphpred():
1897
    ds = data.GINDataset(name="MUTAG", self_loop=True)
Mufei Li's avatar
Mufei Li committed
1898
1899
1900
1901
1902
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 188
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1903
    ds = data.FakeNewsDataset("politifact", "profile")
Mufei Li's avatar
Mufei Li committed
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
    new_ds = data.AsGraphPredDataset(ds, verbose=True)
    assert len(new_ds) == 314
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.QM7bDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 7211
    assert new_ds.num_tasks == 14
    assert new_ds.num_classes is None

1915
    ds = data.QM9Dataset(label_keys=["mu", "gap"])
Mufei Li's avatar
Mufei Li committed
1916
1917
1918
1919
1920
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1921
    ds = data.QM9EdgeDataset(label_keys=["mu", "alpha"])
Mufei Li's avatar
Mufei Li committed
1922
1923
1924
1925
1926
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 130831
    assert new_ds.num_tasks == 2
    assert new_ds.num_classes is None

1927
    ds = data.TUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1928
1929
1930
1931
1932
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1933
    ds = data.LegacyTUDataset("DD")
Mufei Li's avatar
Mufei Li committed
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1178
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

    ds = data.BA2MotifDataset()
    new_ds = data.AsGraphPredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1000
    assert new_ds.num_tasks == 1
    assert new_ds.num_classes == 2

1945
1946
1947
1948
1949

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1950
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1951
def test_as_graphpred_reprocess():
1952
1953
1954
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1955
1956
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1957
1958
1959
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1960
1961
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1962
1963
1964
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1965
1966
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1967
1968
1969
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1970
1971
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1972
1973
1974
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1975
1976
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1977
1978
1979
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1991
1992
1993
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1994
1995
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1996
1997
1998
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1999
2000
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2001
2002
2003
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
2004
2005
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2006
2007
2008
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2009
2010
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2011
2012
2013
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
2014
2015
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2016
2017
2018
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
2019
2020
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2021
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2022
2023
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2024
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2025
2026
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2027
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2028
2029
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2030
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2031
2032
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
2033
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
2034
2035
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
2036
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

2048
2049
2050
2051

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Mufei Li's avatar
Mufei Li committed
2052
2053
def test_as_graphpred_ogb():
    from ogb.graphproppred import DglGraphPropPredDataset
2054
2055
2056
2057

    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"), split_ratio=None, verbose=True
    )
Mufei Li's avatar
Mufei Li committed
2058
2059
    assert len(ds.train_idx) == 32901
    # force generate new split
2060
2061
2062
2063
2064
    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"),
        split_ratio=[0.6, 0.2, 0.2],
        verbose=True,
    )
Mufei Li's avatar
Mufei Li committed
2065
2066
    assert len(ds.train_idx) == 24676

2067
2068

if __name__ == "__main__":
2069
    test_minigc()
2070
    test_gin()
2071
    test_data_hash()
2072
2073
2074
    test_tudataset_regression()
    test_fraud()
    test_fakenews()
2075
    test_extract_archive()
2076
    test_csvdataset()
2077
2078
2079
    test_add_nodepred_split()
    test_as_nodepred1()
    test_as_nodepred2()
2080
    test_as_nodepred_csvdataset()