test_data.py 71.7 KB
Newer Older
1
import gzip
2
import io
3
import os
4
import tarfile
5
import tempfile
6
import unittest
7

8
import backend as F
9
10
11

import dgl
import dgl.data as data
12
import numpy as np
13
14
import pandas as pd
import pytest
15
import yaml
16
from dgl import DGLError
17

18
19
20
21
22

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
23
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
24
25
26
27
def test_minigc():
    ds = data.MiniGCDataset(16, 10, 20)
    g, l = list(zip(*ds))
    print(g, l)
28
29
30
31
32
    g1 = ds[0][0]
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    ds = data.MiniGCDataset(16, 10, 20, transform=transform)
    g2 = ds[0][0]
    assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
33

34
35
36
37
38

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
39
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
40
41
def test_gin():
    ds_n_graphs = {
42
43
44
45
46
        "MUTAG": 188,
        "IMDBBINARY": 1000,
        "IMDBMULTI": 1500,
        "PROTEINS": 1113,
        "PTC": 344,
47
    }
48
    transform = dgl.AddSelfLoop(allow_duplicate=True)
49
50
51
    for name, n_graphs in ds_n_graphs.items():
        ds = data.GINDataset(name, self_loop=False, degree_as_nlabel=False)
        assert len(ds) == n_graphs, (len(ds), name)
52
        g1 = ds[0][0]
53
54
55
        ds = data.GINDataset(
            name, self_loop=False, degree_as_nlabel=False, transform=transform
        )
56
57
        g2 = ds[0][0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
Mufei Li's avatar
Mufei Li committed
58
        assert ds.num_classes == ds.gclasses
59

60
61
62
63
64

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
65
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
66
def test_fraud():
67
68
    transform = dgl.AddSelfLoop(allow_duplicate=True)

69
    g = data.FraudDataset("amazon")[0]
70
    assert g.num_nodes() == 11944
71
    num_edges1 = g.num_edges()
72
    g2 = data.FraudDataset("amazon", transform=transform)[0]
73
74
    # 3 edge types
    assert g2.num_edges() - num_edges1 == g.num_nodes() * 3
75
76
77

    g = data.FraudAmazonDataset()[0]
    assert g.num_nodes() == 11944
78
79
80
    g2 = data.FraudAmazonDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
81
82
83

    g = data.FraudYelpDataset()[0]
    assert g.num_nodes() == 45954
84
85
86
    g2 = data.FraudYelpDataset(transform=transform)[0]
    # 3 edge types
    assert g2.num_edges() - g.num_edges() == g.num_nodes() * 3
87

88
89
90
91
92

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
93
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
94
def test_tudataset_regression():
95
    ds = data.TUDataset("ZINC_test", force_reload=True)
Mufei Li's avatar
Mufei Li committed
96
    assert ds.num_classes == ds.num_labels
Jinjing Zhou's avatar
Jinjing Zhou committed
97
    assert len(ds) == 5000
98
    g = ds[0][0]
Jinjing Zhou's avatar
Jinjing Zhou committed
99

100
    transform = dgl.AddSelfLoop(allow_duplicate=True)
101
    ds = data.TUDataset("ZINC_test", force_reload=True, transform=transform)
102
103
    g2 = ds[0][0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
104

105
106
107
108
109

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
110
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
111
112
113
def test_data_hash():
    class HashTestDataset(data.DGLDataset):
        def __init__(self, hash_key=()):
114
            super(HashTestDataset, self).__init__("hashtest", hash_key=hash_key)
115

116
117
118
        def _load(self):
            pass

119
120
121
    a = HashTestDataset((True, 0, "1", (1, 2, 3)))
    b = HashTestDataset((True, 0, "1", (1, 2, 3)))
    c = HashTestDataset((True, 0, "1", (1, 2, 4)))
122
123
124
    assert a.hash == b.hash
    assert a.hash != c.hash

125

126
127
128
129
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
130
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
131
def test_citation_graph():
132
133
    transform = dgl.AddSelfLoop(allow_duplicate=True)

134
    # cora
135
    g = data.CoraGraphDataset(force_reload=True, reorder=True)[0]
136
137
138
139
    assert g.num_nodes() == 2708
    assert g.num_edges() == 10556
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
140
141
    g2 = data.CoraGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
142
143

    # Citeseer
144
    g = data.CiteseerGraphDataset(force_reload=True, reorder=True)[0]
145
146
147
148
    assert g.num_nodes() == 3327
    assert g.num_edges() == 9228
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
149
150
    g2 = data.CiteseerGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
151
152

    # Pubmed
153
    g = data.PubmedGraphDataset(force_reload=True, reorder=True)[0]
154
155
156
157
    assert g.num_nodes() == 19717
    assert g.num_edges() == 88651
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
158
159
    g2 = data.PubmedGraphDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
160
161


162
163
164
165
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
166
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
167
def test_gnn_benchmark():
168
169
    transform = dgl.AddSelfLoop(allow_duplicate=True)

170
171
172
173
174
175
    # AmazonCoBuyComputerDataset
    g = data.AmazonCoBuyComputerDataset()[0]
    assert g.num_nodes() == 13752
    assert g.num_edges() == 491722
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
176
177
    g2 = data.AmazonCoBuyComputerDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
178
179
180
181
182
183
184

    # AmazonCoBuyPhotoDataset
    g = data.AmazonCoBuyPhotoDataset()[0]
    assert g.num_nodes() == 7650
    assert g.num_edges() == 238163
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
185
186
    g2 = data.AmazonCoBuyPhotoDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
187
188
189
190
191
192
193

    # CoauthorPhysicsDataset
    g = data.CoauthorPhysicsDataset()[0]
    assert g.num_nodes() == 34493
    assert g.num_edges() == 495924
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
194
195
    g2 = data.CoauthorPhysicsDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
196
197
198
199
200
201
202

    # CoauthorCSDataset
    g = data.CoauthorCSDataset()[0]
    assert g.num_nodes() == 18333
    assert g.num_edges() == 163788
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
203
204
    g2 = data.CoauthorCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
205
206
207
208
209
210
211

    # CoraFullDataset
    g = data.CoraFullDataset()[0]
    assert g.num_nodes() == 19793
    assert g.num_edges() == 126842
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))
212
213
    g2 = data.CoraFullDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
214
215


216
217
218
219
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
220
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
221
222
223
224
def test_explain_syn():
    dataset = data.BAShapeDataset()
    assert dataset.num_classes == 4
    g = dataset[0]
225
226
    assert "label" in g.ndata
    assert "feat" in g.ndata
227
228
229
230
231
232
233
234
235
236
237

    g1 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BAShapeDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BACommunityDataset()
    assert dataset.num_classes == 8
    g = dataset[0]
238
239
    assert "label" in g.ndata
    assert "feat" in g.ndata
240
241
242
243
244
245
246
247
248
249
250

    g1 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.BACommunityDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeCycleDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
251
252
    assert "label" in g.ndata
    assert "feat" in g.ndata
253
254
255
256
257
258
259
260
261
262
263

    g1 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeCycleDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.TreeGridDataset()
    assert dataset.num_classes == 2
    g = dataset[0]
264
265
    assert "label" in g.ndata
    assert "feat" in g.ndata
266
267
268
269
270
271
272
273
274
275
276

    g1 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src1, dst1 = g1.edges()
    g2 = data.TreeGridDataset(force_reload=True, seed=0)[0]
    src2, dst2 = g2.edges()
    assert F.allclose(src1, src2)
    assert F.allclose(dst1, dst2)

    dataset = data.BA2MotifDataset()
    assert dataset.num_classes == 2
    g, label = dataset[0]
277
    assert "feat" in g.ndata
278

279
280
281
282
283

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
284
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
285
286
287
288
289
290
291
292
293
294
295
def test_wiki_cs():
    g = data.WikiCSDataset()[0]
    assert g.num_nodes() == 11701
    assert g.num_edges() == 431726
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.WikiCSDataset(transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

296

297
@unittest.skip(reason="Dataset too large to download for the latest CI.")
Minjie Wang's avatar
Minjie Wang committed
298
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
299
300
301
302
303
304
305
306
307
308
309
def test_yelp():
    g = data.YelpDataset(reorder=True)[0]
    assert g.num_nodes() == 716847
    assert g.num_edges() == 13954819
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.YelpDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()

310
311
312
313
314

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
315
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
316
317
318
319
320
321
322
323
324
325
def test_flickr():
    g = data.FlickrDataset(reorder=True)[0]
    assert g.num_nodes() == 89250
    assert g.num_edges() == 899756
    dst = F.asnumpy(g.edges()[1])
    assert np.array_equal(dst, np.sort(dst))

    transform = dgl.AddSelfLoop(allow_duplicate=True)
    g2 = data.FlickrDataset(reorder=True, transform=transform)[0]
    assert g2.num_edges() - g.num_edges() == g.num_nodes()
326

327

328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_pattern():
    mode_n_graphs = {
        "train": 10000,
        "valid": 2000,
        "test": 2000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        ds = data.PATTERNDataset(mode=mode)
        assert len(ds) == n_graphs, (len(ds), mode)
        g1 = ds[0]
        ds = data.PATTERNDataset(mode=mode, transform=transform)
        g2 = ds[0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        assert ds.num_classes == 2


350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_cluster():
    mode_n_graphs = {
        "train": 10000,
        "valid": 1000,
        "test": 1000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        ds = data.CLUSTERDataset(mode=mode)
        assert len(ds) == n_graphs, (len(ds), mode)
        g1 = ds[0]
        ds = data.CLUSTERDataset(mode=mode, transform=transform)
        g2 = ds[0]
        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        assert ds.num_classes == 6


372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="only supports pytorch"
)
def test_zinc():
    mode_n_graphs = {
        "train": 10000,
        "valid": 1000,
        "test": 1000,
    }
    transform = dgl.AddSelfLoop(allow_duplicate=True)
    for mode, n_graphs in mode_n_graphs.items():
        dataset1 = data.ZINCDataset(mode=mode)
        g1, label = dataset1[0]
        dataset2 = data.ZINCDataset(mode=mode, transform=transform)
        g2, _ = dataset2[0]

        assert g2.num_edges() - g1.num_edges() == g1.num_nodes()
        # return a scalar tensor
        assert not label.shape


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
def test_extract_archive():
    # gzip
    with tempfile.TemporaryDirectory() as src_dir:
        gz_file = "gz_archive"
        gz_path = os.path.join(src_dir, gz_file + ".gz")
        content = b"test extract archive gzip"
        with gzip.open(gz_path, "wb") as f:
            f.write(content)
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(gz_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, gz_file))

    # tar
    with tempfile.TemporaryDirectory() as src_dir:
        tar_file = "tar_archive"
        tar_path = os.path.join(src_dir, tar_file + ".tar")
        # default encode to utf8
        content = "test extract archive tar\n".encode()
        info = tarfile.TarInfo(name="tar_archive")
        info.size = len(content)
        with tarfile.open(tar_path, "w") as f:
            f.addfile(info, io.BytesIO(content))
        with tempfile.TemporaryDirectory() as dst_dir:
            data.utils.extract_archive(tar_path, dst_dir, overwrite=True)
            assert os.path.exists(os.path.join(dst_dir, tar_file))


429
def _test_construct_graphs_node_ids():
430
431
432
433
434
435
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

436
437
438
439
440
441
442
443
444
445
446
    num_nodes = 100
    num_edges = 1000

    # node IDs are required to be unique
    node_ids = np.random.choice(np.arange(num_nodes / 2), num_nodes)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    expect_except = False
    try:
447
        _, _ = DGLGraphConstructor.construct_graphs(node_data, edge_data)
448
449
450
451
452
453
454
455
456
457
458
    except:
        expect_except = True
    assert expect_except

    # node IDs are already labelled from 0~num_nodes-1
    node_ids = np.arange(num_nodes)
    np.random.shuffle(node_ids)
    _, idx = np.unique(node_ids, return_index=True)
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
    node_feat = np.random.rand(num_nodes, 3)
459
    node_data = NodeData(node_ids, {"feat": node_feat})
460
461
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
462
463
        node_data, edge_data
    )
464
465
466
467
468
469
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)
470
471
472
    assert F.array_equal(
        F.tensor(node_feat[idx], dtype=F.float32), g.ndata["feat"]
    )
473
474
475

    # node IDs are mixed with numeric and non-numeric values
    # homogeneous graph
476
    node_ids = [1, 2, 3, "a"]
477
    src_ids = [1, 2, 3]
478
    dst_ids = ["a", 1, 2]
479
480
481
    node_data = NodeData(node_ids, {})
    edge_data = EdgeData(src_ids, dst_ids, {})
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
482
483
        node_data, edge_data
    )
484
485
486
487
488
489
490
491
492
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == len(node_ids)
    assert g.num_edges() == len(src_ids)

    # heterogeneous graph
    node_ids_user = [1, 2, 3]
493
    node_ids_item = ["a", "b", "c"]
494
495
    src_ids = node_ids_user
    dst_ids = node_ids_item
496
497
498
    node_data_user = NodeData(node_ids_user, {}, type="user")
    node_data_item = NodeData(node_ids_item, {}, type="item")
    edge_data = EdgeData(src_ids, dst_ids, {}, type=("user", "like", "item"))
499
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
500
501
        [node_data_user, node_data_item], edge_data
    )
502
503
504
505
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
506
507
    assert g.num_nodes("user") == len(node_ids_user)
    assert g.num_nodes("item") == len(node_ids_item)
508
509
510
    assert g.num_edges() == len(src_ids)


511
def _test_construct_graphs_homo():
512
513
514
515
516
517
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

518
    # node_id could be non-sorted, non-numeric.
519
520
521
522
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
    node_ids = np.random.choice(
523
524
        np.arange(num_nodes * 2), size=num_nodes, replace=False
    )
525
    assert len(node_ids) == num_nodes
526
    # to be non-sorted
527
    np.random.shuffle(node_ids)
528
    # to be non-numeric
529
530
531
532
533
    node_ids = ["id_{}".format(id) for id in node_ids]
    t_ndata = {
        "feat": np.random.rand(num_nodes, num_dims),
        "label": np.random.randint(2, size=num_nodes),
    }
534
    _, u_indices = np.unique(node_ids, return_index=True)
535
536
537
538
    ndata = {
        "feat": t_ndata["feat"][u_indices],
        "label": t_ndata["label"][u_indices],
    }
539
    node_data = NodeData(node_ids, t_ndata)
540
541
    src_ids = np.random.choice(node_ids, size=num_edges)
    dst_ids = np.random.choice(node_ids, size=num_edges)
542
543
544
545
    edata = {
        "feat": np.random.rand(num_edges, num_dims),
        "label": np.random.randint(2, size=num_edges),
    }
546
547
    edge_data = EdgeData(src_ids, dst_ids, edata)
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
548
549
        node_data, edge_data
    )
550
551
552
553
554
555
556
557
558
559
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert g.is_homogeneous
    assert g.num_nodes() == num_nodes
    assert g.num_edges() == num_edges

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
560
561
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
562
563
564
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

565
566
567
568
569
    assert_data(ndata, g.ndata)
    assert_data(edata, g.edata)


def _test_construct_graphs_hetero():
570
571
572
573
574
575
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        NodeData,
    )

576
    # node_id/src_id/dst_id could be non-sorted, duplicated, non-numeric.
577
578
579
    num_nodes = 100
    num_edges = 1000
    num_dims = 3
580
    ntypes = ["user", "item"]
581
582
583
584
585
    node_data = []
    node_ids_dict = {}
    ndata_dict = {}
    for ntype in ntypes:
        node_ids = np.random.choice(
586
587
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
588
        assert len(node_ids) == num_nodes
589
        # to be non-sorted
590
        np.random.shuffle(node_ids)
591
        # to be non-numeric
592
593
594
595
596
        node_ids = ["id_{}".format(id) for id in node_ids]
        t_ndata = {
            "feat": np.random.rand(num_nodes, num_dims),
            "label": np.random.randint(2, size=num_nodes),
        }
597
        _, u_indices = np.unique(node_ids, return_index=True)
598
599
600
601
        ndata = {
            "feat": t_ndata["feat"][u_indices],
            "label": t_ndata["label"][u_indices],
        }
602
        node_data.append(NodeData(node_ids, t_ndata, type=ntype))
603
604
        node_ids_dict[ntype] = node_ids
        ndata_dict[ntype] = ndata
605
    etypes = [("user", "follow", "user"), ("user", "like", "item")]
606
607
608
609
610
    edge_data = []
    edata_dict = {}
    for src_type, e_type, dst_type in etypes:
        src_ids = np.random.choice(node_ids_dict[src_type], size=num_edges)
        dst_ids = np.random.choice(node_ids_dict[dst_type], size=num_edges)
611
612
613
614
615
616
617
        edata = {
            "feat": np.random.rand(num_edges, num_dims),
            "label": np.random.randint(2, size=num_edges),
        }
        edge_data.append(
            EdgeData(src_ids, dst_ids, edata, type=(src_type, e_type, dst_type))
        )
618
        edata_dict[(src_type, e_type, dst_type)] = edata
619
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
620
621
        node_data, edge_data
    )
622
623
624
625
    assert len(graphs) == 1
    assert len(data_dict) == 0
    g = graphs[0]
    assert not g.is_homogeneous
626
627
    assert g.num_nodes() == num_nodes * len(ntypes)
    assert g.num_edges() == num_edges * len(etypes)
628
629
630
631

    def assert_data(lhs, rhs):
        for key, value in lhs.items():
            assert key in rhs
632
633
            assert F.dtype(rhs[key]) != F.float64
            assert F.array_equal(
634
635
636
                F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
            )

637
638
639
640
641
642
643
644
645
    for ntype in g.ntypes:
        assert g.num_nodes(ntype) == num_nodes
        assert_data(ndata_dict[ntype], g.nodes[ntype].data)
    for etype in g.canonical_etypes:
        assert g.num_edges(etype) == num_edges
        assert_data(edata_dict[etype], g.edges[etype].data)


def _test_construct_graphs_multiple():
646
647
648
649
650
651
652
    from dgl.data.csv_dataset_base import (
        DGLGraphConstructor,
        EdgeData,
        GraphData,
        NodeData,
    )

653
654
655
656
    num_nodes = 100
    num_edges = 1000
    num_graphs = 10
    num_dims = 3
657
658
659
660
661
662
    node_ids = np.array([], dtype=int)
    src_ids = np.array([], dtype=int)
    dst_ids = np.array([], dtype=int)
    ngraph_ids = np.array([], dtype=int)
    egraph_ids = np.array([], dtype=int)
    u_indices = np.array([], dtype=int)
663
664
    for i in range(num_graphs):
        l_node_ids = np.random.choice(
665
666
            np.arange(num_nodes * 2), size=num_nodes, replace=False
        )
667
668
669
670
        node_ids = np.append(node_ids, l_node_ids)
        _, l_u_indices = np.unique(l_node_ids, return_index=True)
        u_indices = np.append(u_indices, l_u_indices)
        ngraph_ids = np.append(ngraph_ids, np.full(num_nodes, i))
671
672
673
674
675
676
        src_ids = np.append(
            src_ids, np.random.choice(l_node_ids, size=num_edges)
        )
        dst_ids = np.append(
            dst_ids, np.random.choice(l_node_ids, size=num_edges)
        )
677
        egraph_ids = np.append(egraph_ids, np.full(num_edges, i))
678
679
680
681
682
    ndata = {
        "feat": np.random.rand(num_nodes * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_nodes * num_graphs),
    }
    ngraph_ids = ["graph_{}".format(id) for id in ngraph_ids]
683
    node_data = NodeData(node_ids, ndata, graph_id=ngraph_ids)
684
685
686
687
688
    egraph_ids = ["graph_{}".format(id) for id in egraph_ids]
    edata = {
        "feat": np.random.rand(num_edges * num_graphs, num_dims),
        "label": np.random.randint(2, size=num_edges * num_graphs),
    }
689
    edge_data = EdgeData(src_ids, dst_ids, edata, graph_id=egraph_ids)
690
691
692
693
694
    gdata = {
        "feat": np.random.rand(num_graphs, num_dims),
        "label": np.random.randint(2, size=num_graphs),
    }
    graph_ids = ["graph_{}".format(id) for id in np.arange(num_graphs)]
695
    graph_data = GraphData(graph_ids, gdata)
696
    graphs, data_dict = DGLGraphConstructor.construct_graphs(
697
698
        node_data, edge_data, graph_data
    )
699
700
701
    assert len(graphs) == num_graphs
    assert len(data_dict) == len(gdata)
    for k, v in data_dict.items():
702
        assert F.dtype(v) != F.float64
703
704
705
706
        assert F.array_equal(
            F.reshape(F.tensor(gdata[k], dtype=F.dtype(v)), (len(graphs), -1)),
            v,
        )
707
708
709
710
711
712
713
714
    for i, g in enumerate(graphs):
        assert g.is_homogeneous
        assert g.num_nodes() == num_nodes
        assert g.num_edges() == num_edges

        def assert_data(lhs, rhs, size, node=False):
            for key, value in lhs.items():
                assert key in rhs
715
                value = value[i * size : (i + 1) * size]
716
                if node:
717
                    indices = u_indices[i * size : (i + 1) * size]
718
                    value = value[indices]
719
720
                assert F.dtype(rhs[key]) != F.float64
                assert F.array_equal(
721
722
723
                    F.tensor(value, dtype=F.dtype(rhs[key])), rhs[key]
                )

724
725
726
727
        assert_data(ndata, g.ndata, num_nodes, node=True)
        assert_data(edata, g.edata, num_edges)

    # Graph IDs found in node/edge CSV but not in graph CSV
728
    graph_data = GraphData(np.arange(num_graphs - 2), {})
729
730
    expect_except = False
    try:
731
        _, _ = DGLGraphConstructor.construct_graphs(
732
733
            node_data, edge_data, graph_data
        )
734
735
736
737
738
739
    except:
        expect_except = True
    assert expect_except


def _test_DefaultDataParser():
740
    from dgl.data.csv_dataset_base import DefaultDataParser
741

742
743
744
745
746
747
748
749
750
    # common csv
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
        num_nodes = 5
        num_labels = 3
        num_dims = 2
        node_id = np.arange(num_nodes)
        label = np.random.randint(num_labels, size=num_nodes)
        feat = np.random.rand(num_nodes, num_dims)
751
752
753
754
755
756
757
        df = pd.DataFrame(
            {
                "node_id": node_id,
                "label": label,
                "feat": [line.tolist() for line in feat],
            }
        )
758
        df.to_csv(csv_path, index=False)
759
        dp = DefaultDataParser()
760
761
        df = pd.read_csv(csv_path)
        dt = dp(df)
762
763
764
        assert np.array_equal(node_id, dt["node_id"])
        assert np.array_equal(label, dt["label"])
        assert np.array_equal(feat, dt["feat"])
765
766
767
    # string consists of non-numeric values
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
768
        df = pd.DataFrame({"label": ["a", "b", "c"]})
769
        df.to_csv(csv_path, index=False)
770
        dp = DefaultDataParser()
771
772
773
774
775
776
777
778
779
780
        df = pd.read_csv(csv_path)
        expect_except = False
        try:
            dt = dp(df)
        except:
            expect_except = True
        assert expect_except
    # csv has index column which is ignored as it's unnamed
    with tempfile.TemporaryDirectory() as test_dir:
        csv_path = os.path.join(test_dir, "nodes.csv")
781
        df = pd.DataFrame({"label": [1, 2, 3]})
782
        df.to_csv(csv_path)
783
        dp = DefaultDataParser()
784
785
786
787
788
789
        df = pd.read_csv(csv_path)
        dt = dp(df)
        assert len(dt) == 1


def _test_load_yaml_with_sanity_check():
790
    from dgl.data.csv_dataset_base import load_yaml_with_sanity_check
791

792
    with tempfile.TemporaryDirectory() as test_dir:
793
        yaml_path = os.path.join(test_dir, "meta.yaml")
794
        # workable but meaningless usually
795
796
797
798
799
800
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
        with open(yaml_path, "w") as f:
801
            yaml.dump(yaml_data, f, sort_keys=False)
802
        meta = load_yaml_with_sanity_check(yaml_path)
803
804
805
        assert meta.version == "1.0.0"
        assert meta.dataset_name == "default"
        assert meta.separator == ","
806
807
808
809
        assert len(meta.node_data) == 0
        assert len(meta.edge_data) == 0
        assert meta.graph_data is None
        # minimum with required fields only
810
811
812
813
814
815
816
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
817
            yaml.dump(yaml_data, f, sort_keys=False)
818
        meta = load_yaml_with_sanity_check(yaml_path)
819
        for ndata in meta.node_data:
820
821
822
823
            assert ndata.file_name == "nodes.csv"
            assert ndata.ntype == "_V"
            assert ndata.graph_id_field == "graph_id"
            assert ndata.node_id_field == "node_id"
824
        for edata in meta.edge_data:
825
826
827
828
829
            assert edata.file_name == "edges.csv"
            assert edata.etype == ["_V", "_E", "_V"]
            assert edata.graph_id_field == "graph_id"
            assert edata.src_id_field == "src_id"
            assert edata.dst_id_field == "dst_id"
830
        # optional fields are specified
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "separator": "|",
            "node_data": [
                {
                    "file_name": "nodes.csv",
                    "ntype": "user",
                    "graph_id_field": "xxx",
                    "node_id_field": "xxx",
                }
            ],
            "edge_data": [
                {
                    "file_name": "edges.csv",
                    "etype": ["user", "follow", "user"],
                    "graph_id_field": "xxx",
                    "src_id_field": "xxx",
                    "dst_id_field": "xxx",
                }
            ],
            "graph_data": {"file_name": "graph.csv", "graph_id_field": "xxx"},
        }
        with open(yaml_path, "w") as f:
855
            yaml.dump(yaml_data, f, sort_keys=False)
856
        meta = load_yaml_with_sanity_check(yaml_path)
857
858
        assert len(meta.node_data) == 1
        ndata = meta.node_data[0]
859
860
861
        assert ndata.ntype == "user"
        assert ndata.graph_id_field == "xxx"
        assert ndata.node_id_field == "xxx"
862
863
        assert len(meta.edge_data) == 1
        edata = meta.edge_data[0]
864
865
866
867
        assert edata.etype == ["user", "follow", "user"]
        assert edata.graph_id_field == "xxx"
        assert edata.src_id_field == "xxx"
        assert edata.dst_id_field == "xxx"
868
        assert meta.graph_data is not None
869
870
        assert meta.graph_data.file_name == "graph.csv"
        assert meta.graph_data.graph_id_field == "xxx"
871
        # some required fields are missing
872
873
874
875
876
        yaml_data = {
            "dataset_name": "default",
            "node_data": [],
            "edge_data": [],
        }
877
878
879
        for field in yaml_data.keys():
            ydata = {k: v for k, v in yaml_data.items()}
            ydata.pop(field)
880
            with open(yaml_path, "w") as f:
881
882
883
                yaml.dump(ydata, f, sort_keys=False)
            expect_except = False
            try:
884
                meta = load_yaml_with_sanity_check(yaml_path)
885
886
887
888
            except:
                expect_except = True
            assert expect_except
        # inapplicable version
889
890
891
892
893
894
895
        yaml_data = {
            "version": "0.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes_0.csv"}],
            "edge_data": [{"file_name": "edges_0.csv"}],
        }
        with open(yaml_path, "w") as f:
896
897
898
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
899
            meta = load_yaml_with_sanity_check(yaml_path)
900
901
902
903
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate node types
904
905
906
907
908
909
910
911
912
913
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [
                {"file_name": "nodes.csv"},
                {"file_name": "nodes.csv"},
            ],
            "edge_data": [{"file_name": "edges.csv"}],
        }
        with open(yaml_path, "w") as f:
914
915
916
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
917
            meta = load_yaml_with_sanity_check(yaml_path)
918
919
920
921
        except DGLError:
            expect_except = True
        assert expect_except
        # duplicate edge types
922
923
924
925
926
927
928
929
930
931
        yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default",
            "node_data": [{"file_name": "nodes.csv"}],
            "edge_data": [
                {"file_name": "edges.csv"},
                {"file_name": "edges.csv"},
            ],
        }
        with open(yaml_path, "w") as f:
932
933
934
            yaml.dump(yaml_data, f, sort_keys=False)
        expect_except = False
        try:
935
            meta = load_yaml_with_sanity_check(yaml_path)
936
937
938
939
940
941
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_node_data_from_csv():
942
943
    from dgl.data.csv_dataset_base import DefaultDataParser, MetaNode, NodeData

944
945
946
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        # minimum
947
948
        df = pd.DataFrame({"node_id": np.arange(num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
949
        df.to_csv(csv_path, index=False)
950
        meta_node = MetaNode(file_name=csv_path)
951
952
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
953
954
955
        assert len(node_data.data) == 0

        # common case
956
957
958
959
960
961
962
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
963
        df.to_csv(csv_path, index=False)
964
        meta_node = MetaNode(file_name=csv_path)
965
966
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
967
        assert len(node_data.data) == 1
968
        assert np.array_equal(df["label"], node_data.data["label"])
969
        assert np.array_equal(np.full(num_nodes, 0), node_data.graph_id)
970
        assert node_data.type == "_V"
971
972

        # add more fields into nodes.csv
973
974
975
976
977
978
979
980
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": np.random.randint(3, size=num_nodes),
                "graph_id": np.full(num_nodes, 1),
            }
        )
        csv_path = os.path.join(test_dir, "nodes.csv")
981
        df.to_csv(csv_path, index=False)
982
        meta_node = MetaNode(file_name=csv_path)
983
984
        node_data = NodeData.load_from_csv(meta_node, DefaultDataParser())
        assert np.array_equal(df["node_id"], node_data.id)
985
        assert len(node_data.data) == 1
986
987
988
        assert np.array_equal(df["label"], node_data.data["label"])
        assert np.array_equal(df["graph_id"], node_data.graph_id)
        assert node_data.type == "_V"
989
990

        # required header is missing
991
992
        df = pd.DataFrame({"label": np.random.randint(3, size=num_nodes)})
        csv_path = os.path.join(test_dir, "nodes.csv")
993
        df.to_csv(csv_path, index=False)
994
        meta_node = MetaNode(file_name=csv_path)
995
996
        expect_except = False
        try:
997
            NodeData.load_from_csv(meta_node, DefaultDataParser())
998
999
1000
1001
1002
1003
        except:
            expect_except = True
        assert expect_except


def _test_load_edge_data_from_csv():
1004
1005
    from dgl.data.csv_dataset_base import DefaultDataParser, EdgeData, MetaEdge

1006
1007
1008
1009
    with tempfile.TemporaryDirectory() as test_dir:
        num_nodes = 100
        num_edges = 1000
        # minimum
1010
1011
1012
1013
1014
1015
1016
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1017
        df.to_csv(csv_path, index=False)
1018
        meta_edge = MetaEdge(file_name=csv_path)
1019
1020
1021
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1022
1023
1024
        assert len(edge_data.data) == 0

        # common case
1025
1026
1027
1028
1029
1030
1031
1032
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1033
        df.to_csv(csv_path, index=False)
1034
        meta_edge = MetaEdge(file_name=csv_path)
1035
1036
1037
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1038
        assert len(edge_data.data) == 1
1039
        assert np.array_equal(df["label"], edge_data.data["label"])
1040
        assert np.array_equal(np.full(num_edges, 0), edge_data.graph_id)
1041
        assert edge_data.type == ("_V", "_E", "_V")
1042
1043

        # add more fields into edges.csv
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "graph_id": np.arange(num_edges),
                "feat": np.random.randint(3, size=num_edges),
                "label": np.random.randint(3, size=num_edges),
            }
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1054
        df.to_csv(csv_path, index=False)
1055
        meta_edge = MetaEdge(file_name=csv_path)
1056
1057
1058
        edge_data = EdgeData.load_from_csv(meta_edge, DefaultDataParser())
        assert np.array_equal(df["src_id"], edge_data.src)
        assert np.array_equal(df["dst_id"], edge_data.dst)
1059
        assert len(edge_data.data) == 2
1060
1061
1062
1063
        assert np.array_equal(df["feat"], edge_data.data["feat"])
        assert np.array_equal(df["label"], edge_data.data["label"])
        assert np.array_equal(df["graph_id"], edge_data.graph_id)
        assert edge_data.type == ("_V", "_E", "_V")
1064
1065

        # required headers are missing
1066
        df = pd.DataFrame(
1067
            {"src_id": np.random.randint(num_nodes, size=num_edges)}
1068
1069
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1070
        df.to_csv(csv_path, index=False)
1071
        meta_edge = MetaEdge(file_name=csv_path)
1072
1073
        expect_except = False
        try:
1074
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1075
1076
1077
        except DGLError:
            expect_except = True
        assert expect_except
1078
        df = pd.DataFrame(
1079
            {"dst_id": np.random.randint(num_nodes, size=num_edges)}
1080
1081
        )
        csv_path = os.path.join(test_dir, "edges.csv")
1082
        df.to_csv(csv_path, index=False)
1083
        meta_edge = MetaEdge(file_name=csv_path)
1084
1085
        expect_except = False
        try:
1086
            EdgeData.load_from_csv(meta_edge, DefaultDataParser())
1087
1088
1089
1090
1091
1092
        except DGLError:
            expect_except = True
        assert expect_except


def _test_load_graph_data_from_csv():
1093
1094
1095
1096
1097
1098
    from dgl.data.csv_dataset_base import (
        DefaultDataParser,
        GraphData,
        MetaGraph,
    )

1099
1100
1101
    with tempfile.TemporaryDirectory() as test_dir:
        num_graphs = 100
        # minimum
1102
1103
        df = pd.DataFrame({"graph_id": np.arange(num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1104
        df.to_csv(csv_path, index=False)
1105
        meta_graph = MetaGraph(file_name=csv_path)
1106
1107
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1108
1109
1110
        assert len(graph_data.data) == 0

        # common case
1111
1112
1113
1114
1115
1116
1117
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1118
        df.to_csv(csv_path, index=False)
1119
        meta_graph = MetaGraph(file_name=csv_path)
1120
1121
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1122
        assert len(graph_data.data) == 1
1123
        assert np.array_equal(df["label"], graph_data.data["label"])
1124
1125

        # add more fields into graph.csv
1126
1127
1128
1129
1130
1131
1132
1133
        df = pd.DataFrame(
            {
                "graph_id": np.arange(num_graphs),
                "feat": np.random.randint(3, size=num_graphs),
                "label": np.random.randint(3, size=num_graphs),
            }
        )
        csv_path = os.path.join(test_dir, "graph.csv")
1134
        df.to_csv(csv_path, index=False)
1135
        meta_graph = MetaGraph(file_name=csv_path)
1136
1137
        graph_data = GraphData.load_from_csv(meta_graph, DefaultDataParser())
        assert np.array_equal(df["graph_id"], graph_data.graph_id)
1138
        assert len(graph_data.data) == 2
1139
1140
        assert np.array_equal(df["feat"], graph_data.data["feat"])
        assert np.array_equal(df["label"], graph_data.data["label"])
1141
1142

        # required header is missing
1143
1144
        df = pd.DataFrame({"label": np.random.randint(3, size=num_graphs)})
        csv_path = os.path.join(test_dir, "graph.csv")
1145
        df.to_csv(csv_path, index=False)
1146
        meta_graph = MetaGraph(file_name=csv_path)
1147
1148
        expect_except = False
        try:
1149
            GraphData.load_from_csv(meta_graph, DefaultDataParser())
1150
1151
1152
1153
1154
        except DGLError:
            expect_except = True
        assert expect_except


1155
def _test_CSVDataset_single():
1156
1157
1158
1159
1160
1161
1162
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
        }
        with open(meta_yaml_path, "w") as f:
1188
1189
1190
1191
1192
1193
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes)
1194
1195
1196
1197
1198
1199
1200
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1201
1202
1203
1204
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
        feat_edata = np.random.rand(num_edges, num_dims)
        label_edata = np.random.randint(2, size=num_edges)
1205
1206
1207
1208
1209
1210
1211
1212
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
            }
        )
1213
1214
1215
1216
1217
1218
1219
1220
1221
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)

        # load CSVDataset
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1222
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1223
1224
1225
1226
1227
1228
            assert len(csv_dataset) == 1
            g = csv_dataset[0]
            assert not g.is_homogeneous
            assert csv_dataset.has_cache()
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1229
1230
1231
1232
1233
1234
1235
                assert F.array_equal(
                    F.tensor(feat_ndata, dtype=F.float32),
                    g.nodes[ntype].data["feat"],
                )
                assert np.array_equal(
                    label_ndata, F.asnumpy(g.nodes[ntype].data["label"])
                )
1236
1237
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1238
1239
1240
1241
1242
1243
1244
                assert F.array_equal(
                    F.tensor(feat_edata, dtype=F.float32),
                    g.edges[etype].data["feat"],
                )
                assert np.array_equal(
                    label_edata, F.asnumpy(g.edges[etype].data["label"])
                )
1245
1246


1247
def _test_CSVDataset_multiple():
1248
1249
1250
1251
1252
1253
1254
1255
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1282
1283
1284
1285
1286
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
        num_dims = 3
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
        feat_ndata = np.random.rand(num_nodes * num_graphs, num_dims)
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1301
1302
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
        feat_edata = np.random.rand(num_edges * num_graphs, num_dims)
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "feat": [line.tolist() for line in feat_edata],
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1326
1327
1328
1329
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        feat_gdata = np.random.rand(num_graphs, num_dims)
        label_gdata = np.random.randint(2, size=num_graphs)
1330
1331
1332
1333
1334
1335
1336
        df = pd.DataFrame(
            {
                "label": label_gdata,
                "feat": [line.tolist() for line in feat_gdata],
                "graph_id": np.arange(num_graphs),
            }
        )
1337
1338
        df.to_csv(graph_csv_path, index=False)

1339
        # load CSVDataset with default node/edge/gdata_parser
1340
1341
1342
1343
1344
        for force_reload in [True, False]:
            if not force_reload:
                # remove original node data file to verify reload from cached files
                os.remove(nodes_csv_path_0)
                assert not os.path.exists(nodes_csv_path_0)
1345
            csv_dataset = data.CSVDataset(test_dir, force_reload=force_reload)
1346
1347
1348
            assert len(csv_dataset) == num_graphs
            assert csv_dataset.has_cache()
            assert len(csv_dataset.data) == 2
1349
1350
1351
1352
1353
            assert "feat" in csv_dataset.data
            assert "label" in csv_dataset.data
            assert F.array_equal(
                F.tensor(feat_gdata, dtype=F.float32), csv_dataset.data["feat"]
            )
1354
            for i, (g, g_data) in enumerate(csv_dataset):
1355
                assert not g.is_homogeneous
1356
1357
1358
1359
                assert F.asnumpy(g_data["label"]) == label_gdata[i]
                assert F.array_equal(
                    g_data["feat"], F.tensor(feat_gdata[i], dtype=F.float32)
                )
1360
1361
                for ntype in g.ntypes:
                    assert g.num_nodes(ntype) == num_nodes
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
                    assert F.array_equal(
                        F.tensor(
                            feat_ndata[i * num_nodes : (i + 1) * num_nodes],
                            dtype=F.float32,
                        ),
                        g.nodes[ntype].data["feat"],
                    )
                    assert np.array_equal(
                        label_ndata[i * num_nodes : (i + 1) * num_nodes],
                        F.asnumpy(g.nodes[ntype].data["label"]),
                    )
1373
1374
                for etype in g.etypes:
                    assert g.num_edges(etype) == num_edges
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
                    assert F.array_equal(
                        F.tensor(
                            feat_edata[i * num_edges : (i + 1) * num_edges],
                            dtype=F.float32,
                        ),
                        g.edges[etype].data["feat"],
                    )
                    assert np.array_equal(
                        label_edata[i * num_edges : (i + 1) * num_edges],
                        F.asnumpy(g.edges[etype].data["label"]),
                    )
1386
1387


1388
def _test_CSVDataset_customized_data_parser():
1389
1390
1391
1392
1393
1394
1395
1396
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path_0 = os.path.join(test_dir, "test_edges_0.csv")
        edges_csv_path_1 = os.path.join(test_dir, "test_edges_1.csv")
        nodes_csv_path_0 = os.path.join(test_dir, "test_nodes_0.csv")
        nodes_csv_path_1 = os.path.join(test_dir, "test_nodes_1.csv")
        graph_csv_path = os.path.join(test_dir, "test_graph.csv")
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
        meta_yaml_data = {
            "dataset_name": "default_name",
            "node_data": [
                {
                    "file_name": os.path.basename(nodes_csv_path_0),
                    "ntype": "user",
                },
                {
                    "file_name": os.path.basename(nodes_csv_path_1),
                    "ntype": "item",
                },
            ],
            "edge_data": [
                {
                    "file_name": os.path.basename(edges_csv_path_0),
                    "etype": ["user", "follow", "user"],
                },
                {
                    "file_name": os.path.basename(edges_csv_path_1),
                    "etype": ["user", "like", "item"],
                },
            ],
            "graph_data": {"file_name": os.path.basename(graph_csv_path)},
        }
        with open(meta_yaml_path, "w") as f:
1422
1423
1424
1425
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_graphs = 10
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
        label_ndata = np.random.randint(2, size=num_nodes * num_graphs)
        df = pd.DataFrame(
            {
                "node_id": np.hstack(
                    [np.arange(num_nodes) for _ in range(num_graphs)]
                ),
                "label": label_ndata,
                "graph_id": np.hstack(
                    [np.full(num_nodes, i) for i in range(num_graphs)]
                ),
            }
        )
1438
1439
        df.to_csv(nodes_csv_path_0, index=False)
        df.to_csv(nodes_csv_path_1, index=False)
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
        label_edata = np.random.randint(2, size=num_edges * num_graphs)
        df = pd.DataFrame(
            {
                "src_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "dst_id": np.hstack(
                    [
                        np.random.randint(num_nodes, size=num_edges)
                        for _ in range(num_graphs)
                    ]
                ),
                "label": label_edata,
                "graph_id": np.hstack(
                    [np.full(num_edges, i) for i in range(num_graphs)]
                ),
            }
        )
1461
1462
1463
        df.to_csv(edges_csv_path_0, index=False)
        df.to_csv(edges_csv_path_1, index=False)
        label_gdata = np.random.randint(2, size=num_graphs)
1464
1465
1466
        df = pd.DataFrame(
            {"label": label_gdata, "graph_id": np.arange(num_graphs)}
        )
1467
1468
1469
1470
1471
1472
1473
        df.to_csv(graph_csv_path, index=False)

        class CustDataParser:
            def __call__(self, df):
                data = {}
                for header in df:
                    dt = df[header].to_numpy().squeeze()
1474
                    if header == "label":
1475
1476
1477
                        dt += 2
                    data[header] = dt
                return data
1478

1479
1480
1481
        # load CSVDataset with customized node/edge/gdata_parser
        # specify via dict[ntype/etype, callable]
        csv_dataset = data.CSVDataset(
1482
1483
1484
1485
1486
1487
            test_dir,
            force_reload=True,
            ndata_parser={"user": CustDataParser()},
            edata_parser={("user", "like", "item"): CustDataParser()},
            gdata_parser=CustDataParser(),
        )
1488
1489
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1490
        assert "label" in csv_dataset.data
1491
        for i, (g, g_data) in enumerate(csv_dataset):
1492
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1493
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1494
1495
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
1496
1497
1498
1499
1500
                offset = 2 if ntype == "user" else 0
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1501
1502
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
1503
1504
1505
1506
1507
                offset = 2 if etype == "like" else 0
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1508
1509
        # specify via callable
        csv_dataset = data.CSVDataset(
1510
1511
1512
1513
1514
1515
            test_dir,
            force_reload=True,
            ndata_parser=CustDataParser(),
            edata_parser=CustDataParser(),
            gdata_parser=CustDataParser(),
        )
1516
1517
        assert len(csv_dataset) == num_graphs
        assert len(csv_dataset.data) == 1
1518
        assert "label" in csv_dataset.data
1519
1520
        for i, (g, g_data) in enumerate(csv_dataset):
            assert not g.is_homogeneous
Mufei Li's avatar
Mufei Li committed
1521
            assert F.asnumpy(g_data) == label_gdata[i] + 2
1522
1523
1524
            for ntype in g.ntypes:
                assert g.num_nodes(ntype) == num_nodes
                offset = 2
1525
1526
1527
1528
                assert np.array_equal(
                    label_ndata[i * num_nodes : (i + 1) * num_nodes] + offset,
                    F.asnumpy(g.nodes[ntype].data["label"]),
                )
1529
1530
1531
            for etype in g.etypes:
                assert g.num_edges(etype) == num_edges
                offset = 2
1532
1533
1534
1535
                assert np.array_equal(
                    label_edata[i * num_edges : (i + 1) * num_edges] + offset,
                    F.asnumpy(g.edges[etype].data["label"]),
                )
1536
1537
1538


def _test_NodeEdgeGraphData():
1539
1540
    from dgl.data.csv_dataset_base import EdgeData, GraphData, NodeData

1541
1542
    # NodeData basics
    num_nodes = 100
1543
    node_ids = np.arange(num_nodes, dtype=float)
1544
    ndata = NodeData(node_ids, {})
1545
    assert np.array_equal(ndata.id, node_ids)
1546
    assert len(ndata.data) == 0
1547
    assert ndata.type == "_V"
1548
1549
    assert np.array_equal(ndata.graph_id, np.full(num_nodes, 0))
    # NodeData more
1550
    data = {"feat": np.random.rand(num_nodes, 3)}
1551
    graph_id = np.arange(num_nodes)
1552
1553
    ndata = NodeData(node_ids, data, type="user", graph_id=graph_id)
    assert ndata.type == "user"
1554
1555
1556
1557
1558
1559
1560
1561
    assert np.array_equal(ndata.graph_id, graph_id)
    assert len(ndata.data) == len(data)
    for k, v in data.items():
        assert k in ndata.data
        assert np.array_equal(ndata.data[k], v)
    # NodeData except
    expect_except = False
    try:
1562
1563
1564
1565
1566
        NodeData(
            np.arange(num_nodes),
            {"feat": np.random.rand(num_nodes + 1, 3)},
            graph_id=np.arange(num_nodes - 1),
        )
1567
1568
1569
1570
1571
1572
1573
1574
1575
    except:
        expect_except = True
    assert expect_except

    # EdgeData basics
    num_nodes = 100
    num_edges = 1000
    src_ids = np.random.randint(num_nodes, size=num_edges)
    dst_ids = np.random.randint(num_nodes, size=num_edges)
1576
    edata = EdgeData(src_ids, dst_ids, {})
1577
1578
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
1579
    assert edata.type == ("_V", "_E", "_V")
1580
1581
1582
    assert len(edata.data) == 0
    assert np.array_equal(edata.graph_id, np.full(num_edges, 0))
    # EdageData more
1583
1584
    src_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
    dst_ids = np.random.randint(num_nodes, size=num_edges).astype(float)
1585
1586
    data = {"feat": np.random.rand(num_edges, 3)}
    etype = ("user", "like", "item")
1587
    graph_ids = np.arange(num_edges)
1588
    edata = EdgeData(src_ids, dst_ids, data, type=etype, graph_id=graph_ids)
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
    assert np.array_equal(edata.src, src_ids)
    assert np.array_equal(edata.dst, dst_ids)
    assert edata.type == etype
    assert len(edata.data) == len(data)
    for k, v in data.items():
        assert k in edata.data
        assert np.array_equal(edata.data[k], v)
    assert np.array_equal(edata.graph_id, graph_ids)
    # EdgeData except
    expect_except = False
    try:
1600
1601
1602
1603
1604
1605
        EdgeData(
            np.arange(num_edges),
            np.arange(num_edges + 1),
            {"feat": np.random.rand(num_edges - 1, 3)},
            graph_id=np.arange(num_edges + 2),
        )
1606
1607
1608
1609
1610
1611
1612
    except:
        expect_except = True
    assert expect_except

    # GraphData basics
    num_graphs = 10
    graph_ids = np.arange(num_graphs)
1613
    gdata = GraphData(graph_ids, {})
1614
1615
1616
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == 0
    # GraphData more
1617
    graph_ids = np.arange(num_graphs).astype(float)
1618
    data = {"feat": np.random.rand(num_graphs, 3)}
1619
    gdata = GraphData(graph_ids, data)
1620
1621
1622
1623
1624
1625
1626
    assert np.array_equal(gdata.graph_id, graph_ids)
    assert len(gdata.data) == len(data)
    for k, v in data.items():
        assert k in gdata.data
        assert np.array_equal(gdata.data[k], v)


1627
1628
1629
1630
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1631
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1632
1633
1634
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow", reason="Skip Tensorflow"
)
1635
1636
def test_csvdataset():
    _test_NodeEdgeGraphData()
1637
    _test_construct_graphs_node_ids()
1638
1639
1640
1641
1642
1643
1644
1645
    _test_construct_graphs_homo()
    _test_construct_graphs_hetero()
    _test_construct_graphs_multiple()
    _test_DefaultDataParser()
    _test_load_yaml_with_sanity_check()
    _test_load_node_data_from_csv()
    _test_load_edge_data_from_csv()
    _test_load_graph_data_from_csv()
1646
1647
1648
    _test_CSVDataset_single()
    _test_CSVDataset_multiple()
    _test_CSVDataset_customized_data_parser()
1649

1650
1651
1652
1653
1654

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1655
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1656
1657
def test_as_nodepred1():
    ds = data.AmazonCoBuyComputerDataset()
1658
    print("train_mask" in ds[0].ndata)
1659
1660
1661
1662
    new_ds = data.AsNodePredDataset(ds, [0.8, 0.1, 0.1], verbose=True)
    assert len(new_ds) == 1
    assert new_ds[0].num_nodes() == ds[0].num_nodes()
    assert new_ds[0].num_edges() == ds[0].num_edges()
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
    assert "train_mask" in new_ds[0].ndata
    assert F.array_equal(
        new_ds.train_idx, F.nonzero_1d(new_ds[0].ndata["train_mask"])
    )
    assert F.array_equal(
        new_ds.val_idx, F.nonzero_1d(new_ds[0].ndata["val_mask"])
    )
    assert F.array_equal(
        new_ds.test_idx, F.nonzero_1d(new_ds[0].ndata["test_mask"])
    )
1673
1674

    ds = data.AIFBDataset()
1675
1676
1677
1678
    print("train_mask" in ds[0].nodes["Personen"].data)
    new_ds = data.AsNodePredDataset(
        ds, [0.8, 0.1, 0.1], "Personen", verbose=True
    )
1679
1680
1681
    assert len(new_ds) == 1
    assert new_ds[0].ntypes == ds[0].ntypes
    assert new_ds[0].canonical_etypes == ds[0].canonical_etypes
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
    assert "train_mask" in new_ds[0].nodes["Personen"].data
    assert F.array_equal(
        new_ds.train_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["train_mask"]),
    )
    assert F.array_equal(
        new_ds.val_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["val_mask"]),
    )
    assert F.array_equal(
        new_ds.test_idx,
        F.nonzero_1d(new_ds[0].nodes["Personen"].data["test_mask"]),
    )


@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1701
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1702
1703
1704
1705
def test_as_nodepred2():
    # test proper reprocessing

    # create
1706
1707
1708
1709
1710
1711
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1712
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1713
    # read from cache
1714
1715
1716
1717
1718
1719
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.8, 0.1, 0.1]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.8
    )
1720
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.8)
1721
    # invalid cache, re-read
1722
1723
1724
1725
1726
1727
    ds = data.AsNodePredDataset(
        data.AmazonCoBuyComputerDataset(), [0.1, 0.1, 0.8]
    )
    assert F.sum(F.astype(ds[0].ndata["train_mask"], F.int32), 0) == int(
        ds[0].num_nodes() * 0.1
    )
1728
    assert len(ds.train_idx) == int(ds[0].num_nodes() * 0.1)
1729
1730

    # create
1731
1732
1733
1734
1735
1736
1737
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1738
    # read from cache
1739
1740
1741
1742
1743
1744
1745
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.8, 0.1, 0.1], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.8)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.8)
1746
    # invalid cache, re-read
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
    ds = data.AsNodePredDataset(
        data.AIFBDataset(), [0.1, 0.1, 0.8], "Personen", verbose=True
    )
    assert F.sum(
        F.astype(ds[0].nodes["Personen"].data["train_mask"], F.int32), 0
    ) == int(ds[0].num_nodes("Personen") * 0.1)
    assert len(ds.train_idx) == int(ds[0].num_nodes("Personen") * 0.1)


@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Minjie Wang's avatar
Minjie Wang committed
1759
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Jinjing Zhou's avatar
Jinjing Zhou committed
1760
1761
def test_as_nodepred_ogb():
    from ogb.nodeproppred import DglNodePropPredDataset
1762
1763
1764
1765

    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"), split_ratio=None, verbose=True
    )
1766
    split = DglNodePropPredDataset("ogbn-arxiv").get_idx_split()
1767
    train_idx, val_idx, test_idx = split["train"], split["valid"], split["test"]
1768
1769
1770
    assert F.array_equal(ds.train_idx, F.tensor(train_idx))
    assert F.array_equal(ds.val_idx, F.tensor(val_idx))
    assert F.array_equal(ds.test_idx, F.tensor(test_idx))
Jinjing Zhou's avatar
Jinjing Zhou committed
1771
    # force generate new split
1772
1773
1774
1775
1776
1777
    ds = data.AsNodePredDataset(
        DglNodePropPredDataset("ogbn-arxiv"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )

1778

1779
1780
1781
1782
@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1783
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1784
1785
def test_as_linkpred():
    # create
1786
1787
1788
1789
1790
1791
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.8, 0.1, 0.1],
        neg_ratio=1,
        verbose=True,
    )
1792
1793
1794
1795
1796
    # Cora has 10556 edges, 10% test edges can be 1057
    assert ds.test_edges[0][0].shape[0] == 1057
    # negative samples, not guaranteed, so the assert is in a relaxed range
    assert 1000 <= ds.test_edges[1][0].shape[0] <= 1057
    # read from cache
1797
1798
1799
1800
1801
1802
    ds = data.AsLinkPredDataset(
        data.CoraGraphDataset(),
        split_ratio=[0.7, 0.1, 0.2],
        neg_ratio=2,
        verbose=True,
    )
1803
1804
1805
1806
1807
    assert ds.test_edges[0][0].shape[0] == 2112
    # negative samples, not guaranteed to be ratio 2, so the assert is in a relaxed range
    assert 4000 < ds.test_edges[1][0].shape[0] <= 4224


1808
1809
1810
@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
1811
1812
def test_as_linkpred_ogb():
    from ogb.linkproppred import DglLinkPropPredDataset
1813
1814
1815
1816

    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"), split_ratio=None, verbose=True
    )
1817
1818
1819
    # original dataset has 46329 test edges
    assert ds.test_edges[0][0].shape[0] == 46329
    # force generate new split
1820
1821
1822
1823
1824
    ds = data.AsLinkPredDataset(
        DglLinkPropPredDataset("ogbl-collab"),
        split_ratio=[0.7, 0.2, 0.1],
        verbose=True,
    )
1825
1826
    assert ds.test_edges[0][0].shape[0] == 235812

1827
1828
1829
1830
1831

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1832
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
1833
1834
1835
@unittest.skipIf(
    dgl.backend.backend_name == "tensorflow", reason="Skip Tensorflow"
)
1836
1837
1838
1839
1840
1841
def test_as_nodepred_csvdataset():
    with tempfile.TemporaryDirectory() as test_dir:
        # generate YAML/CSVs
        meta_yaml_path = os.path.join(test_dir, "meta.yaml")
        edges_csv_path = os.path.join(test_dir, "test_edges.csv")
        nodes_csv_path = os.path.join(test_dir, "test_nodes.csv")
1842
1843
1844
1845
1846
1847
1848
        meta_yaml_data = {
            "version": "1.0.0",
            "dataset_name": "default_name",
            "node_data": [{"file_name": os.path.basename(nodes_csv_path)}],
            "edge_data": [{"file_name": os.path.basename(edges_csv_path)}],
        }
        with open(meta_yaml_path, "w") as f:
1849
1850
1851
1852
1853
1854
1855
            yaml.dump(meta_yaml_data, f, sort_keys=False)
        num_nodes = 100
        num_edges = 500
        num_dims = 3
        num_classes = num_nodes
        feat_ndata = np.random.rand(num_nodes, num_dims)
        label_ndata = np.arange(num_classes)
1856
1857
1858
1859
1860
1861
1862
        df = pd.DataFrame(
            {
                "node_id": np.arange(num_nodes),
                "label": label_ndata,
                "feat": [line.tolist() for line in feat_ndata],
            }
        )
1863
        df.to_csv(nodes_csv_path, index=False)
1864
1865
1866
1867
1868
1869
        df = pd.DataFrame(
            {
                "src_id": np.random.randint(num_nodes, size=num_edges),
                "dst_id": np.random.randint(num_nodes, size=num_edges),
            }
        )
1870
1871
        df.to_csv(edges_csv_path, index=False)

1872
        ds = data.CSVDataset(test_dir, force_reload=True)
1873
1874
1875
1876
1877
1878
1879
        assert "feat" in ds[0].ndata
        assert "label" in ds[0].ndata
        assert "train_mask" not in ds[0].ndata
        assert not hasattr(ds[0], "num_classes")
        new_ds = data.AsNodePredDataset(
            ds, split_ratio=[0.8, 0.1, 0.1], force_reload=True
        )
1880
        assert new_ds.num_classes == num_classes
1881
1882
1883
        assert "feat" in new_ds[0].ndata
        assert "label" in new_ds[0].ndata
        assert "train_mask" in new_ds[0].ndata
1884

1885
1886
1887
1888
1889

@unittest.skipIf(
    F._default_context_str == "gpu",
    reason="Datasets don't need to be tested on GPU.",
)
Minjie Wang's avatar
Minjie Wang committed
1890
@unittest.skipIf(dgl.backend.backend_name == "mxnet", reason="Skip MXNet")
Mufei Li's avatar
Mufei Li committed
1891
def test_as_graphpred_reprocess():
1892
1893
1894
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1895
1896
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1897
1898
1899
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1900
1901
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1902
1903
1904
    ds = data.AsGraphPredDataset(
        data.GINDataset(name="MUTAG", self_loop=True), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1905
1906
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1907
1908
1909
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1910
1911
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1912
1913
1914
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1915
1916
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1917
1918
1919
    ds = data.AsGraphPredDataset(
        data.FakeNewsDataset("politifact", "profile"), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.QM7bDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1931
1932
1933
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1934
1935
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1936
1937
1938
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1939
1940
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1941
1942
1943
    ds = data.AsGraphPredDataset(
        data.QM9Dataset(label_keys=["mu", "gap"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1944
1945
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1946
1947
1948
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1949
1950
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1951
1952
1953
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.8, 0.1, 0.1]
    )
Mufei Li's avatar
Mufei Li committed
1954
1955
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1956
1957
1958
    ds = data.AsGraphPredDataset(
        data.QM9EdgeDataset(label_keys=["mu", "alpha"]), [0.1, 0.1, 0.8]
    )
Mufei Li's avatar
Mufei Li committed
1959
1960
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1961
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1962
1963
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1964
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1965
1966
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1967
    ds = data.AsGraphPredDataset(data.TUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
1968
1969
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1970
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1971
1972
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
1973
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.8, 0.1, 0.1])
Mufei Li's avatar
Mufei Li committed
1974
1975
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
1976
    ds = data.AsGraphPredDataset(data.LegacyTUDataset("DD"), [0.1, 0.1, 0.8])
Mufei Li's avatar
Mufei Li committed
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
    assert len(ds.train_idx) == int(len(ds) * 0.1)

    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # read from cache
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.8, 0.1, 0.1])
    assert len(ds.train_idx) == int(len(ds) * 0.8)
    # invalid cache, re-read
    ds = data.AsGraphPredDataset(data.BA2MotifDataset(), [0.1, 0.1, 0.8])
    assert len(ds.train_idx) == int(len(ds) * 0.1)

1988
1989
1990
1991

@unittest.skipIf(
    dgl.backend.backend_name != "pytorch", reason="ogb only supports pytorch"
)
Mufei Li's avatar
Mufei Li committed
1992
1993
def test_as_graphpred_ogb():
    from ogb.graphproppred import DglGraphPropPredDataset
1994
1995
1996
1997

    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"), split_ratio=None, verbose=True
    )
Mufei Li's avatar
Mufei Li committed
1998
1999
    assert len(ds.train_idx) == 32901
    # force generate new split
2000
2001
2002
2003
2004
    ds = data.AsGraphPredDataset(
        DglGraphPropPredDataset("ogbg-molhiv"),
        split_ratio=[0.6, 0.2, 0.2],
        verbose=True,
    )
Mufei Li's avatar
Mufei Li committed
2005
2006
    assert len(ds.train_idx) == 24676

2007
2008

if __name__ == "__main__":
2009
    test_minigc()
2010
    test_gin()
2011
    test_data_hash()
2012
2013
2014
    test_tudataset_regression()
    test_fraud()
    test_fakenews()
2015
    test_csvdataset()
2016
2017
    test_as_nodepred1()
    test_as_nodepred2()
2018
    test_as_nodepred_csvdataset()