mixed_precision.rst 5.83 KB
Newer Older
1
2
3
4
.. _guide-mixed_precision:

Chapter 8: Mixed Precision Training
===================================
Mufei Li's avatar
Mufei Li committed
5
DGL is compatible with the `PyTorch Automatic Mixed Precision (AMP) package
6
7
<https://pytorch.org/docs/stable/amp.html>`_
for mixed precision training, thus saving both training time and GPU memory
Mufei Li's avatar
Mufei Li committed
8
consumption. This feature requires DGL 0.9+.
9
10
11

Message-Passing with Half Precision
-----------------------------------
Mufei Li's avatar
Mufei Li committed
12
13
DGL allows message-passing on ``float16 (fp16)`` features for both
UDFs (User Defined Functions) and built-in functions (e.g., ``dgl.function.sum``,
14
15
``dgl.function.copy_u``).

Mufei Li's avatar
Mufei Li committed
16
The following example shows how to use DGL's message-passing APIs on half-precision
17
18
19
20
21
features:

    >>> import torch
    >>> import dgl
    >>> import dgl.function as fn
Mufei Li's avatar
Mufei Li committed
22
23
24
25
    >>> dev = torch.device('cuda')
    >>> g = dgl.rand_graph(30, 100).to(dev)  # Create a graph on GPU w/ 30 nodes and 100 edges.
    >>> g.ndata['h'] = torch.rand(30, 16).to(dev).half()  # Create fp16 node features.
    >>> g.edata['w'] = torch.rand(100, 1).to(dev).half()  # Create fp16 edge features.
26
27
    >>> # Use DGL's built-in functions for message passing on fp16 features.
    >>> g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'x'))
Mufei Li's avatar
Mufei Li committed
28
29
    >>> g.ndata['x'].dtype
    torch.float16
30
    >>> g.apply_edges(fn.u_dot_v('h', 'x', 'hx'))
Mufei Li's avatar
Mufei Li committed
31
32
33
34
    >>> g.edata['hx'].dtype
    torch.float16

    >>> # Use UDFs for message passing on fp16 features.
35
36
37
38
39
40
41
42
43
44
    >>> def message(edges):
    ...     return {'m': edges.src['h'] * edges.data['w']}
    ...
    >>> def reduce(nodes):
    ...     return {'y': torch.sum(nodes.mailbox['m'], 1)}
    ...
    >>> def dot(edges):
    ...     return {'hy': (edges.src['h'] * edges.dst['y']).sum(-1, keepdims=True)}
    ...
    >>> g.update_all(message, reduce)
Mufei Li's avatar
Mufei Li committed
45
46
    >>> g.ndata['y'].dtype
    torch.float16
47
    >>> g.apply_edges(dot)
Mufei Li's avatar
Mufei Li committed
48
49
    >>> g.edata['hy'].dtype
    torch.float16
50
51
52
53
54
55
56

End-to-End Mixed Precision Training
-----------------------------------
DGL relies on PyTorch's AMP package for mixed precision training,
and the user experience is exactly
the same as `PyTorch's <https://pytorch.org/docs/stable/notes/amp_examples.html>`_.

Mufei Li's avatar
Mufei Li committed
57
58
59
By wrapping the forward pass with ``torch.cuda.amp.autocast()``, PyTorch automatically
selects the appropriate datatype for each op and tensor. Half precision tensors are memory
efficient, most operators on half precision tensors are faster as they leverage GPU tensorcores.
60

Mufei Li's avatar
Mufei Li committed
61
.. code::
62

Mufei Li's avatar
Mufei Li committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
    import torch.nn.functional as F
    from torch.cuda.amp import autocast

    def forward(g, feat, label, mask, model, use_fp16):
        with autocast(enabled=use_fp16):
            logit = model(g, feat)
            loss = F.cross_entropy(logit[mask], label[mask])
            return loss

Small Gradients in ``float16`` format have underflow problems (flush to zero).
PyTorch provides a ``GradScaler`` module to address this issue. It multiplies
the loss by a factor and invokes backward pass on the scaled loss to prevent
the underflow problem. It then unscales the computed gradients before the optimizer
updates the parameters. The scale factor is determined automatically.
77
78
79

.. code::

Mufei Li's avatar
Mufei Li committed
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
    from torch.cuda.amp import GradScaler

    scaler = GradScaler()

    def backward(scaler, loss, optimizer):
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

The following example trains a 3-layer GAT on the Reddit dataset (w/ 114 million edges).
Pay attention to the differences in the code when ``use_fp16`` is activated or not.

.. code::

    import torch
95
96
97
98
    import torch.nn as nn
    import dgl
    from dgl.data import RedditDataset
    from dgl.nn import GATConv
Mufei Li's avatar
Mufei Li committed
99
    from dgl.transforms import AddSelfLoop
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124

    use_fp16 = True

    class GAT(nn.Module):
        def __init__(self,
                     in_feats,
                     n_hidden,
                     n_classes,
                     heads):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(GATConv(in_feats, n_hidden, heads[0], activation=F.elu))
            self.layers.append(GATConv(n_hidden * heads[0], n_hidden, heads[1], activation=F.elu))
            self.layers.append(GATConv(n_hidden * heads[1], n_classes, heads[2], activation=F.elu))

        def forward(self, g, h):
            for l, layer in enumerate(self.layers):
                h = layer(g, h)
                if l != len(self.layers) - 1:
                    h = h.flatten(1)
                else:
                    h = h.mean(1)
            return h

    # Data loading
Mufei Li's avatar
Mufei Li committed
125
126
127
128
    transform = AddSelfLoop()
    data = RedditDataset(transform)
    dev = torch.device('cuda')

129
    g = data[0]
Mufei Li's avatar
Mufei Li committed
130
    g = g.int().to(dev)
131
    train_mask = g.ndata['train_mask']
Mufei Li's avatar
Mufei Li committed
132
133
134
135
    feat = g.ndata['feat']
    label = g.ndata['label']

    in_feats = feat.shape[1]
136
137
138
139
    n_hidden = 256
    n_classes = data.num_classes
    heads = [1, 1, 1]
    model = GAT(in_feats, n_hidden, n_classes, heads)
Mufei Li's avatar
Mufei Li committed
140
141
    model = model.to(dev)
    model.train()
142
143
144
145
146
147

    # Create optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=5e-4)

    for epoch in range(100):
        optimizer.zero_grad()
Mufei Li's avatar
Mufei Li committed
148
        loss = forward(g, feat, label, train_mask, model, use_fp16)
149
150
151

        if use_fp16:
            # Backprop w/ gradient scaling
Mufei Li's avatar
Mufei Li committed
152
            backward(scaler, loss, optimizer)
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
        else:
            loss.backward()
            optimizer.step()

        print('Epoch {} | Loss {}'.format(epoch, loss.item()))

On a NVIDIA V100 (16GB) machine, training this model without fp16 consumes
15.2GB GPU memory; with fp16 turned on, the training consumes 12.8G
GPU memory, the loss converges to similar values in both settings.
If we change the number of heads to ``[2, 2, 2]``, training without fp16
triggers GPU OOM(out-of-memory) issue while training with fp16 consumes
15.7G GPU memory.

DGL is still improving its half-precision support and the compute kernel's
performance is far from optimal, please stay tuned to our future updates.