test_heterograph-pickle.py 663 Bytes
Newer Older
1
2
3
import io
import pickle

4
5
import networkx as nx
import torch
6
7
8

import dgl

9
10
11
12
13
14
15
16
17

def _reconstruct_pickle(obj):
    f = io.BytesIO()
    pickle.dump(obj, f)
    f.seek(0)
    obj = pickle.load(f)
    f.close()
    return obj

18

19
20
21
22
23
24
def test_pickling_batched_graph():
    # NOTE: this is a test for a wierd bug mentioned in
    #   https://github.com/dmlc/dgl/issues/438
    glist = [nx.path_graph(i + 5) for i in range(5)]
    glist = [dgl.DGLGraph(g) for g in glist]
    bg = dgl.batch(glist)
25
26
    bg.ndata["x"] = torch.randn((35, 5))
    bg.edata["y"] = torch.randn((60, 3))
27
28
    new_bg = _reconstruct_pickle(bg)

29
30

if __name__ == "__main__":
31
    test_pickling_batched_graph()