Unverified Commit 1d2a1cdc authored by Rhett Ying's avatar Rhett Ying Committed by GitHub
Browse files

[doc] update edge classification chapter (#6642)

parent e3752754
......@@ -16,36 +16,45 @@ You can use the
.. code:: python
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
datapipe = datapipe.sample_neighbor(g, [10, 10])
# Or equivalently
datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])
To use the neighborhood sampler provided by DGL for edge classification,
one need to instead combine it with
:func:`~dgl.dataloading.as_edge_prediction_sampler`, which iterates
over a set of edges in minibatches, yielding the subgraph induced by the
edge minibatch and *message flow graphs* (MFGs) to be consumed by the module below.
For example, the following code creates a PyTorch DataLoader that
iterates over the training edge ID array ``train_eids`` in batches,
putting the list of generated MFGs onto GPU.
The code for defining a data loader is also the same as that of node
classification. The only difference is that it iterates over the
edges(namely, node pairs) in the training set instead of the nodes.
.. code:: python
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
dataloader = dgl.dataloading.DataLoader(
g, train_eid_dict, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=4)
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
# Or equivalently:
# datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch`
which contains a list of specially created graphs representing the computation
dependencies on each layer. They are called *message flow graphs* (MFGs) in DGL.
.. code:: python
mini_batch = next(iter(dataloader))
print(mini_batch.blocks)
.. note::
See the :doc:`Stochastic Training Tutorial
<tutorials/large/L0_neighbor_sampling_overview>` for the concept of
message flow graph.
For a complete list of supported builtin samplers, please refer to the
:ref:`neighborhood sampler API reference <api-dataloading-neighbor-sampling>`.
<../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__
for the concept of message flow graph.
If you wish to develop your own neighborhood sampler or you want a more
detailed explanation of the concept of MFGs, please refer to
......@@ -63,26 +72,29 @@ an edge exists between the two nodes, and potentially use it for
advantage.
Therefore in edge classification you sometimes would like to exclude the
edges sampled in the minibatch from the original graph for neighborhood
sampling, as well as the reverse edges of the sampled edges on an
undirected graph. You can specify ``exclude='reverse_id'`` in calling
:func:`~dgl.dataloading.as_edge_prediction_sampler`, with the mapping of the edge
IDs to their reverse edges IDs. Usually doing so will lead to much slower
sampling process due to locating the reverse edges involving in the minibatch
and removing them.
seed edges as well as their reverse edges from the sampled minibatch.
You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
:class:`~dgl.graphbolt.MiniBatchTransformer` to achieve this.
.. code:: python
n_edges = g.num_edges()
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, exclude='reverse_id', reverse_eids=torch.cat([
torch.arange(n_edges // 2, n_edges), torch.arange(0, n_edges // 2)]))
dataloader = dgl.dataloading.DataLoader(
g, train_eid_dict, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=4)
import dgl.graphbolt as gb
from functools import partial
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_paris = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 2, (5,))
train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
datapipe = datapipe.transform(exclude_seed_edges)
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Adapt your model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -113,14 +125,12 @@ input features.
return x
The input to the latter part is usually the output from the
former part, as well as the subgraph of the original graph induced by the
edges in the minibatch. The subgraph is yielded from the same data
loader. One can call :meth:`dgl.DGLGraph.apply_edges` to compute the
scores on the edges with the edge subgraph.
former part, as well as the subgraph(node pairs) of the original graph induced
by the edges in the minibatch. The subgraph is yielded from the same data
loader.
The following code shows an example of predicting scores on the edges by
concatenating the incident node features and projecting it with a dense
layer.
concatenating the incident node features and projecting it with a dense layer.
.. code:: python
......@@ -129,19 +139,15 @@ layer.
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def apply_edges(self, edges):
data = torch.cat([edges.src['x'], edges.dst['x']], 1)
return {'score': self.W(data)}
def forward(self, edge_subgraph, x):
with edge_subgraph.local_scope():
edge_subgraph.ndata['x'] = x
edge_subgraph.apply_edges(self.apply_edges)
return edge_subgraph.edata['score']
def forward(self, node_pairs, x):
src_x = x[node_pairs[0]]
dst_x = x[node_pairs[1]]
data = torch.cat([src_x, dst_x], 1)
return self.W(data)
The entire model will take the list of MFGs and the edge subgraph
generated by the data loader, as well as the input node features as
follows:
The entire model will take the list of MFGs and the edges generated by the data
loader, as well as the input node features as follows:
.. code:: python
......@@ -151,10 +157,10 @@ follows:
self.gcn = StochasticTwoLayerGCN(
in_features, hidden_features, out_features)
self.predictor = ScorePredictor(num_classes, out_features)
def forward(self, edge_subgraph, blocks, x):
def forward(self, blocks, x, node_pairs):
x = self.gcn(blocks, x)
return self.predictor(edge_subgraph, x)
return self.predictor(node_pairs, x)
DGL ensures that that the nodes in the edge subgraph are the same as the
output nodes of the last MFG in the generated list of MFGs.
......@@ -169,21 +175,21 @@ their incident node representations.
.. code:: python
import torch.nn.functional as F
model = Model(in_features, hidden_features, out_features, num_classes)
model = model.cuda()
model = model.to(device)
opt = torch.optim.Adam(model.parameters())
for input_nodes, edge_subgraph, blocks in dataloader:
blocks = [b.to(torch.device('cuda')) for b in blocks]
edge_subgraph = edge_subgraph.to(torch.device('cuda'))
input_features = blocks[0].srcdata['features']
edge_labels = edge_subgraph.edata['labels']
edge_predictions = model(edge_subgraph, blocks, input_features)
loss = compute_loss(edge_labels, edge_predictions)
for data in dataloader:
blocks = data.blocks
x = data.edge_features("feat")
y_hat = model(data.blocks, x, data.positive_node_pairs)
loss = F.cross_entropy(data.labels, y_hat)
opt.zero_grad()
loss.backward()
opt.step()
For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
......@@ -212,7 +218,7 @@ classification/regression.
For score prediction, the only implementation difference between the
homogeneous graph and the heterogeneous graph is that we are looping
over the edge types for :meth:`~dgl.DGLGraph.apply_edges`.
over the edge types.
.. code:: python
......@@ -221,16 +227,13 @@ over the edge types for :meth:`~dgl.DGLGraph.apply_edges`.
super().__init__()
self.W = nn.Linear(2 * in_features, num_classes)
def apply_edges(self, edges):
data = torch.cat([edges.src['x'], edges.dst['x']], 1)
return {'score': self.W(data)}
def forward(self, edge_subgraph, x):
with edge_subgraph.local_scope():
edge_subgraph.ndata['x'] = x
for etype in edge_subgraph.canonical_etypes:
edge_subgraph.apply_edges(self.apply_edges, etype=etype)
return edge_subgraph.edata['score']
def forward(self, node_pairs, x):
scores = {}
for etype in node_pairs.keys():
src, dst = node_pairs[etype]
data = torch.cat([x[etype][src], x[etype][dst]], 1)
scores[etype] = self.W(data)
return scores
class Model(nn.Module):
def __init__(self, in_features, hidden_features, out_features, num_classes,
......@@ -240,34 +243,46 @@ over the edge types for :meth:`~dgl.DGLGraph.apply_edges`.
in_features, hidden_features, out_features, etypes)
self.pred = ScorePredictor(num_classes, out_features)
def forward(self, edge_subgraph, blocks, x):
def forward(self, node_pairs, blocks, x):
x = self.rgcn(blocks, x)
return self.pred(edge_subgraph, x)
return self.pred(node_pairs, x)
Data loader definition is also very similar to that of node
classification. The only difference is that you need
:func:`~dgl.dataloading.as_edge_prediction_sampler`,
and you will be supplying a
dictionary of edge types and edge ID tensors instead of a dictionary of
node types and node ID tensors.
Data loader definition is almost identical to that of homogeneous graph. The
only difference is that the train_set is now an instance of
:class:`~dgl.graphbolt.ItemSetDict` instead of :class:`~dgl.graphbolt.ItemSet`.
.. code:: python
sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
sampler = dgl.dataloading.as_edge_prediction_sampler(sampler)
dataloader = dgl.dataloading.DataLoader(
g, train_eid_dict, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=4)
import dgl.graphbolt as gb
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
g = gb.SamplingGraph()
node_pairs = torch.arange(0, 1000).reshape(-1, 2)
labels = torch.randint(0, 3, (1000,))
node_pairs_labels = {
"user:like:item": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
),
"user:follow:user": gb.ItemSet(
(node_pairs, labels), names=("node_pairs", "labels")
),
}
train_set = gb.ItemSetDict(node_pairs_labels)
datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"item": ["feat"], "user": ["feat"]}
)
datapipe = datapipe.to_dgl()
datapipe = datapipe.copy_to(device)
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
Things become a little different if you wish to exclude the reverse
edges on heterogeneous graphs. On heterogeneous graphs, reverse edges
usually have a different edge type from the edges themselves, in order
to differentiate the forward and backward relationships (e.g.
``follow`` and ``followed by`` are reverse relations of each other,
``purchase`` and ``purchased by`` are reverse relations of each other,
``follow`` and ``followed_by`` are reverse relations of each other,
``like`` and ``liked_by`` are reverse relations of each other,
etc.).
If each edge in a type has a reverse edge with the same ID in another
......@@ -277,16 +292,17 @@ reverse edges then goes as follows.
.. code:: python
sampler = dgl.dataloading.as_edge_prediction_sampler(
sampler, exclude='reverse_types',
reverse_etypes={'follow': 'followed by', 'followed by': 'follow',
'purchase': 'purchased by', 'purchased by': 'purchase'})
dataloader = dgl.dataloading.DataLoader(
g, train_eid_dict, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=4)
exclude_seed_edges = partial(
gb.exclude_seed_edges,
include_reverse_edges=True,
reverse_etypes_mapping={
"user:like:item": "item:liked_by:user",
"user:follow:user": "user:followed_by:user",
},
)
datapipe = datapipe.transform(exclude_seed_edges)
The training loop is again almost the same as that on homogeneous graph,
except for the implementation of ``compute_loss`` that will take in two
......@@ -309,7 +325,3 @@ dictionaries of node types and predictions here.
loss.backward()
opt.step()
`GCMC <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcmc>`__
is an example of edge classification on a bipartite graph.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment