oss.rst 3.21 KB
Newer Older
1
2
3
Optimizer state sharding
========================

4
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10


    import torch
11
12
    from torch.nn.parallel import DistributedDataParallel as DDP

13
14
15
16
17
18
19
20
21
22

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        model = myAwesomeModel().to(rank)
24
        model = DDP(model, device_ids=[rank])
25
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
26
        loss_ln = myVeryRelevantLoss()
27

Vittorio Caggiano's avatar
Vittorio Caggiano committed
28
        # optimizer specific arguments e.g. LR, momentum, etc...
29
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
30
31
32
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
33
34
35
36

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
37
38
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
39
40
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
41
42
                outputs = model(data)
                loss = loss_fn(outputs, target)
43
44
45
46
47
48
49
                loss /= world_size
                loss.backward()
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
50
.. code-block:: python
51
52
53
54


    import torch
    from fairscale.optim.oss import OSS
55
    from torch.nn.parallel import DistributedDataParallel as DDP
56
57
58
59
60
61
62
63
64
65

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
66
        model = myAwesomeModel().to(rank)
67
        model = DDP(model, device_ids=[rank])
68
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
69
        loss_ln = myVeryRelevantLoss()
70

Vittorio Caggiano's avatar
Vittorio Caggiano committed
71
        # optimizer specific arguments e.g. LR, momentum, etc...
72
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
73
74

        # ** NEW ** Wrap a base optimizer into OSS
75
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
76
77
78
79
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
80
81
82
83

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
84
85
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
86
87
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
88
89
                outputs = model(data)
                loss = loss_fn(outputs, target)
90
91
92
                loss /= world_size
                loss.backward()
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
93
94
95
96
97
98
99
100
101
102
103
104
105


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
106

Vittorio Caggiano's avatar
Vittorio Caggiano committed
107
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_