"...text-generation-inference.git" did not exist on "f41d644a903d179915e122896aba6bc77821795a"
graph_serialize.py 5.37 KB
Newer Older
VoVAllen's avatar
VoVAllen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
"""For Graph Serialization"""
from __future__ import absolute_import
from ..graph import DGLGraph
from ..batched_graph import BatchedDGLGraph
from .._ffi.object import ObjectBase, register_object
from .._ffi.function import _init_api
from .. import backend as F

_init_api("dgl.data.graph_serialize")

__all__ = ['save_graphs', "load_graphs", "load_labels"]

@register_object("graph_serialize.StorageMetaData")
class StorageMetaData(ObjectBase):
    """StorageMetaData Object
    attributes available:
      num_graph [int]: return numbers of graphs
      nodes_num_list Value of NDArray: return number of nodes for each graph
      edges_num_list Value of NDArray: return number of edges for each graph
      labels [dict of backend tensors]: return dict of labels
      graph_data [list of GraphData]: return list of GraphData Object
    """


@register_object("graph_serialize.GraphData")
class GraphData(ObjectBase):
    """GraphData Object"""

    @staticmethod
    def create(g: DGLGraph):
        """Create GraphData"""
        assert not isinstance(g, BatchedDGLGraph), "BatchedDGLGraph is not supported for serialization"
        ghandle = g._graph
        if len(g.ndata) != 0:
            node_tensors = dict()
            for key, value in g.ndata.items():
                node_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
        else:
            node_tensors = None

        if len(g.edata) != 0:
            edge_tensors = dict()
            for key, value in g.edata.items():
                edge_tensors[key] = F.zerocopy_to_dgl_ndarray(value)
        else:
            edge_tensors = None

        return _CAPI_MakeGraphData(ghandle, node_tensors, edge_tensors)

    def get_graph(self):
        """Get DGLGraph from GraphData"""
        ghandle = _CAPI_GDataGraphHandle(self)
        g = DGLGraph(graph_data=ghandle, readonly=True)
        node_tensors_items = _CAPI_GDataNodeTensors(self).items()
        edge_tensors_items = _CAPI_GDataEdgeTensors(self).items()
        for k, v in node_tensors_items:
            g.ndata[k] = F.zerocopy_from_dgl_ndarray(v.data)
        for k, v in edge_tensors_items:
            g.edata[k] = F.zerocopy_from_dgl_ndarray(v.data)
        return g


def save_graphs(filename, g_list, labels=None):
    r"""
    Save DGLGraphs and graph labels to file

    Parameters
    ----------
    filename : str
        File name to store DGLGraphs. 
    g_list: list
        DGLGraph or list of DGLGraph
    labels: dict (Default: None)
        labels should be dict of tensors/ndarray, with str as keys

    Examples
    ----------
    >>> import dgl
    >>> import torch as th

    Create :code:`DGLGraph` objects and initialize node and edge features.

    >>> g1 = dgl.DGLGraph()
    >>> g1.add_nodes(3)
    >>> g1.add_edges([0, 0, 0, 1, 1, 2], [0, 1, 2, 1, 2, 2])
    >>> g1.ndata["e"] = th.ones(3, 5)
    >>> g2 = dgl.DGLGraph()
    >>> g2.add_nodes(3)
    >>> g2.add_edges([0, 1, 2], [1, 2, 1])
90
    >>> g2.edata["e"] = th.ones(3, 4)
VoVAllen's avatar
VoVAllen committed
91
92
93
94
95

    Save Graphs into file

    >>> from dgl.data.utils import save_graphs
    >>> graph_labels = {"glabel": th.tensor([0, 1])}
96
    >>> save_graphs("./data.bin", [g1, g2], graph_labels)
VoVAllen's avatar
VoVAllen committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132

    """
    if isinstance(g_list, DGLGraph):
        g_list = [g_list]
    if (labels is not None) and (len(labels) != 0):
        label_dict = dict()
        for key, value in labels.items():
            label_dict[key] = F.zerocopy_to_dgl_ndarray(value)
    else:
        label_dict = None
    gdata_list = [GraphData.create(g) for g in g_list]
    _CAPI_DGLSaveGraphs(filename, gdata_list, label_dict)


def load_graphs(filename, idx_list=None):
    """
    Load DGLGraphs from file

    Parameters
    ----------
    filename: str
        filename to load DGLGraphs
    idx_list: list of int
        list of index of graph to be loaded. If not specified, will
        load all graphs from file

    Returns
    ----------
    graph_list: list of immutable DGLGraphs
    labels: dict of labels stored in file (empty dict returned if no
    label stored)

    Examples
    ----------
    Following the example in save_graphs.

133
    >>> from dgl.data.utils import load_graphs
VoVAllen's avatar
VoVAllen committed
134
135
136
137
138
139
    >>> glist, label_dict = load_graphs("./data.bin") # glist will be [g1, g2]
    >>> glist, label_dict = load_graphs("./data.bin", [0]) # glist will be [g1]

    """
    if idx_list is None:
        idx_list = []
140
    assert isinstance(idx_list, list)
VoVAllen's avatar
VoVAllen committed
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
    metadata = _CAPI_DGLLoadGraphs(filename, idx_list, False)
    label_dict = {}
    for k, v in metadata.labels.items():
        label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)

    return [gdata.get_graph() for gdata in metadata.graph_data], label_dict


def load_labels(filename):
    """
    Load label dict from file

    Parameters
    ----------
    filename: str
        filename to load DGLGraphs

    Returns
    ----------
    labels: dict
        dict of labels stored in file (empty dict returned if no
        label stored)

    Examples
    ----------
    Following the example in save_graphs.

    >>> from dgl.data.utils import load_labels
    >>> label_dict = load_graphs("./data.bin")

    """
    metadata = _CAPI_DGLLoadGraphs(filename, [], True)
    label_dict = {}
    for k, v in metadata.labels.items():
        label_dict[k] = F.zerocopy_from_dgl_ndarray(v.data)
    return label_dict