Unverified Commit 2b983869 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4705)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent 408eba24
import logging import logging
import numpy as np import numpy as np
from numpy.lib.format import open_memmap from numpy.lib.format import open_memmap
from .registry import register_array_parser from .registry import register_array_parser
@register_array_parser("numpy") @register_array_parser("numpy")
class NumpyArrayParser(object): class NumpyArrayParser(object):
def __init__(self): def __init__(self):
pass pass
def read(self, path): def read(self, path):
logging.info('Reading from %s using numpy format' % path) logging.info("Reading from %s using numpy format" % path)
arr = np.load(path, mmap_mode='r') arr = np.load(path, mmap_mode="r")
logging.info('Done reading from %s' % path) logging.info("Done reading from %s" % path)
return arr return arr
def write(self, path, arr): def write(self, path, arr):
logging.info('Writing to %s using numpy format' % path) logging.info("Writing to %s using numpy format" % path)
# np.save would load the entire memmap array up into CPU. So we manually open # np.save would load the entire memmap array up into CPU. So we manually open
# an empty npy file with memmap mode and manually flush it instead. # an empty npy file with memmap mode and manually flush it instead.
new_arr = open_memmap(path, mode='w+', dtype=arr.dtype, shape=arr.shape) new_arr = open_memmap(path, mode="w+", dtype=arr.dtype, shape=arr.shape)
new_arr[:] = arr[:] new_arr[:] = arr[:]
logging.info('Done writing to %s' % path) logging.info("Done writing to %s" % path)
REGISTRY = {} REGISTRY = {}
def register_array_parser(name): def register_array_parser(name):
def _deco(cls): def _deco(cls):
REGISTRY[name] = cls REGISTRY[name] = cls
return cls return cls
return _deco return _deco
def get_array_parser(**fmt_meta): def get_array_parser(**fmt_meta):
cls = REGISTRY[fmt_meta.pop('name')] cls = REGISTRY[fmt_meta.pop("name")]
return cls(**fmt_meta) return cls(**fmt_meta)
import logging
import os import os
from contextlib import contextmanager from contextlib import contextmanager
import logging
from numpy.lib.format import open_memmap from numpy.lib.format import open_memmap
@contextmanager @contextmanager
def setdir(path): def setdir(path):
try: try:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
cwd = os.getcwd() cwd = os.getcwd()
logging.info('Changing directory to %s' % path) logging.info("Changing directory to %s" % path)
logging.info('Previously: %s' % cwd) logging.info("Previously: %s" % cwd)
os.chdir(path) os.chdir(path)
yield yield
finally: finally:
logging.info('Restoring directory to %s' % cwd) logging.info("Restoring directory to %s" % cwd)
os.chdir(cwd) os.chdir(cwd)
...@@ -21,11 +21,12 @@ networks with PyTorch. ...@@ -21,11 +21,12 @@ networks with PyTorch.
""" """
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.data
###################################################################### ######################################################################
# Overview of Node Classification with GNN # Overview of Node Classification with GNN
...@@ -45,7 +46,7 @@ import torch.nn.functional as F ...@@ -45,7 +46,7 @@ import torch.nn.functional as F
# task. With the help of only a small portion of labeled nodes, a graph # task. With the help of only a small portion of labeled nodes, a graph
# neural network (GNN) can accurately predict the node category of the # neural network (GNN) can accurately predict the node category of the
# others. # others.
# #
# This tutorial will show how to build such a GNN for semi-supervised node # This tutorial will show how to build such a GNN for semi-supervised node
# classification with only a small number of labels on the Cora # classification with only a small number of labels on the Cora
# dataset, # dataset,
...@@ -54,21 +55,20 @@ import torch.nn.functional as F ...@@ -54,21 +55,20 @@ import torch.nn.functional as F
# word count vector as its features, normalized so that they sum up to one, # word count vector as its features, normalized so that they sum up to one,
# as described in Section 5.2 of # as described in Section 5.2 of
# `the paper <https://arxiv.org/abs/1609.02907>`__. # `the paper <https://arxiv.org/abs/1609.02907>`__.
# #
# Loading Cora Dataset # Loading Cora Dataset
# -------------------- # --------------------
# #
import dgl.data
dataset = dgl.data.CoraGraphDataset() dataset = dgl.data.CoraGraphDataset()
print('Number of categories:', dataset.num_classes) print("Number of categories:", dataset.num_classes)
###################################################################### ######################################################################
# A DGL Dataset object may contain one or multiple graphs. The Cora # A DGL Dataset object may contain one or multiple graphs. The Cora
# dataset used in this tutorial only consists of one single graph. # dataset used in this tutorial only consists of one single graph.
# #
g = dataset[0] g = dataset[0]
...@@ -77,7 +77,7 @@ g = dataset[0] ...@@ -77,7 +77,7 @@ g = dataset[0]
# A DGL graph can store node features and edge features in two # A DGL graph can store node features and edge features in two
# dictionary-like attributes called ``ndata`` and ``edata``. # dictionary-like attributes called ``ndata`` and ``edata``.
# In the DGL Cora dataset, the graph contains the following node features: # In the DGL Cora dataset, the graph contains the following node features:
# #
# - ``train_mask``: A boolean tensor indicating whether the node is in the # - ``train_mask``: A boolean tensor indicating whether the node is in the
# training set. # training set.
# #
...@@ -90,68 +90,71 @@ g = dataset[0] ...@@ -90,68 +90,71 @@ g = dataset[0]
# - ``label``: The ground truth node category. # - ``label``: The ground truth node category.
# #
# - ``feat``: The node features. # - ``feat``: The node features.
# #
print('Node features') print("Node features")
print(g.ndata) print(g.ndata)
print('Edge features') print("Edge features")
print(g.edata) print(g.edata)
###################################################################### ######################################################################
# Defining a Graph Convolutional Network (GCN) # Defining a Graph Convolutional Network (GCN)
# -------------------------------------------- # --------------------------------------------
# #
# This tutorial will build a two-layer `Graph Convolutional Network # This tutorial will build a two-layer `Graph Convolutional Network
# (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each # (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each
# layer computes new node representations by aggregating neighbor # layer computes new node representations by aggregating neighbor
# information. # information.
# #
# To build a multi-layer GCN you can simply stack ``dgl.nn.GraphConv`` # To build a multi-layer GCN you can simply stack ``dgl.nn.GraphConv``
# modules, which inherit ``torch.nn.Module``. # modules, which inherit ``torch.nn.Module``.
# #
from dgl.nn import GraphConv from dgl.nn import GraphConv
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__() super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats) self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes) self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat): def forward(self, g, in_feat):
h = self.conv1(g, in_feat) h = self.conv1(g, in_feat)
h = F.relu(h) h = F.relu(h)
h = self.conv2(g, h) h = self.conv2(g, h)
return h return h
# Create the model with given dimensions # Create the model with given dimensions
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes) model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)
###################################################################### ######################################################################
# DGL provides implementation of many popular neighbor aggregation # DGL provides implementation of many popular neighbor aggregation
# modules. You can easily invoke them with one line of code. # modules. You can easily invoke them with one line of code.
# #
###################################################################### ######################################################################
# Training the GCN # Training the GCN
# ---------------- # ----------------
# #
# Training this GCN is similar to training other PyTorch neural networks. # Training this GCN is similar to training other PyTorch neural networks.
# #
def train(g, model): def train(g, model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
best_val_acc = 0 best_val_acc = 0
best_test_acc = 0 best_test_acc = 0
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
for e in range(100): for e in range(100):
# Forward # Forward
logits = model(g, features) logits = model(g, features)
...@@ -179,19 +182,24 @@ def train(g, model): ...@@ -179,19 +182,24 @@ def train(g, model):
optimizer.step() optimizer.step()
if e % 5 == 0: if e % 5 == 0:
print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format( print(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc)) "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
model = GCN(g.ndata['feat'].shape[1], 16, dataset.num_classes) e, loss, val_acc, best_val_acc, test_acc, best_test_acc
)
)
model = GCN(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model) train(g, model)
###################################################################### ######################################################################
# Training on GPU # Training on GPU
# --------------- # ---------------
# #
# Training on GPU requires to put both the model and the graph onto GPU # Training on GPU requires to put both the model and the graph onto GPU
# with the ``to`` method, similar to what you will do in PyTorch. # with the ``to`` method, similar to what you will do in PyTorch.
# #
# .. code:: python # .. code:: python
# #
# g = g.to('cuda') # g = g.to('cuda')
...@@ -203,7 +211,7 @@ train(g, model) ...@@ -203,7 +211,7 @@ train(g, model)
###################################################################### ######################################################################
# What’s next? # What’s next?
# ------------ # ------------
# #
# - :doc:`How does DGL represent a graph <2_dglgraph>`? # - :doc:`How does DGL represent a graph <2_dglgraph>`?
# - :doc:`Write your own GNN module <3_message_passing>`. # - :doc:`Write your own GNN module <3_message_passing>`.
# - :doc:`Link prediction (predicting existence of edges) on full # - :doc:`Link prediction (predicting existence of edges) on full
...@@ -213,7 +221,7 @@ train(g, model) ...@@ -213,7 +221,7 @@ train(g, model)
# - :ref:`The list of supported graph convolution # - :ref:`The list of supported graph convolution
# modules <apinn-pytorch>`. # modules <apinn-pytorch>`.
# - :ref:`The list of datasets provided by DGL <apidata>`. # - :ref:`The list of datasets provided by DGL <apidata>`.
# #
# Thumbnail credits: Stanford CS224W Notes # Thumbnail credits: Stanford CS224W Notes
......
...@@ -19,24 +19,28 @@ By the end of this tutorial you will be able to: ...@@ -19,24 +19,28 @@ By the end of this tutorial you will be able to:
###################################################################### ######################################################################
# DGL Graph Construction # DGL Graph Construction
# ---------------------- # ----------------------
# #
# DGL represents a directed graph as a ``DGLGraph`` object. You can # DGL represents a directed graph as a ``DGLGraph`` object. You can
# construct a graph by specifying the number of nodes in the graph as well # construct a graph by specifying the number of nodes in the graph as well
# as the list of source and destination nodes. Nodes in the graph have # as the list of source and destination nodes. Nodes in the graph have
# consecutive IDs starting from 0. # consecutive IDs starting from 0.
# #
# For instance, the following code constructs a directed star graph with 5 # For instance, the following code constructs a directed star graph with 5
# leaves. The center node's ID is 0. The edges go from the # leaves. The center node's ID is 0. The edges go from the
# center node to the leaves. # center node to the leaves.
# #
import dgl
import numpy as np import numpy as np
import torch import torch
import dgl
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6) g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]), num_nodes=6)
# Equivalently, PyTorch LongTensors also work. # Equivalently, PyTorch LongTensors also work.
g = dgl.graph((torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])), num_nodes=6) g = dgl.graph(
(torch.LongTensor([0, 0, 0, 0, 0]), torch.LongTensor([1, 2, 3, 4, 5])),
num_nodes=6,
)
# You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone. # You can omit the number of nodes argument if you can tell the number of nodes from the edge list alone.
g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])) g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))
...@@ -46,7 +50,7 @@ g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5])) ...@@ -46,7 +50,7 @@ g = dgl.graph(([0, 0, 0, 0, 0], [1, 2, 3, 4, 5]))
# Edges in the graph have consecutive IDs starting from 0, and are # Edges in the graph have consecutive IDs starting from 0, and are
# in the same order as the list of source and destination nodes during # in the same order as the list of source and destination nodes during
# creation. # creation.
# #
# Print the source and destination nodes of every edge. # Print the source and destination nodes of every edge.
print(g.edges()) print(g.edges())
...@@ -54,7 +58,7 @@ print(g.edges()) ...@@ -54,7 +58,7 @@ print(g.edges())
###################################################################### ######################################################################
# .. note:: # .. note::
# #
# ``DGLGraph``'s are always directed to best fit the computation # ``DGLGraph``'s are always directed to best fit the computation
# pattern of graph neural networks, where the messages sent # pattern of graph neural networks, where the messages sent
# from one node to the other are often different between both # from one node to the other are often different between both
...@@ -62,59 +66,59 @@ print(g.edges()) ...@@ -62,59 +66,59 @@ print(g.edges())
# treating it as a bidirectional graph. See `Graph # treating it as a bidirectional graph. See `Graph
# Transformations`_ for an example of making # Transformations`_ for an example of making
# a bidirectional graph. # a bidirectional graph.
# #
###################################################################### ######################################################################
# Assigning Node and Edge Features to Graph # Assigning Node and Edge Features to Graph
# ----------------------------------------- # -----------------------------------------
# #
# Many graph data contain attributes on nodes and edges. # Many graph data contain attributes on nodes and edges.
# Although the types of node and edge attributes can be arbitrary in real # Although the types of node and edge attributes can be arbitrary in real
# world, ``DGLGraph`` only accepts attributes stored in tensors (with # world, ``DGLGraph`` only accepts attributes stored in tensors (with
# numerical contents). Consequently, an attribute of all the nodes or # numerical contents). Consequently, an attribute of all the nodes or
# edges must have the same shape. In the context of deep learning, those # edges must have the same shape. In the context of deep learning, those
# attributes are often called *features*. # attributes are often called *features*.
# #
# You can assign and retrieve node and edge features via ``ndata`` and # You can assign and retrieve node and edge features via ``ndata`` and
# ``edata`` interface. # ``edata`` interface.
# #
# Assign a 3-dimensional node feature vector for each node. # Assign a 3-dimensional node feature vector for each node.
g.ndata['x'] = torch.randn(6, 3) g.ndata["x"] = torch.randn(6, 3)
# Assign a 4-dimensional edge feature vector for each edge. # Assign a 4-dimensional edge feature vector for each edge.
g.edata['a'] = torch.randn(5, 4) g.edata["a"] = torch.randn(5, 4)
# Assign a 5x4 node feature matrix for each node. Node and edge features in DGL can be multi-dimensional. # Assign a 5x4 node feature matrix for each node. Node and edge features in DGL can be multi-dimensional.
g.ndata['y'] = torch.randn(6, 5, 4) g.ndata["y"] = torch.randn(6, 5, 4)
print(g.edata['a']) print(g.edata["a"])
###################################################################### ######################################################################
# .. note:: # .. note::
# #
# The vast development of deep learning has provided us many # The vast development of deep learning has provided us many
# ways to encode various types of attributes into numerical features. # ways to encode various types of attributes into numerical features.
# Here are some general suggestions: # Here are some general suggestions:
# #
# - For categorical attributes (e.g. gender, occupation), consider # - For categorical attributes (e.g. gender, occupation), consider
# converting them to integers or one-hot encoding. # converting them to integers or one-hot encoding.
# - For variable length string contents (e.g. news article, quote), # - For variable length string contents (e.g. news article, quote),
# consider applying a language model. # consider applying a language model.
# - For images, consider applying a vision model such as CNNs. # - For images, consider applying a vision model such as CNNs.
# #
# You can find plenty of materials on how to encode such attributes # You can find plenty of materials on how to encode such attributes
# into a tensor in the `PyTorch Deep Learning # into a tensor in the `PyTorch Deep Learning
# Tutorials <https://pytorch.org/tutorials/>`__. # Tutorials <https://pytorch.org/tutorials/>`__.
# #
###################################################################### ######################################################################
# Querying Graph Structures # Querying Graph Structures
# ------------------------- # -------------------------
# #
# ``DGLGraph`` object provides various methods to query a graph structure. # ``DGLGraph`` object provides various methods to query a graph structure.
# #
print(g.num_nodes()) print(g.num_nodes())
print(g.num_edges()) print(g.num_edges())
...@@ -127,13 +131,13 @@ print(g.in_degrees(0)) ...@@ -127,13 +131,13 @@ print(g.in_degrees(0))
###################################################################### ######################################################################
# Graph Transformations # Graph Transformations
# --------------------- # ---------------------
# #
###################################################################### ######################################################################
# DGL provides many APIs to transform a graph to another such as # DGL provides many APIs to transform a graph to another such as
# extracting a subgraph: # extracting a subgraph:
# #
# Induce a subgraph from node 0, node 1 and node 3 from the original graph. # Induce a subgraph from node 0, node 1 and node 3 from the original graph.
sg1 = g.subgraph([0, 1, 3]) sg1 = g.subgraph([0, 1, 3])
...@@ -145,7 +149,7 @@ sg2 = g.edge_subgraph([0, 1, 3]) ...@@ -145,7 +149,7 @@ sg2 = g.edge_subgraph([0, 1, 3])
# You can obtain the node/edge mapping from the subgraph to the original # You can obtain the node/edge mapping from the subgraph to the original
# graph by looking into the node feature ``dgl.NID`` or edge feature # graph by looking into the node feature ``dgl.NID`` or edge feature
# ``dgl.EID`` in the new graph. # ``dgl.EID`` in the new graph.
# #
# The original IDs of each node in sg1 # The original IDs of each node in sg1
print(sg1.ndata[dgl.NID]) print(sg1.ndata[dgl.NID])
...@@ -163,24 +167,24 @@ print(sg2.edata[dgl.EID]) ...@@ -163,24 +167,24 @@ print(sg2.edata[dgl.EID])
# #
# The original node feature of each node in sg1 # The original node feature of each node in sg1
print(sg1.ndata['x']) print(sg1.ndata["x"])
# The original edge feature of each node in sg1 # The original edge feature of each node in sg1
print(sg1.edata['a']) print(sg1.edata["a"])
# The original node feature of each node in sg2 # The original node feature of each node in sg2
print(sg2.ndata['x']) print(sg2.ndata["x"])
# The original edge feature of each node in sg2 # The original edge feature of each node in sg2
print(sg2.edata['a']) print(sg2.edata["a"])
###################################################################### ######################################################################
# Another common transformation is to add a reverse edge for each edge in # Another common transformation is to add a reverse edge for each edge in
# the original graph with ``dgl.add_reverse_edges``. # the original graph with ``dgl.add_reverse_edges``.
# #
# .. note:: # .. note::
# #
# If you have an undirected graph, it is better to convert it # If you have an undirected graph, it is better to convert it
# into a bidirectional graph first via adding reverse edges. # into a bidirectional graph first via adding reverse edges.
# #
newg = dgl.add_reverse_edges(g) newg = dgl.add_reverse_edges(g)
print(newg.edges()) print(newg.edges())
...@@ -189,19 +193,19 @@ print(newg.edges()) ...@@ -189,19 +193,19 @@ print(newg.edges())
###################################################################### ######################################################################
# Loading and Saving Graphs # Loading and Saving Graphs
# ------------------------- # -------------------------
# #
# You can save a graph or a list of graphs via ``dgl.save_graphs`` and # You can save a graph or a list of graphs via ``dgl.save_graphs`` and
# load them back with ``dgl.load_graphs``. # load them back with ``dgl.load_graphs``.
# #
# Save graphs # Save graphs
dgl.save_graphs('graph.dgl', g) dgl.save_graphs("graph.dgl", g)
dgl.save_graphs('graphs.dgl', [g, sg1, sg2]) dgl.save_graphs("graphs.dgl", [g, sg1, sg2])
# Load graphs # Load graphs
(g,), _ = dgl.load_graphs('graph.dgl') (g,), _ = dgl.load_graphs("graph.dgl")
print(g) print(g)
(g, sg1, sg2), _ = dgl.load_graphs('graphs.dgl') (g, sg1, sg2), _ = dgl.load_graphs("graphs.dgl")
print(g) print(g)
print(sg1) print(sg1)
print(sg2) print(sg2)
...@@ -210,7 +214,7 @@ print(sg2) ...@@ -210,7 +214,7 @@ print(sg2)
###################################################################### ######################################################################
# What’s next? # What’s next?
# ------------ # ------------
# #
# - See # - See
# :ref:`here <apigraph-querying-graph-structure>` # :ref:`here <apigraph-querying-graph-structure>`
# for a list of graph structure query APIs. # for a list of graph structure query APIs.
...@@ -223,7 +227,7 @@ print(sg2) ...@@ -223,7 +227,7 @@ print(sg2)
# - API reference of :func:`dgl.save_graphs` # - API reference of :func:`dgl.save_graphs`
# and # and
# :func:`dgl.load_graphs` # :func:`dgl.load_graphs`
# #
# Thumbnail credits: Wikipedia # Thumbnail credits: Wikipedia
......
...@@ -18,73 +18,73 @@ GNN for node classification <1_introduction>`. ...@@ -18,73 +18,73 @@ GNN for node classification <1_introduction>`.
""" """
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
###################################################################### ######################################################################
# Message passing and GNNs # Message passing and GNNs
# ------------------------ # ------------------------
# #
# DGL follows the *message passing paradigm* inspired by the Message # DGL follows the *message passing paradigm* inspired by the Message
# Passing Neural Network proposed by `Gilmer et # Passing Neural Network proposed by `Gilmer et
# al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many # al. <https://arxiv.org/abs/1704.01212>`__ Essentially, they found many
# GNN models can fit into the following framework: # GNN models can fit into the following framework:
# #
# .. math:: # .. math::
# #
# #
# m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) # m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)
# #
# .. math:: # .. math::
# #
# #
# m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} # m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}
# #
# .. math:: # .. math::
# #
# #
# h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right) # h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
# #
# where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the # where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the
# *reduce function* and :math:`U^{(l)}` the *update function*. Note that # *reduce function* and :math:`U^{(l)}` the *update function*. Note that
# :math:`\sum` here can represent any function and is not necessarily a # :math:`\sum` here can represent any function and is not necessarily a
# summation. # summation.
# #
###################################################################### ######################################################################
# For example, the `GraphSAGE convolution (Hamilton et al., # For example, the `GraphSAGE convolution (Hamilton et al.,
# 2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__ # 2017) <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__
# takes the following mathematical form: # takes the following mathematical form:
# #
# .. math:: # .. math::
# #
# #
# h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\} # h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\}
# #
# .. math:: # .. math::
# #
# #
# h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) # h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right)
# #
# You can see that message passing is directional: the message sent from # You can see that message passing is directional: the message sent from
# one node :math:`u` to other node :math:`v` is not necessarily the same # one node :math:`u` to other node :math:`v` is not necessarily the same
# as the other message sent from node :math:`v` to node :math:`u` in the # as the other message sent from node :math:`v` to node :math:`u` in the
# opposite direction. # opposite direction.
# #
# Although DGL has builtin support of GraphSAGE via # Although DGL has builtin support of GraphSAGE via
# :class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`, # :class:`dgl.nn.SAGEConv <dgl.nn.pytorch.SAGEConv>`,
# here is how you can implement GraphSAGE convolution in DGL by your own. # here is how you can implement GraphSAGE convolution in DGL by your own.
# #
import dgl.function as fn
class SAGEConv(nn.Module): class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model. """Graph convolution module used by the GraphSAGE model.
Parameters Parameters
---------- ----------
in_feat : int in_feat : int
...@@ -92,14 +92,15 @@ class SAGEConv(nn.Module): ...@@ -92,14 +92,15 @@ class SAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__() super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat) self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h): def forward(self, g, h):
"""Forward computation """Forward computation
Parameters Parameters
---------- ----------
g : Graph g : Graph
...@@ -108,10 +109,13 @@ class SAGEConv(nn.Module): ...@@ -108,10 +109,13 @@ class SAGEConv(nn.Module):
The input node feature. The input node feature.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
# update_all is a message passing API. # update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.copy_u("h", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h, h_N], dim=1) h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
...@@ -132,7 +136,7 @@ class SAGEConv(nn.Module): ...@@ -132,7 +136,7 @@ class SAGEConv(nn.Module):
# #
# * ``update_all`` tells DGL to trigger the # * ``update_all`` tells DGL to trigger the
# message and reduce functions for all the nodes and edges. # message and reduce functions for all the nodes and edges.
# #
###################################################################### ######################################################################
...@@ -140,12 +144,13 @@ class SAGEConv(nn.Module): ...@@ -140,12 +144,13 @@ class SAGEConv(nn.Module):
# a multi-layer GraphSAGE network. # a multi-layer GraphSAGE network.
# #
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats) self.conv1 = SAGEConv(in_feats, h_feats)
self.conv2 = SAGEConv(h_feats, num_classes) self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, g, in_feat): def forward(self, g, in_feat):
h = self.conv1(g, in_feat) h = self.conv1(g, in_feat)
h = F.relu(h) h = F.relu(h)
...@@ -158,24 +163,25 @@ class Model(nn.Module): ...@@ -158,24 +163,25 @@ class Model(nn.Module):
# ~~~~~~~~~~~~~ # ~~~~~~~~~~~~~
# The following code for data loading and training loop is directly copied # The following code for data loading and training loop is directly copied
# from the introduction tutorial. # from the introduction tutorial.
# #
import dgl.data import dgl.data
dataset = dgl.data.CoraGraphDataset() dataset = dgl.data.CoraGraphDataset()
g = dataset[0] g = dataset[0]
def train(g, model): def train(g, model):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01) optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
all_logits = [] all_logits = []
best_val_acc = 0 best_val_acc = 0
best_test_acc = 0 best_test_acc = 0
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
for e in range(200): for e in range(200):
# Forward # Forward
logits = model(g, features) logits = model(g, features)
...@@ -205,21 +211,25 @@ def train(g, model): ...@@ -205,21 +211,25 @@ def train(g, model):
all_logits.append(logits.detach()) all_logits.append(logits.detach())
if e % 5 == 0: if e % 5 == 0:
print('In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})'.format( print(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc)) "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format(
e, loss, val_acc, best_val_acc, test_acc, best_test_acc
)
)
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes) model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model) train(g, model)
###################################################################### ######################################################################
# More customization # More customization
# ------------------ # ------------------
# #
# In DGL, we provide many built-in message and reduce functions under the # In DGL, we provide many built-in message and reduce functions under the
# ``dgl.function`` package. You can find more details in :ref:`the API # ``dgl.function`` package. You can find more details in :ref:`the API
# doc <apifunction>`. # doc <apifunction>`.
# #
###################################################################### ######################################################################
...@@ -228,11 +238,12 @@ train(g, model) ...@@ -228,11 +238,12 @@ train(g, model)
# neighbor representations using a weighted average. Note that ``edata`` # neighbor representations using a weighted average. Note that ``edata``
# member can hold edge features which can also take part in message # member can hold edge features which can also take part in message
# passing. # passing.
# #
class WeightedSAGEConv(nn.Module): class WeightedSAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model with edge weights. """Graph convolution module used by the GraphSAGE model with edge weights.
Parameters Parameters
---------- ----------
in_feat : int in_feat : int
...@@ -240,14 +251,15 @@ class WeightedSAGEConv(nn.Module): ...@@ -240,14 +251,15 @@ class WeightedSAGEConv(nn.Module):
out_feat : int out_feat : int
Output feature size. Output feature size.
""" """
def __init__(self, in_feat, out_feat): def __init__(self, in_feat, out_feat):
super(WeightedSAGEConv, self).__init__() super(WeightedSAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output. # A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat) self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h, w): def forward(self, g, h, w):
"""Forward computation """Forward computation
Parameters Parameters
---------- ----------
g : Graph g : Graph
...@@ -258,10 +270,13 @@ class WeightedSAGEConv(nn.Module): ...@@ -258,10 +270,13 @@ class WeightedSAGEConv(nn.Module):
The edge weight. The edge weight.
""" """
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
g.edata['w'] = w g.edata["w"] = w
g.update_all(message_func=fn.u_mul_e('h', 'w', 'm'), reduce_func=fn.mean('m', 'h_N')) g.update_all(
h_N = g.ndata['h_N'] message_func=fn.u_mul_e("h", "w", "m"),
reduce_func=fn.mean("m", "h_N"),
)
h_N = g.ndata["h_N"]
h_total = torch.cat([h, h_N], dim=1) h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total) return self.linear(h_total)
...@@ -270,88 +285,92 @@ class WeightedSAGEConv(nn.Module): ...@@ -270,88 +285,92 @@ class WeightedSAGEConv(nn.Module):
# Because the graph in this dataset does not have edge weights, we # Because the graph in this dataset does not have edge weights, we
# manually assign all edge weights to one in the ``forward()`` function of # manually assign all edge weights to one in the ``forward()`` function of
# the model. You can replace it with your own edge weights. # the model. You can replace it with your own edge weights.
# #
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv1 = WeightedSAGEConv(in_feats, h_feats)
self.conv2 = WeightedSAGEConv(h_feats, num_classes) self.conv2 = WeightedSAGEConv(h_feats, num_classes)
def forward(self, g, in_feat): def forward(self, g, in_feat):
h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device))
h = F.relu(h) h = F.relu(h)
h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device))
return h return h
model = Model(g.ndata['feat'].shape[1], 16, dataset.num_classes)
model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes)
train(g, model) train(g, model)
###################################################################### ######################################################################
# Even more customization by user-defined function # Even more customization by user-defined function
# ------------------------------------------------ # ------------------------------------------------
# #
# DGL allows user-defined message and reduce function for the maximal # DGL allows user-defined message and reduce function for the maximal
# expressiveness. Here is a user-defined message function that is # expressiveness. Here is a user-defined message function that is
# equivalent to ``fn.u_mul_e('h', 'w', 'm')``. # equivalent to ``fn.u_mul_e('h', 'w', 'm')``.
# #
def u_mul_e_udf(edges): def u_mul_e_udf(edges):
return {'m' : edges.src['h'] * edges.data['w']} return {"m": edges.src["h"] * edges.data["w"]}
###################################################################### ######################################################################
# ``edges`` has three members: ``src``, ``data`` and ``dst``, representing # ``edges`` has three members: ``src``, ``data`` and ``dst``, representing
# the source node feature, edge feature, and destination node feature for # the source node feature, edge feature, and destination node feature for
# all edges. # all edges.
# #
###################################################################### ######################################################################
# You can also write your own reduce function. For example, the following # You can also write your own reduce function. For example, the following
# is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages # is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages
# the incoming messages: # the incoming messages:
# #
def mean_udf(nodes): def mean_udf(nodes):
return {'h_N': nodes.mailbox['m'].mean(1)} return {"h_N": nodes.mailbox["m"].mean(1)}
###################################################################### ######################################################################
# In short, DGL will group the nodes by their in-degrees, and for each # In short, DGL will group the nodes by their in-degrees, and for each
# group DGL stacks the incoming messages along the second dimension. You # group DGL stacks the incoming messages along the second dimension. You
# can then perform a reduction along the second dimension to aggregate # can then perform a reduction along the second dimension to aggregate
# messages. # messages.
# #
# For more details on customizing message and reduce function with # For more details on customizing message and reduce function with
# user-defined function, please refer to the :ref:`API # user-defined function, please refer to the :ref:`API
# reference <apiudf>`. # reference <apiudf>`.
# #
###################################################################### ######################################################################
# Best practice of writing custom GNN modules # Best practice of writing custom GNN modules
# ------------------------------------------- # -------------------------------------------
# #
# DGL recommends the following practice ranked by preference: # DGL recommends the following practice ranked by preference:
# #
# - Use ``dgl.nn`` modules. # - Use ``dgl.nn`` modules.
# - Use ``dgl.nn.functional`` functions which contain lower-level complex # - Use ``dgl.nn.functional`` functions which contain lower-level complex
# operations such as computing a softmax for each node over incoming # operations such as computing a softmax for each node over incoming
# edges. # edges.
# - Use ``update_all`` with builtin message and reduce functions. # - Use ``update_all`` with builtin message and reduce functions.
# - Use user-defined message or reduce functions. # - Use user-defined message or reduce functions.
# #
###################################################################### ######################################################################
# What’s next? # What’s next?
# ------------ # ------------
# #
# - :ref:`Writing Efficient Message Passing # - :ref:`Writing Efficient Message Passing
# Code <guide-message-passing-efficient>`. # Code <guide-message-passing-efficient>`.
# #
# Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018 # Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018
......
...@@ -17,14 +17,16 @@ By the end of this tutorial you will be able to ...@@ -17,14 +17,16 @@ By the end of this tutorial you will be able to
""" """
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
import itertools import itertools
import numpy as np import numpy as np
import scipy.sparse as sp import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.data
###################################################################### ######################################################################
# Overview of Link Prediction with GNN # Overview of Link Prediction with GNN
...@@ -67,7 +69,6 @@ import scipy.sparse as sp ...@@ -67,7 +69,6 @@ import scipy.sparse as sp
# first loads the Cora dataset. # first loads the Cora dataset.
# #
import dgl.data
dataset = dgl.data.CoraGraphDataset() dataset = dgl.data.CoraGraphDataset()
g = dataset[0] g = dataset[0]
...@@ -98,8 +99,14 @@ adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes()) ...@@ -98,8 +99,14 @@ adj_neg = 1 - adj.todense() - np.eye(g.number_of_nodes())
neg_u, neg_v = np.where(adj_neg != 0) neg_u, neg_v = np.where(adj_neg != 0)
neg_eids = np.random.choice(len(neg_u), g.number_of_edges()) neg_eids = np.random.choice(len(neg_u), g.number_of_edges())
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] test_neg_u, test_neg_v = (
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]] neg_u[neg_eids[:test_size]],
neg_v[neg_eids[:test_size]],
)
train_neg_u, train_neg_v = (
neg_u[neg_eids[test_size:]],
neg_v[neg_eids[test_size:]],
)
###################################################################### ######################################################################
...@@ -129,14 +136,15 @@ train_g = dgl.remove_edges(g, eids[:test_size]) ...@@ -129,14 +136,15 @@ train_g = dgl.remove_edges(g, eids[:test_size])
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
# ----------- 2. create model -------------- # # ----------- 2. create model -------------- #
# build a two-layer GraphSAGE model # build a two-layer GraphSAGE model
class GraphSAGE(nn.Module): class GraphSAGE(nn.Module):
def __init__(self, in_feats, h_feats): def __init__(self, in_feats, h_feats):
super(GraphSAGE, self).__init__() super(GraphSAGE, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, 'mean') self.conv1 = SAGEConv(in_feats, h_feats, "mean")
self.conv2 = SAGEConv(h_feats, h_feats, 'mean') self.conv2 = SAGEConv(h_feats, h_feats, "mean")
def forward(self, g, in_feat): def forward(self, g, in_feat):
h = self.conv1(g, in_feat) h = self.conv1(g, in_feat)
h = F.relu(h) h = F.relu(h)
...@@ -180,8 +188,12 @@ class GraphSAGE(nn.Module): ...@@ -180,8 +188,12 @@ class GraphSAGE(nn.Module):
# for the training set and the test set respectively. # for the training set and the test set respectively.
# #
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()) train_pos_g = dgl.graph(
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()) (train_pos_u, train_pos_v), num_nodes=g.number_of_nodes()
)
train_neg_g = dgl.graph(
(train_neg_u, train_neg_v), num_nodes=g.number_of_nodes()
)
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes()) test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())
...@@ -201,15 +213,16 @@ test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes()) ...@@ -201,15 +213,16 @@ test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())
import dgl.function as fn import dgl.function as fn
class DotPredictor(nn.Module): class DotPredictor(nn.Module):
def forward(self, g, h): def forward(self, g, h):
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
# Compute a new edge feature named 'score' by a dot-product between the # Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'. # source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v('h', 'h', 'score')) g.apply_edges(fn.u_dot_v("h", "h", "score"))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it. # u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata['score'][:, 0] return g.edata["score"][:, 0]
###################################################################### ######################################################################
...@@ -218,6 +231,7 @@ class DotPredictor(nn.Module): ...@@ -218,6 +231,7 @@ class DotPredictor(nn.Module):
# by concatenating the incident nodes’ features and passing it to an MLP. # by concatenating the incident nodes’ features and passing it to an MLP.
# #
class MLPPredictor(nn.Module): class MLPPredictor(nn.Module):
def __init__(self, h_feats): def __init__(self, h_feats):
super().__init__() super().__init__()
...@@ -241,14 +255,14 @@ class MLPPredictor(nn.Module): ...@@ -241,14 +255,14 @@ class MLPPredictor(nn.Module):
dict dict
A dictionary of new edge features. A dictionary of new edge features.
""" """
h = torch.cat([edges.src['h'], edges.dst['h']], 1) h = torch.cat([edges.src["h"], edges.dst["h"]], 1)
return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)} return {"score": self.W2(F.relu(self.W1(h))).squeeze(1)}
def forward(self, g, h): def forward(self, g, h):
with g.local_scope(): with g.local_scope():
g.ndata['h'] = h g.ndata["h"] = h
g.apply_edges(self.apply_edges) g.apply_edges(self.apply_edges)
return g.edata['score'] return g.edata["score"]
###################################################################### ######################################################################
...@@ -284,20 +298,25 @@ class MLPPredictor(nn.Module): ...@@ -284,20 +298,25 @@ class MLPPredictor(nn.Module):
# The evaluation metric in this tutorial is AUC. # The evaluation metric in this tutorial is AUC.
# #
model = GraphSAGE(train_g.ndata['feat'].shape[1], 16) model = GraphSAGE(train_g.ndata["feat"].shape[1], 16)
# You can replace DotPredictor with MLPPredictor. # You can replace DotPredictor with MLPPredictor.
#pred = MLPPredictor(16) # pred = MLPPredictor(16)
pred = DotPredictor() pred = DotPredictor()
def compute_loss(pos_score, neg_score): def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]) scores = torch.cat([pos_score, neg_score])
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]) labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
)
return F.binary_cross_entropy_with_logits(scores, labels) return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score): def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy() scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat( labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy() [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]
).numpy()
return roc_auc_score(labels, scores) return roc_auc_score(labels, scores)
...@@ -313,31 +332,34 @@ def compute_auc(pos_score, neg_score): ...@@ -313,31 +332,34 @@ def compute_auc(pos_score, neg_score):
# ----------- 3. set up loss and optimizer -------------- # # ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop # in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=0.01) optimizer = torch.optim.Adam(
itertools.chain(model.parameters(), pred.parameters()), lr=0.01
)
# ----------- 4. training -------------------------------- # # ----------- 4. training -------------------------------- #
all_logits = [] all_logits = []
for e in range(100): for e in range(100):
# forward # forward
h = model(train_g, train_g.ndata['feat']) h = model(train_g, train_g.ndata["feat"])
pos_score = pred(train_pos_g, h) pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h) neg_score = pred(train_neg_g, h)
loss = compute_loss(pos_score, neg_score) loss = compute_loss(pos_score, neg_score)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if e % 5 == 0: if e % 5 == 0:
print('In epoch {}, loss: {}'.format(e, loss)) print("In epoch {}, loss: {}".format(e, loss))
# ----------- 5. check results ------------------------ # # ----------- 5. check results ------------------------ #
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
with torch.no_grad(): with torch.no_grad():
pos_score = pred(test_pos_g, h) pos_score = pred(test_pos_g, h)
neg_score = pred(test_neg_g, h) neg_score = pred(test_neg_g, h)
print('AUC', compute_auc(pos_score, neg_score)) print("AUC", compute_auc(pos_score, neg_score))
# Thumbnail credits: Link Prediction with Neo4j, Mark Needham # Thumbnail credits: Link Prediction with Neo4j, Mark Needham
......
...@@ -13,32 +13,32 @@ By the end of this tutorial, you will be able to ...@@ -13,32 +13,32 @@ By the end of this tutorial, you will be able to
(Time estimate: 18 minutes) (Time estimate: 18 minutes)
""" """
import dgl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.data
###################################################################### ######################################################################
# Overview of Graph Classification with GNN # Overview of Graph Classification with GNN
# ----------------------------------------- # -----------------------------------------
# #
# Graph classification or regression requires a model to predict certain # Graph classification or regression requires a model to predict certain
# graph-level properties of a single graph given its node and edge # graph-level properties of a single graph given its node and edge
# features. Molecular property prediction is one particular application. # features. Molecular property prediction is one particular application.
# #
# This tutorial shows how to train a graph classification model for a # This tutorial shows how to train a graph classification model for a
# small dataset from the paper `How Powerful Are Graph Neural # small dataset from the paper `How Powerful Are Graph Neural
# Networks <https://arxiv.org/abs/1810.00826>`__. # Networks <https://arxiv.org/abs/1810.00826>`__.
# #
# Loading Data # Loading Data
# ------------ # ------------
# #
import dgl.data
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes. # Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes.
dataset = dgl.data.GINDataset('PROTEINS', self_loop=True) dataset = dgl.data.GINDataset("PROTEINS", self_loop=True)
###################################################################### ######################################################################
...@@ -46,33 +46,34 @@ dataset = dgl.data.GINDataset('PROTEINS', self_loop=True) ...@@ -46,33 +46,34 @@ dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
# label. One can see the node feature dimensionality and the number of # label. One can see the node feature dimensionality and the number of
# possible graph categories of ``GINDataset`` objects in ``dim_nfeats`` # possible graph categories of ``GINDataset`` objects in ``dim_nfeats``
# and ``gclasses`` attributes. # and ``gclasses`` attributes.
# #
print('Node feature dimensionality:', dataset.dim_nfeats) print("Node feature dimensionality:", dataset.dim_nfeats)
print('Number of graph categories:', dataset.gclasses) print("Number of graph categories:", dataset.gclasses)
###################################################################### ######################################################################
# Defining Data Loader # Defining Data Loader
# -------------------- # --------------------
# #
# A graph classification dataset usually contains two types of elements: a # A graph classification dataset usually contains two types of elements: a
# set of graphs, and their graph-level labels. Similar to an image # set of graphs, and their graph-level labels. Similar to an image
# classification task, when the dataset is large enough, we need to train # classification task, when the dataset is large enough, we need to train
# with mini-batches. When you train a model for image classification or # with mini-batches. When you train a model for image classification or
# language modeling, you will use a ``DataLoader`` to iterate over the # language modeling, you will use a ``DataLoader`` to iterate over the
# dataset. In DGL, you can use the ``GraphDataLoader``. # dataset. In DGL, you can use the ``GraphDataLoader``.
# #
# You can also use various dataset samplers provided in # You can also use various dataset samplers provided in
# `torch.utils.data.sampler <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__. # `torch.utils.data.sampler <https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler>`__.
# For example, this tutorial creates a training ``GraphDataLoader`` and # For example, this tutorial creates a training ``GraphDataLoader`` and
# test ``GraphDataLoader``, using ``SubsetRandomSampler`` to tell PyTorch # test ``GraphDataLoader``, using ``SubsetRandomSampler`` to tell PyTorch
# to sample from only a subset of the dataset. # to sample from only a subset of the dataset.
# #
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data.sampler import SubsetRandomSampler
from dgl.dataloading import GraphDataLoader
num_examples = len(dataset) num_examples = len(dataset)
num_train = int(num_examples * 0.8) num_train = int(num_examples * 0.8)
...@@ -80,15 +81,17 @@ train_sampler = SubsetRandomSampler(torch.arange(num_train)) ...@@ -80,15 +81,17 @@ train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples)) test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))
train_dataloader = GraphDataLoader( train_dataloader = GraphDataLoader(
dataset, sampler=train_sampler, batch_size=5, drop_last=False) dataset, sampler=train_sampler, batch_size=5, drop_last=False
)
test_dataloader = GraphDataLoader( test_dataloader = GraphDataLoader(
dataset, sampler=test_sampler, batch_size=5, drop_last=False) dataset, sampler=test_sampler, batch_size=5, drop_last=False
)
###################################################################### ######################################################################
# You can try to iterate over the created ``GraphDataLoader`` and see what it # You can try to iterate over the created ``GraphDataLoader`` and see what it
# gives: # gives:
# #
it = iter(train_dataloader) it = iter(train_dataloader)
batch = next(it) batch = next(it)
...@@ -101,10 +104,10 @@ print(batch) ...@@ -101,10 +104,10 @@ print(batch)
# first element is the batched graph, and the second element is simply a # first element is the batched graph, and the second element is simply a
# label vector representing the category of each graph in the mini-batch. # label vector representing the category of each graph in the mini-batch.
# Next, we’ll talked about the batched graph. # Next, we’ll talked about the batched graph.
# #
# A Batched Graph in DGL # A Batched Graph in DGL
# ---------------------- # ----------------------
# #
# In each mini-batch, the sampled graphs are combined into a single bigger # In each mini-batch, the sampled graphs are combined into a single bigger
# batched graph via ``dgl.batch``. The single bigger batched graph merges # batched graph via ``dgl.batch``. The single bigger batched graph merges
# all original graphs as separately connected components, with the node # all original graphs as separately connected components, with the node
...@@ -114,29 +117,35 @@ print(batch) ...@@ -114,29 +117,35 @@ print(batch)
# `here <2_dglgraph.ipynb>`__). It however contains the information # `here <2_dglgraph.ipynb>`__). It however contains the information
# necessary for recovering the original graphs, such as the number of # necessary for recovering the original graphs, such as the number of
# nodes and edges of each graph element. # nodes and edges of each graph element.
# #
batched_graph, labels = batch batched_graph, labels = batch
print('Number of nodes for each graph element in the batch:', batched_graph.batch_num_nodes()) print(
print('Number of edges for each graph element in the batch:', batched_graph.batch_num_edges()) "Number of nodes for each graph element in the batch:",
batched_graph.batch_num_nodes(),
)
print(
"Number of edges for each graph element in the batch:",
batched_graph.batch_num_edges(),
)
# Recover the original graph elements from the minibatch # Recover the original graph elements from the minibatch
graphs = dgl.unbatch(batched_graph) graphs = dgl.unbatch(batched_graph)
print('The original graphs in the minibatch:') print("The original graphs in the minibatch:")
print(graphs) print(graphs)
###################################################################### ######################################################################
# Define Model # Define Model
# ------------ # ------------
# #
# This tutorial will build a two-layer `Graph Convolutional Network # This tutorial will build a two-layer `Graph Convolutional Network
# (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each of # (GCN) <http://tkipf.github.io/graph-convolutional-networks/>`__. Each of
# its layer computes new node representations by aggregating neighbor # its layer computes new node representations by aggregating neighbor
# information. If you have gone through the # information. If you have gone through the
# :doc:`introduction <1_introduction>`, you will notice two # :doc:`introduction <1_introduction>`, you will notice two
# differences: # differences:
# #
# - Since the task is to predict a single category for the *entire graph* # - Since the task is to predict a single category for the *entire graph*
# instead of for every node, you will need to aggregate the # instead of for every node, you will need to aggregate the
# representations of all the nodes and potentially the edges to form a # representations of all the nodes and potentially the edges to form a
...@@ -148,33 +157,33 @@ print(graphs) ...@@ -148,33 +157,33 @@ print(graphs)
# ``GraphDataLoader``. The readout functions provided by DGL can handle # ``GraphDataLoader``. The readout functions provided by DGL can handle
# batched graphs so that they will return one representation for each # batched graphs so that they will return one representation for each
# minibatch element. # minibatch element.
# #
from dgl.nn import GraphConv from dgl.nn import GraphConv
class GCN(nn.Module): class GCN(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(GCN, self).__init__() super(GCN, self).__init__()
self.conv1 = GraphConv(in_feats, h_feats) self.conv1 = GraphConv(in_feats, h_feats)
self.conv2 = GraphConv(h_feats, num_classes) self.conv2 = GraphConv(h_feats, num_classes)
def forward(self, g, in_feat): def forward(self, g, in_feat):
h = self.conv1(g, in_feat) h = self.conv1(g, in_feat)
h = F.relu(h) h = F.relu(h)
h = self.conv2(g, h) h = self.conv2(g, h)
g.ndata['h'] = h g.ndata["h"] = h
return dgl.mean_nodes(g, 'h') return dgl.mean_nodes(g, "h")
###################################################################### ######################################################################
# Training Loop # Training Loop
# ------------- # -------------
# #
# The training loop iterates over the training set with the # The training loop iterates over the training set with the
# ``GraphDataLoader`` object and computes the gradients, just like # ``GraphDataLoader`` object and computes the gradients, just like
# image classification or language modeling. # image classification or language modeling.
# #
# Create the model with given dimensions # Create the model with given dimensions
model = GCN(dataset.dim_nfeats, 16, dataset.gclasses) model = GCN(dataset.dim_nfeats, 16, dataset.gclasses)
...@@ -182,7 +191,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=0.01) ...@@ -182,7 +191,7 @@ optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
for epoch in range(20): for epoch in range(20):
for batched_graph, labels in train_dataloader: for batched_graph, labels in train_dataloader:
pred = model(batched_graph, batched_graph.ndata['attr'].float()) pred = model(batched_graph, batched_graph.ndata["attr"].float())
loss = F.cross_entropy(pred, labels) loss = F.cross_entropy(pred, labels)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -191,21 +200,21 @@ for epoch in range(20): ...@@ -191,21 +200,21 @@ for epoch in range(20):
num_correct = 0 num_correct = 0
num_tests = 0 num_tests = 0
for batched_graph, labels in test_dataloader: for batched_graph, labels in test_dataloader:
pred = model(batched_graph, batched_graph.ndata['attr'].float()) pred = model(batched_graph, batched_graph.ndata["attr"].float())
num_correct += (pred.argmax(1) == labels).sum().item() num_correct += (pred.argmax(1) == labels).sum().item()
num_tests += len(labels) num_tests += len(labels)
print('Test accuracy:', num_correct / num_tests) print("Test accuracy:", num_correct / num_tests)
###################################################################### ######################################################################
# What’s next # What’s next
# ----------- # -----------
# #
# - See `GIN # - See `GIN
# example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__ # example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gin>`__
# for an end-to-end graph classification model. # for an end-to-end graph classification model.
# #
# Thumbnail credits: DGL # Thumbnail credits: DGL
......
...@@ -18,25 +18,25 @@ By the end of this tutorial, you will be able to ...@@ -18,25 +18,25 @@ By the end of this tutorial, you will be able to
###################################################################### ######################################################################
# ``DGLDataset`` Object Overview # ``DGLDataset`` Object Overview
# ------------------------------ # ------------------------------
# #
# Your custom graph dataset should inherit the ``dgl.data.DGLDataset`` # Your custom graph dataset should inherit the ``dgl.data.DGLDataset``
# class and implement the following methods: # class and implement the following methods:
# #
# - ``__getitem__(self, i)``: retrieve the ``i``-th example of the # - ``__getitem__(self, i)``: retrieve the ``i``-th example of the
# dataset. An example often contains a single DGL graph, and # dataset. An example often contains a single DGL graph, and
# occasionally its label. # occasionally its label.
# - ``__len__(self)``: the number of examples in the dataset. # - ``__len__(self)``: the number of examples in the dataset.
# - ``process(self)``: load and process raw data from disk. # - ``process(self)``: load and process raw data from disk.
# #
###################################################################### ######################################################################
# Creating a Dataset for Node Classification or Link Prediction from CSV # Creating a Dataset for Node Classification or Link Prediction from CSV
# ---------------------------------------------------------------------- # ----------------------------------------------------------------------
# #
# A node classification dataset often consists of a single graph, as well # A node classification dataset often consists of a single graph, as well
# as its node and edge features. # as its node and edge features.
# #
# This tutorial takes a small dataset based on `Zachary’s Karate Club # This tutorial takes a small dataset based on `Zachary’s Karate Club
# network <https://en.wikipedia.org/wiki/Zachary%27s_karate_club>`__. It # network <https://en.wikipedia.org/wiki/Zachary%27s_karate_club>`__. It
# contains # contains
...@@ -49,16 +49,21 @@ By the end of this tutorial, you will be able to ...@@ -49,16 +49,21 @@ By the end of this tutorial, you will be able to
# #
import urllib.request import urllib.request
import pandas as pd import pandas as pd
urllib.request.urlretrieve( urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/members.csv', './members.csv') "https://data.dgl.ai/tutorial/dataset/members.csv", "./members.csv"
)
urllib.request.urlretrieve( urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/interactions.csv', './interactions.csv') "https://data.dgl.ai/tutorial/dataset/interactions.csv",
"./interactions.csv",
)
members = pd.read_csv('./members.csv') members = pd.read_csv("./members.csv")
members.head() members.head()
interactions = pd.read_csv('./interactions.csv') interactions = pd.read_csv("./interactions.csv")
interactions.head() interactions.head()
...@@ -66,45 +71,52 @@ interactions.head() ...@@ -66,45 +71,52 @@ interactions.head()
# This tutorial treats the members as nodes and interactions as edges. It # This tutorial treats the members as nodes and interactions as edges. It
# takes age as a numeric feature of the nodes, affiliated club as the label # takes age as a numeric feature of the nodes, affiliated club as the label
# of the nodes, and edge weight as a numeric feature of the edges. # of the nodes, and edge weight as a numeric feature of the edges.
# #
# .. note:: # .. note::
# #
# The original Zachary’s Karate Club network does not have # The original Zachary’s Karate Club network does not have
# member ages. The ages in this tutorial are generated synthetically # member ages. The ages in this tutorial are generated synthetically
# for demonstrating how to add node features into the graph for dataset # for demonstrating how to add node features into the graph for dataset
# creation. # creation.
# #
# .. note:: # .. note::
# #
# In practice, taking age directly as a numeric feature may # In practice, taking age directly as a numeric feature may
# not work well in machine learning; strategies like binning or # not work well in machine learning; strategies like binning or
# normalizing the feature would work better. This tutorial directly # normalizing the feature would work better. This tutorial directly
# takes the values as-is for simplicity. # takes the values as-is for simplicity.
# #
import os
import torch
import dgl import dgl
from dgl.data import DGLDataset from dgl.data import DGLDataset
import torch
import os
class KarateClubDataset(DGLDataset): class KarateClubDataset(DGLDataset):
def __init__(self): def __init__(self):
super().__init__(name='karate_club') super().__init__(name="karate_club")
def process(self): def process(self):
nodes_data = pd.read_csv('./members.csv') nodes_data = pd.read_csv("./members.csv")
edges_data = pd.read_csv('./interactions.csv') edges_data = pd.read_csv("./interactions.csv")
node_features = torch.from_numpy(nodes_data['Age'].to_numpy()) node_features = torch.from_numpy(nodes_data["Age"].to_numpy())
node_labels = torch.from_numpy(nodes_data['Club'].astype('category').cat.codes.to_numpy()) node_labels = torch.from_numpy(
edge_features = torch.from_numpy(edges_data['Weight'].to_numpy()) nodes_data["Club"].astype("category").cat.codes.to_numpy()
edges_src = torch.from_numpy(edges_data['Src'].to_numpy()) )
edges_dst = torch.from_numpy(edges_data['Dst'].to_numpy()) edge_features = torch.from_numpy(edges_data["Weight"].to_numpy())
edges_src = torch.from_numpy(edges_data["Src"].to_numpy())
self.graph = dgl.graph((edges_src, edges_dst), num_nodes=nodes_data.shape[0]) edges_dst = torch.from_numpy(edges_data["Dst"].to_numpy())
self.graph.ndata['feat'] = node_features
self.graph.ndata['label'] = node_labels self.graph = dgl.graph(
self.graph.edata['weight'] = edge_features (edges_src, edges_dst), num_nodes=nodes_data.shape[0]
)
self.graph.ndata["feat"] = node_features
self.graph.ndata["label"] = node_labels
self.graph.edata["weight"] = edge_features
# If your dataset is a node classification dataset, you will need to assign # If your dataset is a node classification dataset, you will need to assign
# masks indicating whether a node belongs to training, validation, and test set. # masks indicating whether a node belongs to training, validation, and test set.
n_nodes = nodes_data.shape[0] n_nodes = nodes_data.shape[0]
...@@ -114,18 +126,19 @@ class KarateClubDataset(DGLDataset): ...@@ -114,18 +126,19 @@ class KarateClubDataset(DGLDataset):
val_mask = torch.zeros(n_nodes, dtype=torch.bool) val_mask = torch.zeros(n_nodes, dtype=torch.bool)
test_mask = torch.zeros(n_nodes, dtype=torch.bool) test_mask = torch.zeros(n_nodes, dtype=torch.bool)
train_mask[:n_train] = True train_mask[:n_train] = True
val_mask[n_train:n_train + n_val] = True val_mask[n_train : n_train + n_val] = True
test_mask[n_train + n_val:] = True test_mask[n_train + n_val :] = True
self.graph.ndata['train_mask'] = train_mask self.graph.ndata["train_mask"] = train_mask
self.graph.ndata['val_mask'] = val_mask self.graph.ndata["val_mask"] = val_mask
self.graph.ndata['test_mask'] = test_mask self.graph.ndata["test_mask"] = test_mask
def __getitem__(self, i): def __getitem__(self, i):
return self.graph return self.graph
def __len__(self): def __len__(self):
return 1 return 1
dataset = KarateClubDataset() dataset = KarateClubDataset()
graph = dataset[0] graph = dataset[0]
...@@ -136,88 +149,93 @@ print(graph) ...@@ -136,88 +149,93 @@ print(graph)
# Since a link prediction dataset only involves a single graph, preparing # Since a link prediction dataset only involves a single graph, preparing
# a link prediction dataset will have the same experience as preparing a # a link prediction dataset will have the same experience as preparing a
# node classification dataset. # node classification dataset.
# #
###################################################################### ######################################################################
# Creating a Dataset for Graph Classification from CSV # Creating a Dataset for Graph Classification from CSV
# ---------------------------------------------------- # ----------------------------------------------------
# #
# Creating a graph classification dataset involves implementing # Creating a graph classification dataset involves implementing
# ``__getitem__`` to return both the graph and its graph-level label. # ``__getitem__`` to return both the graph and its graph-level label.
# #
# This tutorial demonstrates how to create a graph classification dataset # This tutorial demonstrates how to create a graph classification dataset
# with the following synthetic CSV data: # with the following synthetic CSV data:
# #
# - ``graph_edges.csv``: containing three columns: # - ``graph_edges.csv``: containing three columns:
# #
# - ``graph_id``: the ID of the graph. # - ``graph_id``: the ID of the graph.
# - ``src``: the source node of an edge of the given graph. # - ``src``: the source node of an edge of the given graph.
# - ``dst``: the destination node of an edge of the given graph. # - ``dst``: the destination node of an edge of the given graph.
# #
# - ``graph_properties.csv``: containing three columns: # - ``graph_properties.csv``: containing three columns:
# #
# - ``graph_id``: the ID of the graph. # - ``graph_id``: the ID of the graph.
# - ``label``: the label of the graph. # - ``label``: the label of the graph.
# - ``num_nodes``: the number of nodes in the graph. # - ``num_nodes``: the number of nodes in the graph.
# #
urllib.request.urlretrieve( urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/graph_edges.csv', './graph_edges.csv') "https://data.dgl.ai/tutorial/dataset/graph_edges.csv", "./graph_edges.csv"
)
urllib.request.urlretrieve( urllib.request.urlretrieve(
'https://data.dgl.ai/tutorial/dataset/graph_properties.csv', './graph_properties.csv') "https://data.dgl.ai/tutorial/dataset/graph_properties.csv",
edges = pd.read_csv('./graph_edges.csv') "./graph_properties.csv",
properties = pd.read_csv('./graph_properties.csv') )
edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv("./graph_properties.csv")
edges.head() edges.head()
properties.head() properties.head()
class SyntheticDataset(DGLDataset): class SyntheticDataset(DGLDataset):
def __init__(self): def __init__(self):
super().__init__(name='synthetic') super().__init__(name="synthetic")
def process(self): def process(self):
edges = pd.read_csv('./graph_edges.csv') edges = pd.read_csv("./graph_edges.csv")
properties = pd.read_csv('./graph_properties.csv') properties = pd.read_csv("./graph_properties.csv")
self.graphs = [] self.graphs = []
self.labels = [] self.labels = []
# Create a graph for each graph ID from the edges table. # Create a graph for each graph ID from the edges table.
# First process the properties table into two dictionaries with graph IDs as keys. # First process the properties table into two dictionaries with graph IDs as keys.
# The label and number of nodes are values. # The label and number of nodes are values.
label_dict = {} label_dict = {}
num_nodes_dict = {} num_nodes_dict = {}
for _, row in properties.iterrows(): for _, row in properties.iterrows():
label_dict[row['graph_id']] = row['label'] label_dict[row["graph_id"]] = row["label"]
num_nodes_dict[row['graph_id']] = row['num_nodes'] num_nodes_dict[row["graph_id"]] = row["num_nodes"]
# For the edges, first group the table by graph IDs. # For the edges, first group the table by graph IDs.
edges_group = edges.groupby('graph_id') edges_group = edges.groupby("graph_id")
# For each graph ID... # For each graph ID...
for graph_id in edges_group.groups: for graph_id in edges_group.groups:
# Find the edges as well as the number of nodes and its label. # Find the edges as well as the number of nodes and its label.
edges_of_id = edges_group.get_group(graph_id) edges_of_id = edges_group.get_group(graph_id)
src = edges_of_id['src'].to_numpy() src = edges_of_id["src"].to_numpy()
dst = edges_of_id['dst'].to_numpy() dst = edges_of_id["dst"].to_numpy()
num_nodes = num_nodes_dict[graph_id] num_nodes = num_nodes_dict[graph_id]
label = label_dict[graph_id] label = label_dict[graph_id]
# Create a graph and add it to the list of graphs and labels. # Create a graph and add it to the list of graphs and labels.
g = dgl.graph((src, dst), num_nodes=num_nodes) g = dgl.graph((src, dst), num_nodes=num_nodes)
self.graphs.append(g) self.graphs.append(g)
self.labels.append(label) self.labels.append(label)
# Convert the label list to tensor for saving. # Convert the label list to tensor for saving.
self.labels = torch.LongTensor(self.labels) self.labels = torch.LongTensor(self.labels)
def __getitem__(self, i): def __getitem__(self, i):
return self.graphs[i], self.labels[i] return self.graphs[i], self.labels[i]
def __len__(self): def __len__(self):
return len(self.graphs) return len(self.graphs)
dataset = SyntheticDataset() dataset = SyntheticDataset()
graph, label = dataset[0] graph, label = dataset[0]
print(graph, label) print(graph, label)
......
...@@ -30,30 +30,31 @@ message passing APIs. ...@@ -30,30 +30,31 @@ message passing APIs.
# We describe a layer of graph convolutional neural network from a message # We describe a layer of graph convolutional neural network from a message
# passing perspective; the math can be found `here <math_>`_. # passing perspective; the math can be found `here <math_>`_.
# It boils down to the following step, for each node :math:`u`: # It boils down to the following step, for each node :math:`u`:
# #
# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an # 1) Aggregate neighbors' representations :math:`h_{v}` to produce an
# intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated # intermediate representation :math:`\hat{h}_u`. 2) Transform the aggregated
# representation :math:`\hat{h}_{u}` with a linear projection followed by a # representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`. # non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
# #
# We will implement step 1 with DGL message passing, and step 2 by # We will implement step 1 with DGL message passing, and step 2 by
# PyTorch ``nn.Module``. # PyTorch ``nn.Module``.
# #
# GCN implementation with DGL # GCN implementation with DGL
# `````````````````````````````````````````` # ``````````````````````````````````````````
# We first define the message and reduce function as usual. Since the # We first define the message and reduce function as usual. Since the
# aggregation on a node :math:`u` only involves summing over the neighbors' # aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions: # representations :math:`h_v`, we can simply use builtin functions:
import dgl
import dgl.function as fn
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
gcn_msg = fn.copy_u(u='h', out='m') gcn_msg = fn.copy_u(u="h", out="m")
gcn_reduce = fn.sum(msg='m', out='h') gcn_reduce = fn.sum(msg="m", out="h")
############################################################################### ###############################################################################
# We then proceed to define the GCNLayer module. A GCNLayer essentially performs # We then proceed to define the GCNLayer module. A GCNLayer essentially performs
...@@ -65,6 +66,7 @@ gcn_reduce = fn.sum(msg='m', out='h') ...@@ -65,6 +66,7 @@ gcn_reduce = fn.sum(msg='m', out='h')
# efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`. # efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`.
# #
class GCNLayer(nn.Module): class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats): def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__() super(GCNLayer, self).__init__()
...@@ -75,11 +77,12 @@ class GCNLayer(nn.Module): ...@@ -75,11 +77,12 @@ class GCNLayer(nn.Module):
# (such as the `'h'` ndata below) are automatically popped out # (such as the `'h'` ndata below) are automatically popped out
# when the scope exits. # when the scope exits.
with g.local_scope(): with g.local_scope():
g.ndata['h'] = feature g.ndata["h"] = feature
g.update_all(gcn_msg, gcn_reduce) g.update_all(gcn_msg, gcn_reduce)
h = g.ndata['h'] h = g.ndata["h"]
return self.linear(h) return self.linear(h)
############################################################################### ###############################################################################
# The forward function is essentially the same as any other commonly seen NNs # The forward function is essentially the same as any other commonly seen NNs
# model in PyTorch. We can initialize GCN like any ``nn.Module``. For example, # model in PyTorch. We can initialize GCN like any ``nn.Module``. For example,
...@@ -88,16 +91,19 @@ class GCNLayer(nn.Module): ...@@ -88,16 +91,19 @@ class GCNLayer(nn.Module):
# 1433 and the number of classes is 7). The last GCN layer computes node embeddings, # 1433 and the number of classes is 7). The last GCN layer computes node embeddings,
# so the last layer in general does not apply activation. # so the last layer in general does not apply activation.
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super(Net, self).__init__() super(Net, self).__init__()
self.layer1 = GCNLayer(1433, 16) self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7) self.layer2 = GCNLayer(16, 7)
def forward(self, g, features): def forward(self, g, features):
x = F.relu(self.layer1(g, features)) x = F.relu(self.layer1(g, features))
x = self.layer2(g, x) x = self.layer2(g, x)
return x return x
net = Net() net = Net()
print(net) print(net)
...@@ -105,19 +111,23 @@ print(net) ...@@ -105,19 +111,23 @@ print(net)
# We load the cora dataset using DGL's built-in data module. # We load the cora dataset using DGL's built-in data module.
from dgl.data import CoraGraphDataset from dgl.data import CoraGraphDataset
def load_cora_data(): def load_cora_data():
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
g = dataset[0] g = dataset[0]
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
return g, features, labels, train_mask, test_mask return g, features, labels, train_mask, test_mask
############################################################################### ###############################################################################
# When a model is trained, we can use the following method to evaluate # When a model is trained, we can use the following method to evaluate
# the performance of the model on the test dataset: # the performance of the model on the test dataset:
def evaluate(model, g, features, labels, mask): def evaluate(model, g, features, labels, mask):
model.eval() model.eval()
with th.no_grad(): with th.no_grad():
...@@ -128,35 +138,41 @@ def evaluate(model, g, features, labels, mask): ...@@ -128,35 +138,41 @@ def evaluate(model, g, features, labels, mask):
correct = th.sum(indices == labels) correct = th.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
############################################################################### ###############################################################################
# We then train the network as follows: # We then train the network as follows:
import time import time
import numpy as np import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data() g, features, labels, train_mask, test_mask = load_cora_data()
# Add edges between each node and itself to preserve old node representations # Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes()) g.add_edges(g.nodes(), g.nodes())
optimizer = th.optim.Adam(net.parameters(), lr=1e-2) optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = [] dur = []
for epoch in range(50): for epoch in range(50):
if epoch >=3: if epoch >= 3:
t0 = time.time() t0 = time.time()
net.train() net.train()
logits = net(g, features) logits = net(g, features)
logp = F.log_softmax(logits, 1) logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp[train_mask], labels[train_mask]) loss = F.nll_loss(logp[train_mask], labels[train_mask])
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
if epoch >=3: if epoch >= 3:
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(net, g, features, labels, test_mask) acc = evaluate(net, g, features, labels, test_mask)
print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format( print(
epoch, loss.item(), acc, np.mean(dur))) "Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
epoch, loss.item(), acc, np.mean(dur)
)
)
############################################################################### ###############################################################################
# .. _math: # .. _math:
...@@ -164,9 +180,9 @@ for epoch in range(50): ...@@ -164,9 +180,9 @@ for epoch in range(50):
# GCN in one formula # GCN in one formula
# ------------------ # ------------------
# Mathematically, the GCN model follows this formula: # Mathematically, the GCN model follows this formula:
# #
# :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})` # :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`
# #
# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network, # Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
# :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for # :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
# this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree # this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree
......
...@@ -67,11 +67,12 @@ offers a different perspective. The tutorial describes how to implement a Capsul ...@@ -67,11 +67,12 @@ offers a different perspective. The tutorial describes how to implement a Capsul
# #
# Here's how we set up the graph and initialize node and edge features. # Here's how we set up the graph and initialize node and edge features.
import torch.nn as nn import matplotlib.pyplot as plt
import numpy as np
import torch as th import torch as th
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl import dgl
...@@ -80,8 +81,8 @@ def init_graph(in_nodes, out_nodes, f_size): ...@@ -80,8 +81,8 @@ def init_graph(in_nodes, out_nodes, f_size):
v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes) v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
g = dgl.DGLGraph((u, v)) g = dgl.DGLGraph((u, v))
# init states # init states
g.ndata['v'] = th.zeros(in_nodes + out_nodes, f_size) g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1) g.edata["b"] = th.zeros(in_nodes * out_nodes, 1)
return g return g
...@@ -116,6 +117,7 @@ def init_graph(in_nodes, out_nodes, f_size): ...@@ -116,6 +117,7 @@ def init_graph(in_nodes, out_nodes, f_size):
import dgl.function as fn import dgl.function as fn
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size): def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
...@@ -126,27 +128,33 @@ class DGLRoutingLayer(nn.Module): ...@@ -126,27 +128,33 @@ class DGLRoutingLayer(nn.Module):
self.out_indx = list(range(in_nodes, in_nodes + out_nodes)) self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat self.g.edata["u_hat"] = u_hat
for r in range(routing_num): for r in range(routing_num):
# step 1 (line 4): normalize over out edges # step 1 (line 4): normalize over out edges
edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes) edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1) self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
self.g.edata['c u_hat'] = self.g.edata['c'] * self.g.edata['u_hat'] self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"]
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all(fn.copy_e('c u_hat', 'm'), fn.sum('m', 's')) self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s"))
# step 3 (line 6) # step 3 (line 6)
self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1) self.g.nodes[self.out_indx].data["v"] = self.squash(
self.g.nodes[self.out_indx].data["s"], dim=1
)
# step 4 (line 7) # step 4 (line 7)
v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0) v = th.cat(
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True) [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
)
self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["u_hat"] * v
).sum(dim=1, keepdim=True)
@staticmethod @staticmethod
def squash(s, dim=1): def squash(s, dim=1):
sq = th.sum(s ** 2, dim=dim, keepdim=True) sq = th.sum(s**2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq) s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm) s = (sq / (1.0 + sq)) * (s / s_norm)
return s return s
...@@ -172,14 +180,14 @@ dist_list = [] ...@@ -172,14 +180,14 @@ dist_list = []
for i in range(10): for i in range(10):
routing(u_hat) routing(u_hat)
dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes) dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
entropy_list.append(entropy.data.numpy()) entropy_list.append(entropy.data.numpy())
dist_list.append(dist_matrix.data.numpy()) dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1) stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1) means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o') plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
plt.ylabel("Entropy of Weight Distribution") plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing") plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list))) plt.xticks(np.arange(len(entropy_list)))
...@@ -189,8 +197,8 @@ plt.close() ...@@ -189,8 +197,8 @@ plt.close()
# #
# Alternatively, we can also watch the evolution of histograms. # Alternatively, we can also watch the evolution of histograms.
import seaborn as sns
import matplotlib.animation as animation import matplotlib.animation as animation
import seaborn as sns
fig = plt.figure(dpi=150) fig = plt.figure(dpi=150)
fig.clf() fig.clf()
...@@ -204,7 +212,9 @@ def dist_animate(i): ...@@ -204,7 +212,9 @@ def dist_animate(i):
ax.set_title("Routing: %d" % (i)) ax.set_title("Routing: %d" % (i))
ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500) ani = animation.FuncAnimation(
fig, dist_animate, frames=len(entropy_list), interval=500
)
plt.close() plt.close()
############################################################################################################ ############################################################################################################
...@@ -226,22 +236,43 @@ pos = dict() ...@@ -226,22 +236,43 @@ pos = dict()
fig2 = plt.figure(figsize=(8, 3), dpi=150) fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf() fig2.clf()
ax = fig2.subplots() ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1 pos.update(
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2 (n, (i, 1)) for i, n in zip(height_in_y, X)
) # put nodes from X at x=1
pos.update(
(n, (i, 2)) for i, n in zip(height_out_y, Y)
) # put nodes from Y at x=2
def weight_animate(i): def weight_animate(i):
ax.cla() ax.cla()
ax.axis('off') ax.axis("off")
ax.set_title("Routing: %d " % i) ax.set_title("Routing: %d " % i)
dm = dist_list[i] dm = dist_list[i]
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax) nx.draw_networkx_nodes(
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax) g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax
)
nx.draw_networkx_nodes(
g,
pos,
nodelist=range(in_nodes, in_nodes + out_nodes),
node_color="b",
node_size=100,
ax=ax,
)
for edge in g.edges(): for edge in g.edges():
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax) nx.draw_networkx_edges(
g,
pos,
ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500) edgelist=[edge],
width=dm[edge[0], edge[1] - in_nodes] * 1.5,
ax=ax,
)
ani2 = animation.FuncAnimation(
fig2, weight_animate, frames=len(dist_list), interval=500
)
plt.close() plt.close()
############################################################################################################ ############################################################################################################
...@@ -257,4 +288,3 @@ plt.close() ...@@ -257,4 +288,3 @@ plt.close()
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png # .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif # .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif # .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
...@@ -68,16 +68,19 @@ For communication between multiple processes in multi-gpu training, we need ...@@ -68,16 +68,19 @@ For communication between multiple processes in multi-gpu training, we need
to start the distributed backend at the beginning of each process. We use to start the distributed backend at the beginning of each process. We use
`world_size` to refer to the number of processes and `rank` to refer to the `world_size` to refer to the number of processes and `rank` to refer to the
process ID, which should be an integer from `0` to `world_size - 1`. process ID, which should be an integer from `0` to `world_size - 1`.
""" """
import torch.distributed as dist import torch.distributed as dist
def init_process_group(world_size, rank): def init_process_group(world_size, rank):
dist.init_process_group( dist.init_process_group(
backend='gloo', # change to 'nccl' for multiple GPUs backend="gloo", # change to 'nccl' for multiple GPUs
init_method='tcp://127.0.0.1:12345', init_method="tcp://127.0.0.1:12345",
world_size=world_size, world_size=world_size,
rank=rank) rank=rank,
)
############################################################################### ###############################################################################
# Data Loader Preparation # Data Loader Preparation
...@@ -87,25 +90,28 @@ def init_process_group(world_size, rank): ...@@ -87,25 +90,28 @@ def init_process_group(world_size, rank):
# splitting, we need to use a same random seed across processes to ensure a # splitting, we need to use a same random seed across processes to ensure a
# same split. We follow the common practice to train with multiple GPUs and # same split. We follow the common practice to train with multiple GPUs and
# evaluate with a single GPU, thus only set `use_ddp` to True in the # evaluate with a single GPU, thus only set `use_ddp` to True in the
# :func:`~dgl.dataloading.pytorch.GraphDataLoader` for the training set, where # :func:`~dgl.dataloading.pytorch.GraphDataLoader` for the training set, where
# `ddp` stands for :func:`~torch.nn.parallel.DistributedDataParallel`. # `ddp` stands for :func:`~torch.nn.parallel.DistributedDataParallel`.
# #
from dgl.data import split_dataset from dgl.data import split_dataset
from dgl.dataloading import GraphDataLoader from dgl.dataloading import GraphDataLoader
def get_dataloaders(dataset, seed, batch_size=32): def get_dataloaders(dataset, seed, batch_size=32):
# Use a 80:10:10 train-val-test split # Use a 80:10:10 train-val-test split
train_set, val_set, test_set = split_dataset(dataset, train_set, val_set, test_set = split_dataset(
frac_list=[0.8, 0.1, 0.1], dataset, frac_list=[0.8, 0.1, 0.1], shuffle=True, random_state=seed
shuffle=True, )
random_state=seed) train_loader = GraphDataLoader(
train_loader = GraphDataLoader(train_set, use_ddp=True, batch_size=batch_size, shuffle=True) train_set, use_ddp=True, batch_size=batch_size, shuffle=True
)
val_loader = GraphDataLoader(val_set, batch_size=batch_size) val_loader = GraphDataLoader(val_set, batch_size=batch_size)
test_loader = GraphDataLoader(test_set, batch_size=batch_size) test_loader = GraphDataLoader(test_set, batch_size=batch_size)
return train_loader, val_loader, test_loader return train_loader, val_loader, test_loader
############################################################################### ###############################################################################
# Model Initialization # Model Initialization
# -------------------- # --------------------
...@@ -115,14 +121,20 @@ def get_dataloaders(dataset, seed, batch_size=32): ...@@ -115,14 +121,20 @@ def get_dataloaders(dataset, seed, batch_size=32):
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dgl.nn.pytorch import GINConv, SumPooling from dgl.nn.pytorch import GINConv, SumPooling
class GIN(nn.Module): class GIN(nn.Module):
def __init__(self, input_size=1, num_classes=2): def __init__(self, input_size=1, num_classes=2):
super(GIN, self).__init__() super(GIN, self).__init__()
self.conv1 = GINConv(nn.Linear(input_size, num_classes), aggregator_type='sum') self.conv1 = GINConv(
self.conv2 = GINConv(nn.Linear(num_classes, num_classes), aggregator_type='sum') nn.Linear(input_size, num_classes), aggregator_type="sum"
)
self.conv2 = GINConv(
nn.Linear(num_classes, num_classes), aggregator_type="sum"
)
self.pool = SumPooling() self.pool = SumPooling()
def forward(self, g, feats): def forward(self, g, feats):
...@@ -132,6 +144,7 @@ class GIN(nn.Module): ...@@ -132,6 +144,7 @@ class GIN(nn.Module):
return self.pool(g, feats) return self.pool(g, feats)
############################################################################### ###############################################################################
# To ensure same initial model parameters across processes, we need to set the # To ensure same initial model parameters across processes, we need to set the
# same random seed before model initialization. Once we construct a model # same random seed before model initialization. Once we construct a model
...@@ -141,16 +154,20 @@ class GIN(nn.Module): ...@@ -141,16 +154,20 @@ class GIN(nn.Module):
import torch import torch
from torch.nn.parallel import DistributedDataParallel from torch.nn.parallel import DistributedDataParallel
def init_model(seed, device): def init_model(seed, device):
torch.manual_seed(seed) torch.manual_seed(seed)
model = GIN().to(device) model = GIN().to(device)
if device.type == 'cpu': if device.type == "cpu":
model = DistributedDataParallel(model) model = DistributedDataParallel(model)
else: else:
model = DistributedDataParallel(model, device_ids=[device], output_device=device) model = DistributedDataParallel(
model, device_ids=[device], output_device=device
)
return model return model
############################################################################### ###############################################################################
# Main Function for Each Process # Main Function for Each Process
# ----------------------------- # -----------------------------
...@@ -158,6 +175,7 @@ def init_model(seed, device): ...@@ -158,6 +175,7 @@ def init_model(seed, device):
# Define the model evaluation function as in the single-GPU setting. # Define the model evaluation function as in the single-GPU setting.
# #
def evaluate(model, dataloader, device): def evaluate(model, dataloader, device):
model.eval() model.eval()
...@@ -168,7 +186,7 @@ def evaluate(model, dataloader, device): ...@@ -168,7 +186,7 @@ def evaluate(model, dataloader, device):
bg = bg.to(device) bg = bg.to(device)
labels = labels.to(device) labels = labels.to(device)
# Get input node features # Get input node features
feats = bg.ndata.pop('attr') feats = bg.ndata.pop("attr")
with torch.no_grad(): with torch.no_grad():
pred = model(bg, feats) pred = model(bg, feats)
_, pred = torch.max(pred, 1) _, pred = torch.max(pred, 1)
...@@ -177,26 +195,27 @@ def evaluate(model, dataloader, device): ...@@ -177,26 +195,27 @@ def evaluate(model, dataloader, device):
return 1.0 * total_correct / total return 1.0 * total_correct / total
############################################################################### ###############################################################################
# Define the main function for each process. # Define the main function for each process.
# #
from torch.optim import Adam from torch.optim import Adam
def main(rank, world_size, dataset, seed=0): def main(rank, world_size, dataset, seed=0):
init_process_group(world_size, rank) init_process_group(world_size, rank)
if torch.cuda.is_available(): if torch.cuda.is_available():
device = torch.device('cuda:{:d}'.format(rank)) device = torch.device("cuda:{:d}".format(rank))
torch.cuda.set_device(device) torch.cuda.set_device(device)
else: else:
device = torch.device('cpu') device = torch.device("cpu")
model = init_model(seed, device) model = init_model(seed, device)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.01) optimizer = Adam(model.parameters(), lr=0.01)
train_loader, val_loader, test_loader = get_dataloaders(dataset, train_loader, val_loader, test_loader = get_dataloaders(dataset, seed)
seed)
for epoch in range(5): for epoch in range(5):
model.train() model.train()
# The line below ensures all processes use a different # The line below ensures all processes use a different
...@@ -207,7 +226,7 @@ def main(rank, world_size, dataset, seed=0): ...@@ -207,7 +226,7 @@ def main(rank, world_size, dataset, seed=0):
for bg, labels in train_loader: for bg, labels in train_loader:
bg = bg.to(device) bg = bg.to(device)
labels = labels.to(device) labels = labels.to(device)
feats = bg.ndata.pop('attr') feats = bg.ndata.pop("attr")
pred = model(bg, feats) pred = model(bg, feats)
loss = criterion(pred, labels) loss = criterion(pred, labels)
...@@ -216,15 +235,16 @@ def main(rank, world_size, dataset, seed=0): ...@@ -216,15 +235,16 @@ def main(rank, world_size, dataset, seed=0):
loss.backward() loss.backward()
optimizer.step() optimizer.step()
loss = total_loss loss = total_loss
print('Loss: {:.4f}'.format(loss)) print("Loss: {:.4f}".format(loss))
val_acc = evaluate(model, val_loader, device) val_acc = evaluate(model, val_loader, device)
print('Val acc: {:.4f}'.format(val_acc)) print("Val acc: {:.4f}".format(val_acc))
test_acc = evaluate(model, test_loader, device) test_acc = evaluate(model, test_loader, device)
print('Test acc: {:.4f}'.format(test_acc)) print("Test acc: {:.4f}".format(test_acc))
dist.destroy_process_group() dist.destroy_process_group()
############################################################################### ###############################################################################
# Finally we load the dataset and launch the processes. # Finally we load the dataset and launch the processes.
# #
...@@ -232,9 +252,9 @@ def main(rank, world_size, dataset, seed=0): ...@@ -232,9 +252,9 @@ def main(rank, world_size, dataset, seed=0):
# #
# if __name__ == '__main__': # if __name__ == '__main__':
# import torch.multiprocessing as mp # import torch.multiprocessing as mp
# #
# from dgl.data import GINDataset # from dgl.data import GINDataset
# #
# num_gpus = 4 # num_gpus = 4
# procs = [] # procs = []
# dataset = GINDataset(name='IMDBBINARY', self_loop=False) # dataset = GINDataset(name='IMDBBINARY', self_loop=False)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment