oss.rst 8.07 KB
Newer Older
1
2
Optimizer, Gradient and Model Sharding
=======================================
3

4
5
Using torch.nn.parallel.DistributedDataParallel leads to some wasted communications in the case of OSS, 
but it is possible and makes OSS a drop in solution in your existing torch distributed code.
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
6
Let's suppose that your trainer looks like
7

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
8
.. code-block:: python
9
10
11


    import torch
12
13
    from torch.nn.parallel import DistributedDataParallel as DDP

14
15
16
17
18
19

    def train(
        rank: int,
        world_size: int,
        epochs: int):

20
        # process group init
21
22
23
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
24
        model = myAwesomeModel().to(rank)
25
        model = DDP(model, device_ids=[rank])
26
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
27
        loss_ln = myVeryRelevantLoss()
28

Vittorio Caggiano's avatar
Vittorio Caggiano committed
29
        # optimizer specific arguments e.g. LR, momentum, etc...
30
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
31
32
33
        optimizer = torch.optim.SGD(
            params=model.parameters(),
            **base_optimizer_arguments)
34
35
36
37

        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
38
39
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
40
41
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
42
43
                outputs = model(data)
                loss = loss_fn(outputs, target)
44
45
46
47
                loss.backward()
                optimizer.step()


48
49
50
51
Then sharding the optimizer state is merely a matter of wrapping your optimizer in `fairscale.optim.OSS`, 
as follows.
DDP can be used in place of ShardedDDP in the example below, but the memory savings will be reduced 
(the gradients are not as efficiently sharded).
52

Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
53
.. code-block:: python
54
55
56
57


    import torch
    from fairscale.optim.oss import OSS
58
59
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP

60
61
62
63
64
65

    def train(
        rank: int,
        world_size: int,
        epochs: int):

66
        # process group init
67
68
69
        dist_init(rank, world_size)

        # Problem statement
Vittorio Caggiano's avatar
Vittorio Caggiano committed
70
        model = myAwesomeModel().to(rank)
71
        dataloader = mySuperFastDataloader()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
72
        loss_ln = myVeryRelevantLoss()
73

Vittorio Caggiano's avatar
Vittorio Caggiano committed
74
        # optimizer specific arguments e.g. LR, momentum, etc...
75
        base_optimizer_arguments = { "lr": 1e-4}
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
76

77
        # Wrap a base optimizer into OSS
78
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer
Benjamin Lefaudeux's avatar
Benjamin Lefaudeux committed
79
80
81
82
        optimizer = OSS(
            params=model.parameters(),
            optim=base_optimizer,
            **base_optimizer_arguments)
83

84
85
86
        # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
        model = ShardedDDP(model, optimizer)

87
88
89
        # Any relevant training loop, nothing specific to OSS. For example:
        model.train()
        for e in range(epochs):
Vittorio Caggiano's avatar
Vittorio Caggiano committed
90
91
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
92
93
                # Train
                model.zero_grad()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
94
95
                outputs = model(data)
                loss = loss_fn(outputs, target)
96
97
                loss.backward()
                optimizer.step()
Vittorio Caggiano's avatar
Vittorio Caggiano committed
98
99


100
101
The above `train` function can then be run via a `multiprocessing.spawn` call. Note that any launcher 
can be used, the only assumption being that each of the ranks lives in its own python process.
Vittorio Caggiano's avatar
Vittorio Caggiano committed
102
103
104
105
106
107
108
109
110
111

.. code-block:: python


    mp.spawn(
            train,
            args=(WORLD_SIZE, EPOCHS),
            nprocs=WORLD_SIZE,
            join=True
        )
112

113

114
115
116
117
118
119
120
Using PyTorch Automatic Mixed Precision is possible, and its actual usage will depend on whether OSS 
is used with DDP or with ShardedDDP.
If OSS is used with DDP, then the normal PyTorch GradScaler can be used, nothing needs to be changed. 
If OSS is used with ShardedDDP (to
get the gradient sharding), then a very similar flow can be used, but it requires a shard-aware GradScaler, 
which is available in `fairscale.optim.grad_scaler`. In both cases Autocast can be used as is, and the 
loss will be scaled and handled in the same way.
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
See [the original documentation] (https://pytorch.org/docs/stable/notes/amp_examples.html?highlight=automatic%20mixed%20precision)
for more information.

.. code-block:: python

    from fairscale.optim.grad_scaler import ShardedGradScaler


    # Creates model and optimizer in default precision
    model = Net().cuda()
    optimizer = optim.SGD(model.parameters(), ...)

    # Creates a ShardedGradScaler once at the beginning of training.
    scaler = ShardedGradScaler()

    for epoch in epochs:
        for input, target in data:
            optimizer.zero_grad()

            # Runs the forward pass with autocasting.
            with autocast():
                output = model(input)
                loss = loss_fn(output, target)

            # Scales loss.  Calls backward() on scaled loss to create scaled gradients.
            # Backward passes under autocast are not recommended.
            # Backward ops run in the same dtype autocast chose for corresponding forward ops.
            scaler.scale(loss).backward()

            # scaler.step() first unscales the gradients of the optimizer's assigned params.
            # If these gradients do not contain infs or NaNs, optimizer.step() is then called,
            # otherwise, optimizer.step() is skipped.
            scaler.step(optimizer)

            # Updates the scale for next iteration.
            scaler.update()
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236


Parameters can be sharded using the FullyShardedDataParallel (FSDP) API. It involves wrapping your model similar to the 
SDP API above.

.. code-block:: python


    import torch
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP


    def train(
        rank: int,
        world_size: int,
        epochs: int):

        # process group init
        dist_init(rank, world_size)

        # Problem statement
        model = myAwesomeModel().to(rank)
        dataloader = mySuperFastDataloader()
        loss_ln = myVeryRelevantLoss()

        # optimizer specific arguments e.g. LR, momentum, etc...
        base_optimizer_arguments = { "lr": 1e-4}

        # Wrap a base optimizer into OSS
        base_optimizer = torch.optim.SGD  # any pytorch compliant optimizer

        # Wrap the model into FSDP, which will reduce parameters to the proper ranks
        model = FSDP(model)

        # Any relevant training loop. For example:
        model.train()
        for e in range(epochs):
            for (data, target) in dataloader:
                data, target = data.to(rank), target.to(rank)
                # Train
                model.zero_grad()
                outputs = model(data)
                loss = loss_fn(outputs, target)
                loss.backward()
                optimizer.step()


Auto wrapping sub-modules with FSDP is a convenient way to improve training speed by overlapping 
the allgather step across the forward passes of different submodules. 
It also improves memory efficiency by freeing gathered parameters after each layer finishes executing. 
For example:

.. code-block:: python


    import torch
    from fairscale.nn.wrap import auto_wrap, enable_wrap, wrap
    from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
    from fairscale.utils.testing import DummyProcessGroup


    tfmr = torch.nn.Transformer(num_encoder_layers=2, num_decoder_layers=2)

    group = DummyProcessGroup(rank=0, size=1)
    fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
    with enable_wrap(wrapper_cls=FSDP, process_group=group, **fsdp_params):

        # Wraps layer in FSDP by default if within context
        l1 = wrap(torch.nn.Linear(5, 5))
        assert isinstance(l1, FSDP)
        assert l1.mixed_precision and l1.flatten_parameters
        # Separately Wraps children modules with more than 1e8 params
        tfmr_auto_wrapped = auto_wrap(tfmr, min_num_params=1e6)
        assert isinstance(l2, nn.Transformer)
        for l in l2.encoder.layers:
            assert isinstance(l, FSDP)
            assert l.mixed_precision and l.flatten_parameters
            assert isinstance(l.linear1, FSDP)
            assert isinstance(l.linear2, FSDP)
            assert not isinstance(l.self_attn, FSDP) # self attention is not auto-wrapped