minibatch-custom-sampler.rst 5.84 KB
Newer Older
1
2
.. _guide-minibatch-customizing-neighborhood-sampler:

3
6.4 Implementing Custom Graph Samplers
4
5
----------------------------------------------

6
7
8
Implementing custom samplers involves subclassing the :class:`dgl.dataloading.Sampler`
base class and implementing its abstract :attr:`sample` method.  The :attr:`sample`
method should take in two arguments:
9
10
11

.. code:: python

12
13
   def sample(self, g, indices):
       pass
14

15
16
17
18
19
20
The first argument :attr:`g` is the original graph to sample from while
the second argument :attr:`indices` is the indices of the current mini-batch
-- it generally could be anything depending on what indices are given to the
accompanied :class:`~dgl.dataloading.DataLoader` but are typically seed node
or seed edge IDs. The function returns the mini-batch of samples for
the current iteration.
21

22
23
.. note::

24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
    The design here is similar to PyTorch's ``torch.utils.data.DataLoader``,
    which is an iterator of dataset. Users can customize how to batch samples
    using its ``collate_fn`` argument. Here in DGL, ``dgl.dataloading.DataLoader``
    is an iterator of ``indices`` (e.g., training node IDs) while ``Sampler``
    converts a batch of indices into a batch of graph- or tensor-type samples.


The code below implements a classical neighbor sampler:

.. code:: python

   class NeighborSampler(dgl.dataloading.Sampler):
       def __init__(self, fanouts : list[int]):
           super().__init__()
           self.fanouts = fanouts

       def sample(self, g, seed_nodes):
           output_nodes = seed_nodes
           subgs = []
           for fanout in reversed(self.fanouts):
               # Sample a fixed number of neighbors of the current seed nodes.
               sg = g.sample_neighbors(seed_nodes, fanout)
               # Convert this subgraph to a message flow graph.
               sg = dgl.to_block(sg, seed_nodes)
               seed_nodes = sg.srcdata[NID]
               subgs.insert(0, sg)
50
51
               input_nodes = seed_nodes
           return input_nodes, output_nodes, subgs
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138

To use this sampler with ``DataLoader``:

.. code:: python

    graph = ...  # the graph to be sampled from
    train_nids = ...  # an 1-D tensor of training node IDs
    sampler = NeighborSampler([10, 15])  # create a sampler
    dataloader = dgl.dataloading.DataLoader(
        graph,
        train_nids,
        sampler,
        batch_size=32,    # batch_size decides how many IDs are passed to sampler at once
        ...               # other arguments
    )
    for i, mini_batch in enumerate(dataloader):
        # unpack the mini batch
        input_nodes, output_nodes, subgs = mini_batch
        train(input_nodes, output_nodes, subgs)

Sampler for Heterogeneous Graphs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To write a sampler for heterogeneous graphs, one needs to be aware that
the argument ``g`` will be a heterogeneous graph while ``indices`` could be a
dictionary of ID tensors. Most of DGL's graph sampling operators (e.g.,
the ``sample_neighbors`` and ``to_block`` functions in the above example) can
work on heterogeneous graph natively, so many samplers are automatically
ready for heterogeneous graph. For example, the above ``NeighborSampler``
can be used on heterogeneous graphs:

.. code:: python

    hg = dgl.heterograph({
        ('user', 'like', 'movie') : ...,
        ('user', 'follow', 'user') : ...,
        ('movie', 'liked-by', 'user') : ...,
    })
    train_nids = {'user' : ..., 'movie' : ...}  # training IDs of 'user' and 'movie' nodes
    sampler = NeighborSampler([10, 15])  # create a sampler
    dataloader = dgl.dataloading.DataLoader(
        hg,
        train_nids,
        sampler,
        batch_size=32,    # batch_size decides how many IDs are passed to sampler at once
        ...               # other arguments
    )
    for i, mini_batch in enumerate(dataloader):
        # unpack the mini batch
        # input_nodes and output_nodes are dictionary while subgs are a list of
        # heterogeneous graphs
        input_nodes, output_nodes, subgs = mini_batch
        train(input_nodes, output_nodes, subgs)

Exclude Edges During Sampling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The examples above all belong to *node-wise sampler* because the ``indices`` argument
to the ``sample`` method represents a batch of seed node IDs. Another common type of
samplers is *edge-wise sampler* which, as name suggested, takes in a batch of seed
edge IDs to construct mini-batch data. DGL provides a utility
:func:`dgl.dataloading.as_edge_prediction_sampler` to turn a node-wise sampler to
an edge-wise sampler. To prevent information leakge, it requires the node-wise sampler
to have an additional third argument ``exclude_eids``. The code below modifies
the ``NeighborSampler`` we just defined to properly exclude edges from the sampled
subgraph:

.. code:: python

   class NeighborSampler(Sampler):
       def __init__(self, fanouts):
           super().__init__()
           self.fanouts = fanouts

       # NOTE: There is an additional third argument. For homogeneous graphs,
       #   it is an 1-D tensor of integer IDs. For heterogeneous graphs, it
       #   is a dictionary of ID tensors. We usually set its default value to be None.
       def sample(self, g, seed_nodes, exclude_eids=None):
           output_nodes = seed_nodes
           subgs = []
           for fanout in reversed(self.fanouts):
               # Sample a fixed number of neighbors of the current seed nodes.
               sg = g.sample_neighbors(seed_nodes, fanout, exclude_edges=exclude_eids)
               # Convert this subgraph to a message flow graph.
               sg = dgl.to_block(sg, seed_nodes)
               seed_nodes = sg.srcdata[NID]
               subgs.insert(0, sg)
139
140
               input_nodes = seed_nodes
           return input_nodes, output_nodes, subgs
141
142
143
144

Further Readings
~~~~~~~~~~~~~~~~~~
See :ref:`guide-minibatch-prefetching` for how to write a custom graph sampler
145
with feature prefetching.