minibatch-gpu-sampling.rst 4.22 KB
Newer Older
1
2
3
4
5
6
.. _guide-minibatch-gpu-sampling:

6.7 Using GPU for Neighborhood Sampling
---------------------------------------

DGL since 0.7 has been supporting GPU-based neighborhood sampling, which has a significant
7
8
9
speed advantage over CPU-based neighborhood sampling.  If you estimate that your graph 
can fit onto GPU and your model does not take a lot of GPU memory, then it is best to
put the graph onto GPU memory and use GPU-based neighbor sampling.
10
11

For example, `OGB Products <https://ogb.stanford.edu/docs/nodeprop/#ogbn-products>`_ has
12
13
14
2.4M nodes and 61M edges.  The graph takes less than 1GB since the memory consumption of
a graph depends on the number of edges.  Therefore it is entirely possible to fit the
whole graph onto GPU.
15
16
17
18
19


Using GPU-based neighborhood sampling in DGL data loaders
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

20
One can use GPU-based neighborhood sampling with DGL data loaders via:
21

22
23
24
* Put the graph onto GPU.

* Put the ``train_nid`` onto GPU.
25

26
27
* Set ``device`` argument to a GPU device.

28
29
30
* Set ``num_workers`` argument to 0, because CUDA does not allow multiple processes
  accessing the same context.

31
All the other arguments for the :class:`~dgl.dataloading.DataLoader` can be
32
33
34
35
36
the same as the other user guides and tutorials.

.. code:: python

   g = g.to('cuda:0')
37
   train_nid = train_nid.to('cuda:0')
38
   dataloader = dgl.dataloading.DataLoader(
39
       g,                                # The graph must be on GPU.
40
       train_nid,                        # train_nid must be on GPU.
41
42
43
44
45
46
       sampler,
       device=torch.device('cuda:0'),    # The device argument must be GPU.
       num_workers=0,                    # Number of workers must be 0.
       batch_size=1000,
       drop_last=False,
       shuffle=True)
47
48
49
50
51
52
53
54
55
56

.. note::

  GPU-based neighbor sampling also works for custom neighborhood samplers as long as
  (1) your sampler is subclassed from :class:`~dgl.dataloading.BlockSampler`, and (2)
  your sampler entirely works on GPU.


Using CUDA UVA-based neighborhood sampling in DGL data loaders
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
57
58

.. note::
59
60
61
62
63
64
65
   New feature introduced in DGL 0.8.

For the case where the graph is too large to fit onto the GPU memory, we introduce the
CUDA UVA (Unified Virtual Addressing)-based sampling, in which GPUs perform the sampling
on the graph pinned on CPU memory via zero-copy access.
You can enable UVA-based neighborhood sampling in DGL data loaders via:

66
* Put the ``train_nid`` onto GPU.
67
68
69
70
71
72

* Set ``device`` argument to a GPU device.

* Set ``num_workers`` argument to 0, because CUDA does not allow multiple processes
  accessing the same context.

73
74
* Set ``use_uva=True``.

75
All the other arguments for the :class:`~dgl.dataloading.DataLoader` can be
76
77
78
79
the same as the other user guides and tutorials.

.. code:: python

80
   train_nid = train_nid.to('cuda:0')
81
   dataloader = dgl.dataloading.DataLoader(
82
83
       g,
       train_nid,                        # train_nid must be on GPU.
84
85
86
87
88
       sampler,
       device=torch.device('cuda:0'),    # The device argument must be GPU.
       num_workers=0,                    # Number of workers must be 0.
       batch_size=1000,
       drop_last=False,
89
90
       shuffle=True,
       use_uva=True)                     # Set use_uva=True
91
92
93
94
95
96
97

UVA-based sampling is the recommended solution for mini-batch training on large graphs,
especially for multi-GPU training.

.. note::

  To use UVA-based sampling in multi-GPU training, you should first materialize all the
98
99
  necessary sparse formats of the graph before spawning training processes.
  Refer to our `GraphSAGE example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/multi_gpu_node_classification.py>`_ for more details.
100
101
102
103
104


Using GPU-based neighbor sampling with DGL functions
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

105
106
You can build your own GPU sampling pipelines with the following functions that support
operating on GPU:
107
108
109
110
111

* :func:`dgl.sampling.sample_neighbors`

  * Only has support for uniform sampling; non-uniform sampling can only run on CPU.

112
113
114
115
116
117
118
119
120
121
122
Subgraph extraction ops:

* :func:`dgl.node_subgraph`
* :func:`dgl.edge_subgraph`
* :func:`dgl.in_subgraph`
* :func:`dgl.out_subgraph`

Graph transform ops for subgraph construction:

* :func:`dgl.to_block`
* :func:`dgl.compact_graph`