6_load_data.py 8.35 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""
Make Your Own Dataset
=====================

This tutorial assumes that you already know :doc:`the basics of training a
GNN for node classification <1_introduction>` and :doc:`how to
create, load, and store a DGL graph <2_dglgraph>`.

By the end of this tutorial, you will be able to

-  Create your own graph dataset for node classification, link
   prediction, or graph classification.

(Time estimate: 15 minutes)
"""


######################################################################
# ``DGLDataset`` Object Overview
# ------------------------------
21
#
22
23
# Your custom graph dataset should inherit the ``dgl.data.DGLDataset``
# class and implement the following methods:
24
#
25
26
27
28
29
# -  ``__getitem__(self, i)``: retrieve the ``i``-th example of the
#    dataset. An example often contains a single DGL graph, and
#    occasionally its label.
# -  ``__len__(self)``: the number of examples in the dataset.
# -  ``process(self)``: load and process raw data from disk.
30
#
31
32
33
34
35


######################################################################
# Creating a Dataset for Node Classification or Link Prediction from CSV
# ----------------------------------------------------------------------
36
#
37
38
# A node classification dataset often consists of a single graph, as well
# as its node and edge features.
39
#
40
41
42
43
44
45
46
47
48
49
50
51
# This tutorial takes a small dataset based on `Zachary’s Karate Club
# network <https://en.wikipedia.org/wiki/Zachary%27s_karate_club>`__. It
# contains
#
# * A ``members.csv`` file containing the attributes of all
#   members, as well as their attributes.
#
# * An ``interactions.csv`` file
#   containing the pair-wise interactions between two club members.
#

import urllib.request
52

53
import pandas as pd
54

55
urllib.request.urlretrieve(
56
57
    "https://data.dgl.ai/tutorial/dataset/members.csv", "./members.csv"
)
58
urllib.request.urlretrieve(
59
60
61
    "https://data.dgl.ai/tutorial/dataset/interactions.csv",
    "./interactions.csv",
)
62

63
members = pd.read_csv("./members.csv")
64
65
members.head()

66
interactions = pd.read_csv("./interactions.csv")
67
68
69
70
71
72
73
interactions.head()


######################################################################
# This tutorial treats the members as nodes and interactions as edges. It
# takes age as a numeric feature of the nodes, affiliated club as the label
# of the nodes, and edge weight as a numeric feature of the edges.
74
#
75
# .. note::
76
#
77
78
79
80
#    The original Zachary’s Karate Club network does not have
#    member ages. The ages in this tutorial are generated synthetically
#    for demonstrating how to add node features into the graph for dataset
#    creation.
81
#
82
# .. note::
83
#
84
85
86
87
#    In practice, taking age directly as a numeric feature may
#    not work well in machine learning; strategies like binning or
#    normalizing the feature would work better. This tutorial directly
#    takes the values as-is for simplicity.
88
89
90
91
92
#

import os

import torch
93
94
95

import dgl
from dgl.data import DGLDataset
96

97
98
99

class KarateClubDataset(DGLDataset):
    def __init__(self):
100
101
        super().__init__(name="karate_club")

102
    def process(self):
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        nodes_data = pd.read_csv("./members.csv")
        edges_data = pd.read_csv("./interactions.csv")
        node_features = torch.from_numpy(nodes_data["Age"].to_numpy())
        node_labels = torch.from_numpy(
            nodes_data["Club"].astype("category").cat.codes.to_numpy()
        )
        edge_features = torch.from_numpy(edges_data["Weight"].to_numpy())
        edges_src = torch.from_numpy(edges_data["Src"].to_numpy())
        edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy())

        self.graph = dgl.graph(
            (edges_src, edges_dst), num_nodes=nodes_data.shape[0]
        )
        self.graph.ndata["feat"] = node_features
        self.graph.ndata["label"] = node_labels
        self.graph.edata["weight"] = edge_features

120
121
122
123
124
125
126
127
128
        # If your dataset is a node classification dataset, you will need to assign
        # masks indicating whether a node belongs to training, validation, and test set.
        n_nodes = nodes_data.shape[0]
        n_train = int(n_nodes * 0.6)
        n_val = int(n_nodes * 0.2)
        train_mask = torch.zeros(n_nodes, dtype=torch.bool)
        val_mask = torch.zeros(n_nodes, dtype=torch.bool)
        test_mask = torch.zeros(n_nodes, dtype=torch.bool)
        train_mask[:n_train] = True
129
130
131
132
133
134
        val_mask[n_train : n_train + n_val] = True
        test_mask[n_train + n_val :] = True
        self.graph.ndata["train_mask"] = train_mask
        self.graph.ndata["val_mask"] = val_mask
        self.graph.ndata["test_mask"] = test_mask

135
136
    def __getitem__(self, i):
        return self.graph
137

138
139
140
    def __len__(self):
        return 1

141

142
143
144
145
146
147
148
149
150
151
dataset = KarateClubDataset()
graph = dataset[0]

print(graph)


######################################################################
# Since a link prediction dataset only involves a single graph, preparing
# a link prediction dataset will have the same experience as preparing a
# node classification dataset.
152
#
153
154
155
156
157


######################################################################
# Creating a Dataset for Graph Classification from CSV
# ----------------------------------------------------
158
#
159
160
# Creating a graph classification dataset involves implementing
# ``__getitem__`` to return both the graph and its graph-level label.
161
#
162
163
# This tutorial demonstrates how to create a graph classification dataset
# with the following synthetic CSV data:
164
#
165
# -  ``graph_edges.csv``: containing three columns:
166
#
167
168
169
#    -  ``graph_id``: the ID of the graph.
#    -  ``src``: the source node of an edge of the given graph.
#    -  ``dst``: the destination node of an edge of the given graph.
170
#
171
# -  ``graph_properties.csv``: containing three columns:
172
#
173
174
175
#    -  ``graph_id``: the ID of the graph.
#    -  ``label``: the label of the graph.
#    -  ``num_nodes``: the number of nodes in the graph.
176
#
177
178

urllib.request.urlretrieve(
179
180
    "https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
181
urllib.request.urlretrieve(
182
183
184
185
186
    "https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
    "./graph_properties.csv",
)
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
187
188
189
190
191

edges.head()

properties.head()

192

193
194
class SyntheticDataset(DGLDataset):
    def __init__(self):
195
196
        super().__init__(name="synthetic")

197
    def process(self):
198
199
        edges = pd.read_csv("./graph_edges.csv")
        properties = pd.read_csv("./graph_properties.csv")
200
201
        self.graphs = []
        self.labels = []
202

203
204
205
206
207
208
        # Create a graph for each graph ID from the edges table.
        # First process the properties table into two dictionaries with graph IDs as keys.
        # The label and number of nodes are values.
        label_dict = {}
        num_nodes_dict = {}
        for _, row in properties.iterrows():
209
210
211
            label_dict[row["graph_id"]] = row["label"]
            num_nodes_dict[row["graph_id"]] = row["num_nodes"]

212
        # For the edges, first group the table by graph IDs.
213
214
        edges_group = edges.groupby("graph_id")

215
216
217
218
        # For each graph ID...
        for graph_id in edges_group.groups:
            # Find the edges as well as the number of nodes and its label.
            edges_of_id = edges_group.get_group(graph_id)
219
220
            src = edges_of_id["src"].to_numpy()
            dst = edges_of_id["dst"].to_numpy()
221
222
            num_nodes = num_nodes_dict[graph_id]
            label = label_dict[graph_id]
223

224
225
226
227
            # Create a graph and add it to the list of graphs and labels.
            g = dgl.graph((src, dst), num_nodes=num_nodes)
            self.graphs.append(g)
            self.labels.append(label)
228

229
230
        # Convert the label list to tensor for saving.
        self.labels = torch.LongTensor(self.labels)
231

232
233
    def __getitem__(self, i):
        return self.graphs[i], self.labels[i]
234

235
236
237
    def __len__(self):
        return len(self.graphs)

238

239
240
241
242
dataset = SyntheticDataset()
graph, label = dataset[0]
print(graph, label)

RhettYing's avatar
RhettYing committed
243
######################################################################
RhettYing's avatar
RhettYing committed
244
# Creating Dataset from CSV via :class:`~dgl.data.CSVDataset`
RhettYing's avatar
RhettYing committed
245
246
# ------------------------------------------------------------
#
RhettYing's avatar
refine  
RhettYing committed
247
248
249
250
# The previous examples describe how to create a dataset from CSV files
# step-by-step. DGL also provides a utility class :class:`~dgl.data.CSVDataset`
# for reading and parsing data from CSV files. See :ref:`guide-data-pipeline-loadcsv`
# for more details.
RhettYing's avatar
RhettYing committed
251
252
#

253

254
# Thumbnail credits: (Un)common Use Cases for Graph Databases, Michal Bachman
255
# sphinx_gallery_thumbnail_path = '_static/blitz_6_load_data.png'