minibatch-custom-sampler.rst 5.07 KB
Newer Older
1
2
.. _guide-minibatch-customizing-neighborhood-sampler:

3
6.4 Implementing Custom Graph Samplers
4
5
----------------------------------------------

6
7
8
9
Implementing custom samplers involves subclassing the
:class:`dgl.graphbolt.SubgraphSampler` base class and implementing its abstract
:attr:`_sample_subgraphs` method. The :attr:`_sample_subgraphs` method should
take in seed nodes which are the nodes to sample neighbors from:
10
11
12

.. code:: python

13
14
    def _sample_subgraphs(self, seed_nodes):
        return input_nodes, sampled_subgraphs
15

16
17
The method should return the input node IDs list and a list of subgraphs. Each
subgraph is a :class:`~dgl.graphbolt.SampledSubgraph` object.
18

19

20
21
Any other data that are required during sampling such as the graph structure,
fanout size, etc. should be passed to the sampler via the constructor.
22
23
24
25
26

The code below implements a classical neighbor sampler:

.. code:: python

27
28
29
30
31
    @functional_datapipe("customized_sample_neighbor")
    class CustomizedNeighborSampler(dgl.graphbolt.SubgraphSampler):
       def __init__(self, datapipe, graph, fanouts):
           super().__init__(datapipe)
           self.graph = graph
32
33
           self.fanouts = fanouts

34
       def _sample_subgraphs(self, seed_nodes):
35
36
37
           subgs = []
           for fanout in reversed(self.fanouts):
               # Sample a fixed number of neighbors of the current seed nodes.
38
               input_nodes, sg = g.sample_neighbors(seed_nodes, fanout)
39
               subgs.insert(0, sg)
40
41
               seed_nodes = input_nodes
           return input_nodes, subgs
42

43
To use this sampler with :class:`~dgl.graphbolt.MultiProcessDataLoader`:
44
45
46

.. code:: python

47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)

    for data in dataloader:
        input_features = data.node_features["feat"]
        output_labels = data.labels
        output_predictions = model(data.blocks, input_features)
        loss = compute_loss(output_labels, output_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()

63
64
65
66
67

Sampler for Heterogeneous Graphs
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

To write a sampler for heterogeneous graphs, one needs to be aware that
68
the argument `graph` is a heterogeneous graph while `seeds` could be a
69
70
71
dictionary of ID tensors. Most of DGL's graph sampling operators (e.g.,
the ``sample_neighbors`` and ``to_block`` functions in the above example) can
work on heterogeneous graph natively, so many samplers are automatically
72
ready for heterogeneous graph. For example, the above ``CustomizedNeighborSampler``
73
74
75
76
can be used on heterogeneous graphs:

.. code:: python

77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    import dgl.graphbolt as gb
    hg = gb.FusedCSCSamplingGraph()
    train_set = item_set = gb.ItemSetDict(
        {
            "user": gb.ItemSet(
                (torch.arange(0, 5), torch.arange(5, 10)),
                names=("seed_nodes", "labels"),
            ),
            "item": gb.ItemSet(
                (torch.arange(5, 10), torch.arange(10, 15)),
                names=("seed_nodes", "labels"),
            ),
        }
    )
    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.fetch_feature(
        feature, node_feature_keys={"user": ["feat"], "item": ["feat"]}
95
    )
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)

    for data in dataloader:
        input_features = {
            ntype: data.node_features[(ntype, "feat")]
            for ntype in data.blocks[0].srctypes
        }
        output_labels = data.labels["user"]
        output_predictions = model(data.blocks, input_features)["user"]
        loss = compute_loss(output_labels, output_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()


Exclude Edges After Sampling
114
115
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

116
117
118
119
In some cases, we may want to exclude seed edges from the sampled subgraph. For
example, in link prediction tasks, we want to exclude the edges in the
training set from the sampled subgraph to prevent information leakage. To
do so, we need to add an additional datapipe right after sampling as follows:
120
121
122

.. code:: python

123
124
    datapipe = datapipe.customized_sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
125

126
127
128
129
130
131
132
Please check the API page of :func:`~dgl.graphbolt.exclude_seed_edges` for more
details.

The above API is based on :meth:`~dgl.graphbolt.SampledSubgrahp.exclude_edges`.
If you want to exclude edges from the sampled subgraph based on some other
criteria, you could write your own transform function. Please check the method
for reference.
133

134
135
You could also refer to examples in
`Link Prediction <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__.