pipe.rst 1.89 KB
Newer Older
1
Model sharding using Pipeline Parallel
2
======================================
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44

Let us start with a toy model that contains two linear layers.

.. code-block:: default


    import torch
    import torch.nn as nn

    class ToyModel(nn.Module):
        def __init__(self):
            super(ToyModel, self).__init__()
            self.net1 = torch.nn.Linear(10, 10)
            self.relu = torch.nn.ReLU()
            self.net2 = torch.nn.Linear(10, 5)

        def forward(self, x):
            x = self.relu(self.net1(x))
            return self.net2(x)

    model = ToyModel()

To run this model on 2 GPUs we need to convert the model
to ``torch.nn.Sequential`` and then wrap it with ``fairscale.nn.Pipe``.

.. code-block:: default


    import fairscale
    import torch
    import torch.nn as nn

    model = nn.Sequential(
                torch.nn.Linear(10, 10),
                torch.nn.ReLU(),
                torch.nn.Linear(10, 5)
            )

    model = fairscale.nn.Pipe(model, balance=[2, 1])

This will run the first two layers on ``cuda:0`` and the last
layer on ``cuda:1``. To learn more, visit the `Pipe <../api/nn/pipe.html>`_ documentation.
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

You can then define any optimizer and loss function

.. code-block:: default


    import torch.optim as optim
    import torch.nn.functional as F

    optimizer = optim.SGD(model.parameters(), lr=0.001)
    loss_fn = F.nll_loss

    optimizer.zero_grad()
    target = torch.randint(0,2,size=(20,1)).squeeze()
    data = torch.randn(20, 10)
60

61
62


63
Finally, to run the model and compute the loss function, make sure that outputs and target are on the same device.
64

65
.. code-block:: default
66
67
68
69
70
71
72
73
74
75
76
77
78

    device = model.devices[0]
    ## outputs and target need to be on the same device
    # forward step
    outputs = model(data.to(device))
    # compute loss
    loss = loss_fn(outputs.to(device), target.to(device))

    # backward + optimize
    loss.backward()
    optimizer.step()