Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
dgl
Commits
0a3f6116
Unverified
Commit
0a3f6116
authored
Dec 01, 2023
by
Rhett Ying
Committed by
GitHub
Dec 01, 2023
Browse files
[doc] update custom sampler (#6653)
parent
1e40c81b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
89 additions
and
99 deletions
+89
-99
docs/source/guide/minibatch-custom-sampler.rst
docs/source/guide/minibatch-custom-sampler.rst
+89
-99
No files found.
docs/source/guide/minibatch-custom-sampler.rst
View file @
0a3f6116
...
@@ -3,143 +3,133 @@
...
@@ -3,143 +3,133 @@
6.4 Implementing Custom Graph Samplers
6.4 Implementing Custom Graph Samplers
----------------------------------------------
----------------------------------------------
Implementing custom samplers involves subclassing the :class:`dgl.dataloading.Sampler`
Implementing custom samplers involves subclassing the
base class and implementing its abstract :attr:`sample` method. The :attr:`sample`
:class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract
method should take in two arguments:
:attr:`_sample_subgraphs` method. The :attr:`_sample_subgraphs` method should
take in seed nodes which are the nodes to sample neighbors from:
.. code:: python
.. code:: python
def sample
(self, g, indic
es):
def
_
sample
_subgraphs(self, seed_nod
es):
pas
s
return input_nodes, sampled_subgraph
s
The first argument :attr:`g` is the original graph to sample from while
The method should return the input node IDs list and a list of subgraphs. Each
the second argument :attr:`indices` is the indices of the current mini-batch
subgraph is a :class:`~dgl.graphbolt.SampledSubgraph` object.
-- it generally could be anything depending on what indices are given to the
accompanied :class:`~dgl.dataloading.DataLoader` but are typically seed node
or seed edge IDs. The function returns the mini-batch of samples for
the current iteration.
.. note::
The design here is similar to PyTorch's ``torch.utils.data.DataLoader``,
which is an iterator of dataset. Users can customize how to batch samples
using its ``collate_fn`` argument. Here in DGL, ``dgl.dataloading.DataLoader``
is an iterator of ``indices`` (e.g., training node IDs) while ``Sampler``
converts a batch of indices into a batch of graph- or tensor-type samples.
Any other data that are required during sampling such as the graph structure,
fanout size, etc. should be passed to the sampler via the constructor.
The code below implements a classical neighbor sampler:
The code below implements a classical neighbor sampler:
.. code:: python
.. code:: python
class NeighborSampler(dgl.dataloading.Sampler):
@functional_datapipe("customized_sample_neighbor")
def __init__(self, fanouts : list[int]):
class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
super().__init__()
def __init__(self, datapipe, graph, fanouts):
super().__init__(datapipe)
self.graph = graph
self.fanouts = fanouts
self.fanouts = fanouts
def sample(self, g, seed_nodes):
def _sample_subgraphs(self, seed_nodes):
output_nodes = seed_nodes
subgs = []
subgs = []
for fanout in reversed(self.fanouts):
for fanout in reversed(self.fanouts):
# Sample a fixed number of neighbors of the current seed nodes.
# Sample a fixed number of neighbors of the current seed nodes.
sg = g.sample_neighbors(seed_nodes, fanout)
input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
# Convert this subgraph to a message flow graph.
sg = dgl.to_block(sg, seed_nodes)
seed_nodes = sg.srcdata[NID]
subgs.insert(0, sg)
subgs.insert(0, sg)
input
_nodes =
seed
_nodes
seed
_nodes =
input
_nodes
return input_nodes,
output_nodes,
subgs
return input_nodes, subgs
To use this sampler with
``
DataLoader`
`
:
To use this sampler with
:class:`~dgl.graphbolt.MultiProcess
DataLoader`:
.. code:: python
.. code:: python
graph = ... # the graph to be sampled from
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
train_nids = ... # an 1-D tensor of training node IDs
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
sampler = NeighborSampler([10, 15]) # create a sampler
datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
dataloader = dgl.dataloading.DataLoader(
datapipe = datapipe.to_dgl()
graph,
datapipe = datapipe.copy_to(device)
train_nids,
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
sampler,
batch_size=32, # batch_size decides how many IDs are passed to sampler at once
for data in dataloader:
... # other arguments
input_features = data.node_features["feat"]
)
output_labels = data.labels
for i, mini_batch in enumerate(dataloader):
output_predictions = model(data.blocks, input_features)
# unpack the mini batch
loss = compute_loss(output_labels, output_predictions)
input_nodes, output_nodes, subgs = mini_batch
opt.zero_grad()
train(input_nodes, output_nodes, subgs)
loss.backward()
opt.step()
Sampler for Heterogeneous Graphs
Sampler for Heterogeneous Graphs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
To write a sampler for heterogeneous graphs, one needs to be aware that
To write a sampler for heterogeneous graphs, one needs to be aware that
the argument `
`g`` will be
a heterogeneous graph while `
`indices`
` could be a
the argument `
graph` is
a heterogeneous graph while `
seeds
` could be a
dictionary of ID tensors. Most of DGL's graph sampling operators (e.g.,
dictionary of ID tensors. Most of DGL's graph sampling operators (e.g.,
the ``sample_neighbors`` and ``to_block`` functions in the above example) can
the ``sample_neighbors`` and ``to_block`` functions in the above example) can
work on heterogeneous graph natively, so many samplers are automatically
work on heterogeneous graph natively, so many samplers are automatically
ready for heterogeneous graph. For example, the above ``NeighborSampler``
ready for heterogeneous graph. For example, the above ``
Customized
NeighborSampler``
can be used on heterogeneous graphs:
can be used on heterogeneous graphs:
.. code:: python
.. code:: python
hg = dgl.heterograph({
import dgl.graphbolt as gb
('user', 'like', 'movie') : ...,
hg = gb.FusedCSCSamplingGraph()
('user', 'follow', 'user') : ...,
train_set = item_set = gb.ItemSetDict(
('movie', 'liked-by', 'user') : ...,
{
})
"user": gb.ItemSet(
train_nids = {'user' : ..., 'movie' : ...} # training IDs of 'user' and 'movie' nodes
(torch.arange(0, 5), torch.arange(5, 10)),
sampler = NeighborSampler([10, 15]) # create a sampler
names=("seed_nodes", "labels"),
dataloader = dgl.dataloading.DataLoader(
),
hg,
"item": gb.ItemSet(
train_nids,
(torch.arange(5, 10), torch.arange(10, 15)),
sampler,
names=("seed_nodes", "labels"),
batch_size=32, # batch_size decides how many IDs are passed to sampler at once
),
... # other arguments
}
)
datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
datapipe = datapipe.fetch_feature(
feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
)
)
for i, mini_batch in enumerate(dataloader):
datapipe = datapipe.to_dgl()
# unpack the mini batch
datapipe = datapipe.copy_to(device)
# input_nodes and output_nodes are dictionary while subgs are a list of
dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
# heterogeneous graphs
input_nodes, output_nodes, subgs = mini_batch
for data in dataloader:
train(input_nodes, output_nodes, subgs)
input_features = {
ntype: data.node_features[(ntype, "feat")]
Exclude Edges During Sampling
for ntype in data.blocks[0].srctypes
}
output_labels = data.labels["user"]
output_predictions = model(data.blocks, input_features)["user"]
loss = compute_loss(output_labels, output_predictions)
opt.zero_grad()
loss.backward()
opt.step()
Exclude Edges After Sampling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The examples above all belong to *node-wise sampler* because the ``indices`` argument
In some cases, we may want to exclude seed edges from the sampled subgraph. For
to the ``sample`` method represents a batch of seed node IDs. Another common type of
example, in link prediction tasks, we want to exclude the edges in the
samplers is *edge-wise sampler* which, as name suggested, takes in a batch of seed
training set from the sampled subgraph to prevent information leakage. To
edge IDs to construct mini-batch data. DGL provides a utility
do so, we need to add an additional datapipe right after sampling as follows:
:func:`dgl.dataloading.as_edge_prediction_sampler` to turn a node-wise sampler to
an edge-wise sampler. To prevent information leakge, it requires the node-wise sampler
to have an additional third argument ``exclude_eids``. The code below modifies
the ``NeighborSampler`` we just defined to properly exclude edges from the sampled
subgraph:
.. code:: python
.. code:: python
class NeighborSampler(Sampler):
datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
def __init__(self, fanouts):
datapipe = datapipe.transform(gb.exclude_seed_edges)
super().__init__()
self.fanouts = fanouts
# NOTE: There is an additional third argument. For homogeneous graphs,
Please check the API page of :func:`~dgl.graphbolt.exclude_seed_edges` for more
# it is an 1-D tensor of integer IDs. For heterogeneous graphs, it
details.
# is a dictionary of ID tensors. We usually set its default value to be None.
def sample(self, g, seed_nodes, exclude_eids=None):
The above API is based on :meth:`~dgl.graphbolt.SampledSubgrahp.exclude_edges`.
output_nodes = seed_nodes
If you want to exclude edges from the sampled subgraph based on some other
subgs = []
criteria, you could write your own transform function. Please check the method
for fanout in reversed(self.fanouts):
for reference.
# Sample a fixed number of neighbors of the current seed nodes.
sg = g.sample_neighbors(seed_nodes, fanout, exclude_edges=exclude_eids)
# Convert this subgraph to a message flow graph.
sg = dgl.to_block(sg, seed_nodes)
seed_nodes = sg.srcdata[NID]
subgs.insert(0, sg)
input_nodes = seed_nodes
return input_nodes, output_nodes, subgs
Further Readings
You could also refer to examples in
~~~~~~~~~~~~~~~~~~
`Link Prediction <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__.
See :ref:`guide-minibatch-prefetching` for how to write a custom graph sampler
with feature prefetching.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment