oss.rst 3.38 KB
Newer Older
1
2
3
Optimizer state sharding
========================

4
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10


    import torch
11
12
    from torch.nn.parallel import DistributedDataParallel as DDP

13
14
15
16
17
18
19
20
21
22

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        model = myAwesomeModel().to(rank)
24
        model = DDP(model, device_ids=[rank])
25
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
26
        loss_ln = myVeryRelevantLoss()
27

Vittorio Caggiano's avatar
Vittorio Caggiano committed
28
        # optimizer specific arguments e.g. LR, momentum, etc...
29
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
30
31
32
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
33
34
35
36

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
37
38
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
39
40
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
41
42
                outputs = model(data)
                loss = loss_fn(outputs, target)
43
44
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
45
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
46
47
48
49
50
                optimizer.step()


Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
51
.. code-block:: python
52
53
54
55


    import torch
    from fairscale.optim.oss import OSS
56
    from torch.nn.parallel import DistributedDataParallel as DDP
57
58
59
60
61
62
63
64
65
66

    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # DDP
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
67
        model = myAwesomeModel().to(rank)
68
        model = DDP(model, device_ids=[rank])
69
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
70
        loss_ln = myVeryRelevantLoss()
71

Vittorio Caggiano's avatar
Vittorio Caggiano committed
72
        # optimizer specific arguments e.g. LR, momentum, etc...
73
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
74
75

        # ** NEW ** Wrap a base optimizer into OSS
76
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
77
78
79
80
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
81
82
83
84

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
85
86
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
87
88
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
89
90
                outputs = model(data)
                loss = loss_fn(outputs, target)
91
92
                loss /= world_size
                loss.backward()
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
93
                torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
94
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
95
96
97
98
99
100
101
102
103
104
105
106
107


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
108

Vittorio Caggiano's avatar
Vittorio Caggiano committed
109
to see it in action, you can test it with the following script _`tutorial_oss.py <../../../examples/tutorial_oss.py>`_