csv_dataset.py 6.39 KB
Newer Older
1
2
3
4
import os
import numpy as np
from .dgl_dataset import DGLDataset
from .utils import save_graphs, load_graphs
5
from ..base import DGLError
6
7


8
class CSVDataset(DGLDataset):
RhettYing's avatar
refine  
RhettYing committed
9
    """Dataset class that loads and parses graph data from CSV files.
10
11
12
13
14
15
16
17
18

    Parameters
    ----------
    data_path : str
        Directory which contains 'meta.yaml' and CSV files
    force_reload : bool, optional
        Whether to reload the dataset. Default: False
    verbose: bool, optional
        Whether to print out progress information. Default: True.
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
    ndata_parser : dict[str, callable] or callable, optional
        Callable object which takes in the ``pandas.DataFrame`` object created from
        CSV file, parses node data and returns a dictionary of parsed data. If given a
        dictionary, the key is node type and the value is a callable object which is
        used to parse data of corresponding node type. If given a single callable
        object, such object is used to parse data of all node type data. Default: None.
        If None, a default data parser is applied which load data directly and tries to
        convert list into array.
    edata_parser : dict[(str, str, str), callable], or callable, optional
        Callable object which takes in the ``pandas.DataFrame`` object created from
        CSV file, parses edge data and returns a dictionary of parsed data. If given a
        dictionary, the key is edge type and the value is a callable object which is
        used to parse data of corresponding edge type. If given a single callable
        object, such object is used to parse data of all edge type data. Default: None.
        If None, a default data parser is applied which load data directly and tries to
        convert list into array.
    gdata_parser : callable, optional
        Callable object which takes in the ``pandas.DataFrame`` object created from
        CSV file, parses graph data and returns a dictionary of parsed data. Default:
        None. If None, a default data parser is applied which load data directly and
        tries to convert list into array.
40
41
42
43
    transform : callable, optional
        A transform that takes in a :class:`~dgl.DGLGraph` object and returns
        a transformed version. The :class:`~dgl.DGLGraph` object will be
        transformed before every access.
44
45
46
47
48
49
50
51
52

    Attributes
    ----------
    graphs : :class:`dgl.DGLGraph`
        Graphs of the dataset
    data : dict
        any available graph-level data such as graph-level feature, labels.

    Examples
RhettYing's avatar
RhettYing committed
53
    --------
RhettYing's avatar
refine  
RhettYing committed
54
    Please refer to :ref:`guide-data-pipeline-loadcsv`.
RhettYing's avatar
RhettYing committed
55

56
57
58
    """
    META_YAML_NAME = 'meta.yaml'

59
60
    def __init__(self, data_path, force_reload=False, verbose=True, ndata_parser=None,
                 edata_parser=None, gdata_parser=None, transform=None):
61
        from .csv_dataset_base import load_yaml_with_sanity_check, DefaultDataParser
62
63
        self.graphs = None
        self.data = None
64
65
66
        self.ndata_parser = {} if ndata_parser is None else ndata_parser
        self.edata_parser = {} if edata_parser is None else edata_parser
        self.gdata_parser = gdata_parser
67
        self.default_data_parser = DefaultDataParser()
68
        meta_yaml_path = os.path.join(data_path, CSVDataset.META_YAML_NAME)
69
70
        if not os.path.exists(meta_yaml_path):
            raise DGLError(
71
                "'{}' cannot be found under {}.".format(CSVDataset.META_YAML_NAME, data_path))
72
73
74
        self.meta_yaml = load_yaml_with_sanity_check(meta_yaml_path)
        ds_name = self.meta_yaml.dataset_name
        super().__init__(ds_name, raw_dir=os.path.dirname(
75
            meta_yaml_path), force_reload=force_reload, verbose=verbose, transform=transform)
76

77

78
79
80
    def process(self):
        """Parse node/edge data from CSV files and construct DGL.Graphs
        """
81
        from .csv_dataset_base import NodeData, EdgeData, GraphData, DGLGraphConstructor
82
83
84
85
86
87
88
        meta_yaml = self.meta_yaml
        base_dir = self.raw_dir
        node_data = []
        for meta_node in meta_yaml.node_data:
            if meta_node is None:
                continue
            ntype = meta_node.ntype
89
90
            data_parser = self.ndata_parser if callable(
                self.ndata_parser) else self.ndata_parser.get(ntype, self.default_data_parser)
91
92
93
94
95
96
97
98
            ndata = NodeData.load_from_csv(
                meta_node, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
            node_data.append(ndata)
        edge_data = []
        for meta_edge in meta_yaml.edge_data:
            if meta_edge is None:
                continue
            etype = tuple(meta_edge.etype)
99
100
            data_parser = self.edata_parser if callable(
                self.edata_parser) else self.edata_parser.get(etype, self.default_data_parser)
101
102
103
104
105
106
            edata = EdgeData.load_from_csv(
                meta_edge, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
            edge_data.append(edata)
        graph_data = None
        if meta_yaml.graph_data is not None:
            meta_graph = meta_yaml.graph_data
107
            data_parser = self.default_data_parser if self.gdata_parser is None else self.gdata_parser
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
            graph_data = GraphData.load_from_csv(
                meta_graph, base_dir=base_dir, separator=meta_yaml.separator, data_parser=data_parser)
        # construct graphs
        self.graphs, self.data = DGLGraphConstructor.construct_graphs(
            node_data, edge_data, graph_data)

    def has_cache(self):
        graph_path = os.path.join(self.save_path,
                                  self.name + '.bin')
        if os.path.exists(graph_path):
            return True

        return False

    def save(self):
        if self.graphs is None:
            raise DGLError("No graphs available in dataset")
        graph_path = os.path.join(self.save_path,
                                  self.name + '.bin')
        save_graphs(graph_path, self.graphs,
                    labels=self.data)

    def load(self):
        graph_path = os.path.join(self.save_path,
                                  self.name + '.bin')
        self.graphs, self.data = load_graphs(graph_path)

    def __getitem__(self, i):
136
137
138
139
140
        if self._transform is None:
            g = self.graphs[i]
        else:
            g = self._transform(self.graphs[i])

141
142
143
        if len(self.data) > 0:
            data = {k: v[i] for (k, v) in self.data.items()}
            return g, data
144
        else:
145
            return g
146
147
148

    def __len__(self):
        return len(self.graphs)