oss.rst 4.91 KB
Newer Older
1
2
3
Optimizer state sharding
========================

4
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
5
Let's suppose that your trainer looks like
6

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
7
.. code-block:: python
8
9
10


    import torch
11
12
    from torch.nn.parallel import DistributedDataParallel as DDP

13
14
15
16
17
18

    def train(
        rank: int,
        world_size: int,
        epochs: int):

19
        # process group init
20
21
22
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
23
        model = myAwesomeModel().to(rank)
24
        model = DDP(model, device_ids=[rank])
25
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
26
        loss_ln = myVeryRelevantLoss()
27

Vittorio Caggiano's avatar
Vittorio Caggiano committed
28
        # optimizer specific arguments e.g. LR, momentum, etc...
29
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
30
31
32
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
33
34
35
36

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
37
38
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
39
40
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
41
42
                outputs = model(data)
                loss = loss_fn(outputs, target)
43
44
45
46
                loss.backward()
                optimizer.step()


47
48
Then sharding the optimizer state is merely a matter of wrapping your optimizer in fairscale.optim.OSS, as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced (the gradients are not as efficiently sharded)
49

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
50
.. code-block:: python
51
52
53
54


    import torch
    from fairscale.optim.oss import OSS
55
56
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

57
58
59
60
61
62

    def train(
        rank: int,
        world_size: int,
        epochs: int):

63
        # process group init
64
65
66
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
67
        model = myAwesomeModel().to(rank)
68
        model = ShardedDDP(model, device_ids=[rank])
69
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
70
        loss_ln = myVeryRelevantLoss()
71

Vittorio Caggiano's avatar
Vittorio Caggiano committed
72
        # optimizer specific arguments e.g. LR, momentum, etc...
73
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
74
75

        # ** NEW ** Wrap a base optimizer into OSS
76
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
77
78
79
80
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
81
82
83
84

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
85
86
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
87
88
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
89
90
                outputs = model(data)
                loss = loss_fn(outputs, target)
91
92
                loss.backward()
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
93
94
95
96
97
98
99
100
101
102
103
104
105


The above `train` function will then need to be run via a `multiprocessing.spawn` function.

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
106

107

108
to see it in action, you can test it with the following script `here <../../../examples/tutorial_oss.py>`_.
109

110

111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
Using PyTorch Automatic Mixed Precision is possible, but it requires a shard-aware GradScaler, which is available in
`fairscale.optim.grad_scaler`. Autocast can be used as is, and the loss will be scaled and handled in the same way.
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information.

.. code-block:: python



    from fairscale.optim.grad_scaler import ShardedGradScaler


    # Creates model and optimizer in default precision
    model = Net().cuda()
    optimizer = optim.SGD(model.parameters(), ...)

    # Creates a ShardedGradScaler once at the beginning of training.
    scaler = ShardedGradScaler()

    for epoch in epochs:
        for input, target in data:
            optimizer.zero_grad()

            # Runs the forward pass with autocasting.
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(loss).backward()

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the scale for next iteration.
            scaler.update()