Unverified Commit 4f1da61b authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Doc] fix inconsistencies in minibatch user guide (#2032)

parent e291f503
...@@ -19,16 +19,18 @@ PyTorch node/edge DataLoaders ...@@ -19,16 +19,18 @@ PyTorch node/edge DataLoaders
General collating functions General collating functions
``````````````````````````` ```````````````````````````
.. currentmodule:: dgl.dataloading .. currentmodule:: dgl.dataloading.dataloader
.. autoclass:: Collator .. autoclass:: Collator
:members: dataset, collate :members: dataset, collate
.. autoclass:: NodeCollator .. autoclass:: NodeCollator
:members: dataset, collate :members: dataset, collate
:show-inheritance:
.. autoclass:: EdgeCollator .. autoclass:: EdgeCollator
:members: dataset, collate :members: dataset, collate
:show-inheritance:
.. _api-dataloading-neighbor-sampling: .. _api-dataloading-neighbor-sampling:
...@@ -44,8 +46,14 @@ Base Multi-layer Neighborhood Sampling Class ...@@ -44,8 +46,14 @@ Base Multi-layer Neighborhood Sampling Class
Uniform Node-wise Neighbor Sampling (GraphSAGE style) Uniform Node-wise Neighbor Sampling (GraphSAGE style)
````````````````````````````````````````````````````` `````````````````````````````````````````````````````
.. currentmodule:: dgl.dataloading.neighbor
.. autoclass:: MultiLayerNeighborSampler .. autoclass:: MultiLayerNeighborSampler
:members: sample_frontier :members: sample_frontier
:show-inheritance:
.. autoclass:: MultiLayerFullNeighborSampler
:show-inheritance:
.. _api-dataloading-negative-sampling: .. _api-dataloading-negative-sampling:
......
...@@ -59,7 +59,7 @@ computation dependencies needed for each layer given the nodes we wish ...@@ -59,7 +59,7 @@ computation dependencies needed for each layer given the nodes we wish
to compute on. to compute on.
The simplest neighborhood sampler is The simplest neighborhood sampler is
:class:`~dgl.dataloading.pytorch.MultiLayerFullNeighborSampler` :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`
which makes the node gather messages from all of its neighbors. which makes the node gather messages from all of its neighbors.
To use a sampler provided by DGL, one also need to combine it with To use a sampler provided by DGL, one also need to combine it with
...@@ -241,7 +241,7 @@ removed for simplicity): ...@@ -241,7 +241,7 @@ removed for simplicity):
Some of the samplers provided by DGL also support heterogeneous graphs. Some of the samplers provided by DGL also support heterogeneous graphs.
For example, one can still use the provided For example, one can still use the provided
:class:`~dgl.dataloading.MultiLayerFullNeighborSampler` class and :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` class and
:class:`~dgl.dataloading.pytorch.NodeDataLoader` class for :class:`~dgl.dataloading.pytorch.NodeDataLoader` class for
stochastic training. For full-neighbor sampling, the only difference stochastic training. For full-neighbor sampling, the only difference
would be that you would specify a dictionary of node would be that you would specify a dictionary of node
...@@ -1152,17 +1152,17 @@ classification. ...@@ -1152,17 +1152,17 @@ classification.
To implement your own neighborhood sampling strategy, you basically To implement your own neighborhood sampling strategy, you basically
replace the ``sampler`` object with your own. To do that, lets first replace the ``sampler`` object with your own. To do that, lets first
see what :class:`~dgl.dataloading.BlockSampler`, the parent class of see what :class:`~dgl.dataloading.dataloader.BlockSampler`, the parent class of
:class:`~dgl.dataloading.MultiLayerFullNeighborSampler`, is. :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`, is.
:class:`~dgl.dataloading.BlockSampler` is responsible for :class:`~dgl.dataloading.dataloader.BlockSampler` is responsible for
generating the list of blocks starting from the last layer, with method generating the list of blocks starting from the last layer, with method
:meth:`~dgl.dataloading.BlockSampler.sample_blocks`. The default implementation of :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_blocks`. The default implementation of
``sample_blocks`` is to iterate backwards, generating the frontiers and ``sample_blocks`` is to iterate backwards, generating the frontiers and
converting them to blocks. converting them to blocks.
Therefore, for neighborhood sampling, **you only need to implement Therefore, for neighborhood sampling, **you only need to implement
the**\ :meth:`~dgl.dataloading.BlockSampler.sample_frontier`\ **method**. Given which the**\ :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_frontier`\ **method**. Given which
layer the sampler is generating frontier for, as well as the original layer the sampler is generating frontier for, as well as the original
graph and the nodes to compute representations, this method is graph and the nodes to compute representations, this method is
responsible for generating a frontier for them. responsible for generating a frontier for them.
...@@ -1171,7 +1171,7 @@ Meanwhile, you also need to pass how many GNN layers you have to the ...@@ -1171,7 +1171,7 @@ Meanwhile, you also need to pass how many GNN layers you have to the
parent class. parent class.
For example, the implementation of For example, the implementation of
:class:`~dgl.dataloading.MultiLayerFullNeighborSampler` can :class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` can
go as follows. go as follows.
.. code:: python .. code:: python
...@@ -1184,7 +1184,7 @@ go as follows. ...@@ -1184,7 +1184,7 @@ go as follows.
frontier = dgl.in_subgraph(g, seed_nodes) frontier = dgl.in_subgraph(g, seed_nodes)
return frontier return frontier
:class:`dgl.dataloading.MultiLayerNeighborSampler`, a more :class:`dgl.dataloading.neighbor.MultiLayerNeighborSampler`, a more
complicated neighbor sampler class that allows you to sample a small complicated neighbor sampler class that allows you to sample a small
number of neighbors to gather message for each node, goes as follows. number of neighbors to gather message for each node, goes as follows.
...@@ -1212,7 +1212,7 @@ nodes with a probability, one can simply define the sampler as follows: ...@@ -1212,7 +1212,7 @@ nodes with a probability, one can simply define the sampler as follows:
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.sampling.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, n_layers):
super().__init__() super().__init__()
...@@ -1241,7 +1241,7 @@ iterating over the seed nodes as usual. ...@@ -1241,7 +1241,7 @@ iterating over the seed nodes as usual.
.. code:: python .. code:: python
sampler = MultiLayerDropoutSampler(0.5, 2) sampler = MultiLayerDropoutSampler(0.5, 2)
dataloader = dgl.sampling.NodeDataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, train_nids, sampler, g, train_nids, sampler,
batch_size=1024, batch_size=1024,
shuffle=True, shuffle=True,
...@@ -1273,7 +1273,7 @@ all edge types, so that it can work on heterogeneous graphs as well. ...@@ -1273,7 +1273,7 @@ all edge types, so that it can work on heterogeneous graphs as well.
.. code:: python .. code:: python
class MultiLayerDropoutSampler(dgl.sampling.BlockSampler): class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
def __init__(self, p, n_layers): def __init__(self, p, n_layers):
super().__init__() super().__init__()
...@@ -1565,8 +1565,8 @@ on how messages are aggregated and combined as well. ...@@ -1565,8 +1565,8 @@ on how messages are aggregated and combined as well.
self.hidden_features self.hidden_features
if l != self.n_layers - 1 if l != self.n_layers - 1
else self.out_features) else self.out_features)
sampler = dgl.sampling.MultiLayerFullNeighborSampler(1) sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
dataloader = dgl.sampling.DataLoader( dataloader = dgl.dataloading.NodeDataLoader(
g, torch.arange(g.number_of_nodes()), sampler, g, torch.arange(g.number_of_nodes()), sampler,
batch_size=batch_size, batch_size=batch_size,
shuffle=True, shuffle=True,
......
...@@ -199,7 +199,7 @@ class EdgeDataLoader(DataLoader): ...@@ -199,7 +199,7 @@ class EdgeDataLoader(DataLoader):
See also See also
-------- --------
:class:`~dgl.dataloading.EdgeCollator` :class:`~dgl.dataloading.dataloader.EdgeCollator`
For end-to-end usages, please refer to the following tutorial/examples: For end-to-end usages, please refer to the following tutorial/examples:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment