MegatronGPT2Tutorial.md 16.1 KB
Newer Older
Shaden Smith's avatar
Shaden Smith committed
1
2
# Tutorial: Megatron-LM GPT2 with DeepSpeed

Jeff Rasley's avatar
Jeff Rasley committed
3
4
5
If you haven't already, we advise you to first read through the [Getting
Started](../../README.md#getting-started) guide before stepping through this
tutorial.
Shaden Smith's avatar
Shaden Smith committed
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333

In this tutorial we will be adding DeepSpeed to Megatron-LM GPT2 model, which
is a large, powerful transformer. Megatron-LM supports model-parallel and multi-node
training. Please see the corresponding paper for more details: [Megatron-LM:
Training Multi-Billion Parameter Language Models Using Model
Parallelism](https://arxiv.org/abs/1909.08053).

First, we discuss data and environment setup and how to train the GPT-2 model with the
original Megatron-LM. Next, we proceed step-by-step in enabling this model to run with
DeepSpeed. Finally, we demonstrate the **_performance gains_**, and **_memory footprint
reduction_** from using DeepSpeed.

## 1 Training GPT-2 with the Original Megatron-LM

The original model code from
[Megatron-LM](https://github.com/NVIDIA/Megatron-LM).  We've copied this repo
under
[DeepSpeedExamples/Megatron-LM/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM)
and made it available as a submodule. To download, execute:
```bash
git submodule update --init --recursive
```

### 1.1 Training Data Setup
* Follow Megatron's [instructions](https://github.com/NVIDIA/Megatron-LM#collecting-gpt2-webtext-data)
  to download the webtext data and place a symbolic link under `DeepSpeedExamples/Megatron-LM/data`:

### 1.2 Running Unmodified Megatron-LM GPT2 model

* For a single GPU run:
    - change `scripts/pretrain_gpt2.sh`, set its `--train-data` argument as `"webtext"`.
    - run `bash scripts/pretrain_gpt2.sh`

* For multiple GPUs and/or nodes run:
    - change `scripts/pretrain_gpt2_model_parallel.sh`
        - set its `--train-data` argument as `"webtext"`
        - `GPUS_PER_NODE` indicates how many GPUs per node involved in the testing
        - `NNODES` indicates how many nodes involved in the testing

    - run `bash scripts/pretrain_gpt2_model_parallel.sh`


## 2 Enabling DeepSpeed

To use DeepSpeed we will modify three files :

* `arguments.py` : Arguments configurations
* `pretrain_gpt2.py` : Main entry point for training
* `utils.py` : Checkpoints saving and loading utilities


### 2.1 Argument Parsing
The first step is to apply DeepSpeed is adding DeepSpeed arguments to
Megatron-LM GPT2 model, using `deepspeed.add_config_arguments()` in
`arguments.py`.

```python
def get_args():
    """Parse all the args."""

    parser = argparse.ArgumentParser(description='PyTorch BERT Model')
    parser = add_model_config_args(parser)
    parser = add_fp16_config_args(parser)
    parser = add_training_args(parser)
    parser = add_evaluation_args(parser)
    parser = add_text_generate_args(parser)
    parser = add_data_args(parser)

    # Include DeepSpeed configuration arguments
    parser = deepspeed.add_config_arguments(parser)
```



### 2.2 Initialization and Training
We modify `pretrain.py` to enable training with DeepSpeed.

#### 2.2.1 Initialization
We use `deepspeed.initialize` to create `model_engine`, `optimizer` and LR
`scheduler`. Below is its definition:
```python
def initialize(args,
               model,
               optimizer=None,
               model_parameters=None,
               training_data=None,
               lr_scheduler=None,
               mpu=None,
               dist_init_required=True,
               collate_fn=None):
```

For the Megatron-LM GPT2 model, we initialize DeepSpeed in its
`setup_model_and_optimizer()` function as below, to pass the raw `model`,
`optimizer`, `args`, `lr_scheduler` and `mpu`.
```python
def setup_model_and_optimizer(args):
    """Setup model and optimizer."""

    model = get_model(args)
    optimizer = get_optimizer(model, args)
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.deepspeed:
        import deepspeed

        print_rank_0("DeepSpeed is enabled.")

        model, optimizer, _, lr_scheduler = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False
       )
```


Note that when FP16 is enabled, Megatron-LM GPT2 adds a wrapper to the `Adam`
optimizer. DeepSpeed has its own FP16 Optimizer, so we need to pass the `Adam`
optimizer to DeepSpeed directly without any wrapper. We return the unwrapped
Adam optimizer from `get_optimizer()` when DeepSpeed is enabled.
```python
def get_optimizer(model, args):
    """Setup the optimizer."""

    ......

    # Use Adam.
    optimizer = Adam(param_groups,
                     lr=args.lr, weight_decay=args.weight_decay)

    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer
```

#### 2.2.2 Using the Training API
The `model` returned by `deepspeed.initialize` is the _DeepSpeed Model Engine_
that we will use to train the model using the forward, backward and step API.


##### Forward Propagation
The forward propagation API is compatible to PyTorch and no change is required.


##### Backward Propagation
Backward propagation is done by calling `backward(loss)` directly on the model engine.

```python
    def backward_step(optimizer, model, lm_loss, args, timers):
        """Backward step."""

        # Total loss.
        loss = lm_loss

        # Backward pass.
        if args.deepspeed:
            model.backward(loss)
        else:
            optimizer.zero_grad()
            if args.fp16:
                optimizer.backward(loss, update_master_grads=False)
            else:
                loss.backward()
```

Zeroing the gradients is handled automatically by DeepSpeed after the weights
have been updated using a mini-batch.

Furthermore, DeepSpeed addresses distributed data parallel and FP16 under the
hood, simplifying code in multiple places.

(A) DeepSpeed also performs gradient averaging automatically at the gradient
accumulation boundaries. So we skip the allreduce communication.

   ```python
        if args.deepspeed:
            # DeepSpeed backward propagation already addressed all reduce communication.
            # Reset the timer to avoid breaking timer logs below.
            timers('allreduce').reset()
        else:
            torch.distributed.all_reduce(reduced_losses.data)
            reduced_losses.data = reduced_losses.data / args.world_size
            if not USE_TORCH_DDP:
                timers('allreduce').start()
                model.allreduce_params(reduce_after=False,
                                       fp32_allreduce=args.fp32_allreduce)
                timers('allreduce').stop()

   ```

(B) We also skip updating master gradients, since DeepSpeed addresses it internally.

   ```python
        # Update master gradients.
        if not args.deepspeed:
            if args.fp16:
                optimizer.update_master_grads()

            # Clipping gradients helps prevent the exploding gradient.
            if args.clip_grad > 0:
                if not args.fp16:
                    mpu.clip_grad_norm(model.parameters(), args.clip_grad)
                else:
                    optimizer.clip_master_grads(args.clip_grad)

        return lm_loss_reduced

   ```

##### Updating the Model Parameters
The `step()` function in DeepSpeed engine updates the model parameters as well
as the learning rate.

```python
     if args.deepspeed:
         model.step()
     else:
         optimizer.step()

         # Update learning rate.
         if not (args.fp16 and optimizer.overflow):
             lr_scheduler.step()
         else:
             skipped_iter = 1

```



##### Loss Scaling
The GPT2 training script logs the loss scaling value during training. Inside,
the DeepSpeed optimizer, this value is stored as `cur_scale` instead of
`loss_scale` in Megatron's optimizer. Therefore, we appropriately replace it in
the logging string.

```python
             if args.fp16:
                 log_string += ' loss scale {:.1f} |'.format(
                     optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)

```


### 2.3 Checkpoints Saving & Loading

DeepSpeed engine has flexible APIs for checkpoint saving and loading, to handle
the states from both the client model and its own internal.

```python
def save_checkpoint(self, save_dir, tag, client_state={})
def load_checkpoint(self, load_dir, tag)
```

Applying DeepSpeed needs to update utils.py in which Megatron-LM GPT2 saves and
loads its checkpoints.

A new function `save_ds_checkpoint()` is created as below for DeepSpeed, it
collects the client model states and passes to DeepSpeed engine by calling
`save_checkpoint()` of DeepSpeed.

```python
 def save_ds_checkpoint(iteration, model, args):
     """Save a model checkpoint."""

     sd = {}
     sd['iteration'] = iteration
     # rng states.
     if not args.no_save_rng:
         sd['random_rng_state'] = random.getstate()
         sd['np_rng_state'] = np.random.get_state()
         sd['torch_rng_state'] = torch.get_rng_state()
         sd['cuda_rng_state'] = torch.cuda.get_rng_state()
         sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

     model.save_checkpoint(args.save, iteration, client_state = sd)

```

In Megatron-LM GPT2 `save_checkpoint()` function, adds following lines to
invoke the above function for DeepSpeed.

```python
 def save_checkpoint(iteration, model, optimizer,
                     lr_scheduler, args):
     """Save a model checkpoint."""
     if args.deepspeed:
         save_ds_checkpoint(iteration, model, args)
     else:
		......

```

In `load_checkpoint()` function, use DeepSpeed loading checkpoint API as below,
and return the states for the client model.

```python
 def load_checkpoint(model, optimizer, lr_scheduler, args):
     """Load a model checkpoint."""

     iteration, release = get_checkpoint_iteration(args)

     if args.deepspeed:
         checkpoint_name, sd = model.load_checkpoint(args.load, iteration)

         if checkpoint_name is None:
             if mpu.get_data_parallel_rank() == 0:
                 print("Unable to load checkpoint.")
             return iteration
     else:
         ......

```

### 2.4 Train  scripts
Assume webtext data was prepared in previous step, to start training
Megatron-LM GPT2 model with DeepSpeed applied, execute the following command to
start training.

- Single GPU run
  - run `bash scripts/ds_pretrain_gpt2.sh`
- Multiple GPUs/Nodes run
  - run `bash scripts/ds_pretrain_gpt2_model_parallel.sh`



Jeff Rasley's avatar
Jeff Rasley committed
334
## 3 Performance Improvements
Shaden Smith's avatar
Shaden Smith committed
335
336
337
338
339
340
341
342
343
DeepSpeed enables training very large models effectively via the advanced [ZeRO
optimizer](https://arxiv.org/abs/1910.02054v2). ZeRO significantly reduces the memory
footprint for training large models which means large models can be trained with i) less
model parallelism and ii) larger batch sizes. A lower model parallelism degree improves
training efficiency by increasing the granularity of the computation such as the matrix
multiplication where performance is directly related to the size of the matrices.
Furthermore, less model parallelism also results in less communication between model
parallel GPUs, which further boosts performance.  Larger batch size has a similar effect
of increasing the computational granularity as well as reducing communication, also
Jeff Rasley's avatar
Jeff Rasley committed
344
345
346
resulting in better performance. Therefore, DeepSpeed combines ZeRO-powered data parallelism with
Megatron-LM tensor-slicing model parallelism, which is
significantly faster than using Megatron-LM alone.
Shaden Smith's avatar
Shaden Smith committed
347
348
349
350
351
352
353
354
355
356

The observed performance improvements depend on several factors such as the memory per
GPU, the local GPU interconnect (i.e., PCI-E vs NVLINK vs NVSwitch), the model size,
inter node network interconnect, etc. Below, we show some of the performance improvements
from using DeepSpeed over Megatron on a 16 GPU Low Bandwidth (40 Gbps) cluster and a 400 GPU DGX-2 High Bandwidth (800 Gbps) cluster.
For details please see the [ZeRO Paper](https://arxiv.org/abs/1910.02054v2). We also
present performance improvement on a 64 GPU cluster along with detailed configuration
analysis to show where the improvements come from.

![DeepSpeed-vs-Megatron](../figures/DeepSpeed-vs-Megatron.png)
357
358
359
360
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
</p>

Shaden Smith's avatar
Shaden Smith committed
361
362
363
364
365
366
367
368
369

### 3.1 On Low Bandwidth GPU Cluster
The figure above shows that training 1.5B parameter model with DeepSpeed is
nearly 4x faster than without DeepSpeed on a cluster with 4 nodes, 4 GPU per
node, and 16 GPUs total. These GPUs have 16GB of memory each, and PCI-E
interconnects GPUs within a node, and 40 Gbps infiniband across nodes.

The performance improvement comes from lower model parallelism degree and
larger batch size as discussed earlier. Training 1.5B parameter model with
Jeff Rasley's avatar
Jeff Rasley committed
370
Megatron-LM alone requires 4-way model parallelism, and can only fit an effective
Shaden Smith's avatar
Shaden Smith committed
371
batch size of 32 using all 16 GPUs. On the other hand, DeepSpeed does not
Jeff Rasley's avatar
Jeff Rasley committed
372
require any model-parallelism to train this model, and can support an
Shaden Smith's avatar
Shaden Smith committed
373
374
375
376
377
378
379
effective batch size of 128 without running out of memory, resulting in
significantly higher performance.


### 3.2 On High bandwidth DGX-2 GPU Cluster
Each GPU on the DGX-2 cluster has 32 GB of memory, and GPUs inside a box is connected via
the high-bandwidth NVSwitch. DGX-2 nodes are connected to each other via 800 Gbps (8 x 100Gbps) infiniband interconnect. As such, running a 1.5B model on DGX-2 requires less model
Jeff Rasley's avatar
Jeff Rasley committed
380
parallelism, and the performance improvement from DeepSpeed for this model size is less
Shaden Smith's avatar
Shaden Smith committed
381
382
significant. However, at larger model sizes, Megatron still requires significantly larger
model parallelism degree, and can only run much smaller batch sizes than DeepSpeed.
Jeff Rasley's avatar
Jeff Rasley committed
383
384
Therefore, as the model sizes get larger, DeepSpeed, by coming ZeRO with Megatron model parallelism, starts to significantly outperform
using Megatron-LM alone.
Shaden Smith's avatar
Shaden Smith committed
385
386
387
388
389
390
391
392


### 3.3 Performance Improvements with Configuration Details
The figure below compares DeepSpeed with Megatron on a 64 GPU cluster with 4
DGX-2 nodes. To give the readers a clear idea of source of the performance
improvements, we also present the configuration table for both Megatron and
DeepSpeed. It shows the smallest model parallelism degree and the largest batch
size that can be used to train these models without running out of memory. As
Jeff Rasley's avatar
Jeff Rasley committed
393
394
discussed above, the tables demonstrate that DeepSpeed runs with smaller model parallelism degree
and achieves better performance.
Shaden Smith's avatar
Shaden Smith committed
395
396

![DeepSpeed Performance SpeedUp](../figures/megatron-gpt2-perf-test.png)
397
398
399
400
<p align="center">
<em>The figure depicts system throughput improvements of DeepSpeed (combining ZeRO-powered data parallelism with model parallelism of Nvidia Megatron-LM) over using Megatron-LM alone.</em>
</p>

Shaden Smith's avatar
Shaden Smith committed
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420

**a ) Megatron-LM GPT2 Baseline**

|      | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
| 1.5B | 2                 | 32               | 64    | 512        | 48     | 1600        | 16              | 128.56        |
| 4B   | 4                 | 16               | 64    | 128        | 64     | 2304        | 16              | 49.36         |
| 8B   | 4                 | 16               | 64    | 128        | 72     | 3072        | 24              | 24.57         |
| 20B  | 16                | 4                | 64    | 16         | 111    | 3808        | 32              | 3.42          |



**b ) Megatron-LM GPT2 with DeepSpeed**

|      | Model Parallelism | Data Parallelism | #gpus | batch size | layers | hidden size | attention heads | samples / sec |
| ---- | ----------------: | ---------------: | ----: | ---------: | -----: | -----------:| --------------: | ------------: |
| 1.5B | 1                 | 64               | 64    | 2048       | 48     | 1600        | 16              | 151.35        |
| 4B   | 1                 | 64               | 64    | 512        | 64     | 2304        | 16              | 75.13         |
| 8B   | 2                 | 32               | 64    | 512        | 72     | 3072        | 24              | 43.52         |
| 20B  | 4                 | 16               | 64    | 128        | 111    | 3808        | 32              | 12.65         |