megatron.md 16.8 KB
Newer Older
Shaden Smith's avatar
Shaden Smith committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
---
title: "Megatron-LM GPT2"
---

If you haven't already, we advise you to first read through the [Getting
Started](/getting-started/) guide before stepping through this tutorial.

In this tutorial we will be adding DeepSpeed to Megatron-LM GPT2 model, which
is a large, powerful transformer. Megatron-LM supports model-parallel and multi-node
training. Please see the corresponding paper for more details: [Megatron-LM:
Training Multi-Billion Parameter Language Models Using Model
Parallelism](https://arxiv.org/abs/1909.08053).

First, we discuss data and environment setup and how to train the GPT-2 model with the
original Megatron-LM. Next, we proceed step-by-step in enabling this model to run with
DeepSpeed. Finally, we demonstrate the **_performance gains_**, and **_memory footprint
reduction_** from using DeepSpeed.

## Training GPT-2 with the Original Megatron-LM

The original model code from
[Megatron-LM](https://github.com/NVIDIA/Megatron-LM).  We've copied this repo
under
[DeepSpeedExamples/Megatron-LM/](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM)
and made it available as a submodule. To download, execute:
```bash
git submodule update --init --recursive
```

### Training Data Setup
* Follow Megatron's [instructions](https://github.com/NVIDIA/Megatron-LM#collecting-gpt2-webtext-data)
  to download the webtext data and place a symbolic link under `DeepSpeedExamples/Megatron-LM/data`:

### Running Unmodified Megatron-LM GPT2 model

* For a single GPU run:
    - change `scripts/pretrain_gpt2.sh`, set its `--train-data` argument as `"webtext"`.
    - run `bash scripts/pretrain_gpt2.sh`

* For multiple GPUs and/or nodes run:
    - change `scripts/pretrain_gpt2_model_parallel.sh`
        - set its `--train-data` argument as `"webtext"`
        - `GPUS_PER_NODE` indicates how many GPUs per node involved in the testing
        - `NNODES` indicates how many nodes involved in the testing

    - run `bash scripts/pretrain_gpt2_model_parallel.sh`


## Enabling DeepSpeed

To use DeepSpeed we will modify three files :

* `arguments.py` : Arguments configurations
* `pretrain_gpt2.py` : Main entry point for training
* `utils.py` : Checkpoints saving and loading utilities


### Argument Parsing
The first step is to apply DeepSpeed is adding DeepSpeed arguments to
Megatron-LM GPT2 model, using `deepspeed.add_config_arguments()` in
`arguments.py`.

```python
def get_args():
    """Parse all the args."""

    parser = argparse.ArgumentParser(description='PyTorch BERT Model')
    parser = add_model_config_args(parser)
    parser = add_fp16_config_args(parser)
    parser = add_training_args(parser)
    parser = add_evaluation_args(parser)
    parser = add_text_generate_args(parser)
    parser = add_data_args(parser)

    # Include DeepSpeed configuration arguments
    parser = deepspeed.add_config_arguments(parser)
```



### Initialization and Training
We modify `pretrain.py` to enable training with DeepSpeed.

#### Initialization
We use `deepspeed.initialize` to create `model_engine`, `optimizer` and LR
`scheduler`. Below is its definition:
```python
def initialize(args,
               model,
               optimizer=None,
               model_parameters=None,
               training_data=None,
               lr_scheduler=None,
               mpu=None,
               dist_init_required=True,
               collate_fn=None):
```

For the Megatron-LM GPT2 model, we initialize DeepSpeed in its
`setup_model_and_optimizer()` function as below, to pass the raw `model`,
`optimizer`, `args`, `lr_scheduler` and `mpu`.
```python
def setup_model_and_optimizer(args):
    """Setup model and optimizer."""

    model = get_model(args)
    optimizer = get_optimizer(model, args)
    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.deepspeed:
        import deepspeed

        print_rank_0("DeepSpeed is enabled.")

        model, optimizer, _, lr_scheduler = deepspeed.initialize(
            model=model,
            optimizer=optimizer,
            args=args,
            lr_scheduler=lr_scheduler,
            mpu=mpu,
            dist_init_required=False
       )
```


Note that when FP16 is enabled, Megatron-LM GPT2 adds a wrapper to the `Adam`
optimizer. DeepSpeed has its own FP16 Optimizer, so we need to pass the `Adam`
optimizer to DeepSpeed directly without any wrapper. We return the unwrapped
Adam optimizer from `get_optimizer()` when DeepSpeed is enabled.
```python
def get_optimizer(model, args):
    """Setup the optimizer."""

    ......

    # Use Adam.
    optimizer = Adam(param_groups,
                     lr=args.lr, weight_decay=args.weight_decay)

    if args.deepspeed:
        # fp16 wrapper is not required for DeepSpeed.
        return optimizer
```

#### Using the Training API
The `model` returned by `deepspeed.initialize` is the _DeepSpeed Model Engine_
that we will use to train the model using the forward, backward and step API.


##### Forward Propagation
The forward propagation API is compatible to PyTorch and no change is required.


##### Backward Propagation
Backward propagation is done by calling `backward(loss)` directly on the model engine.

```python
    def backward_step(optimizer, model, lm_loss, args, timers):
        """Backward step."""

        # Total loss.
        loss = lm_loss

        # Backward pass.
        if args.deepspeed:
            model.backward(loss)
        else:
            optimizer.zero_grad()
            if args.fp16:
                optimizer.backward(loss, update_master_grads=False)
            else:
                loss.backward()
```

Zeroing the gradients is handled automatically by DeepSpeed after the weights
have been updated using a mini-batch.

Furthermore, DeepSpeed addresses distributed data parallel and FP16 under the
hood, simplifying code in multiple places.

(A) DeepSpeed also performs gradient averaging automatically at the gradient
accumulation boundaries. So we skip the allreduce communication.

   ```python
        if args.deepspeed:
            # DeepSpeed backward propagation already addressed all reduce communication.
            # Reset the timer to avoid breaking timer logs below.
            timers('allreduce').reset()
        else:
            torch.distributed.all_reduce(reduced_losses.data)
            reduced_losses.data = reduced_losses.data / args.world_size
            if not USE_TORCH_DDP:
                timers('allreduce').start()
                model.allreduce_params(reduce_after=False,
                                       fp32_allreduce=args.fp32_allreduce)
                timers('allreduce').stop()

   ```

(B) We also skip updating master gradients, since DeepSpeed addresses it internally.

   ```python
        # Update master gradients.
        if not args.deepspeed:
            if args.fp16:
                optimizer.update_master_grads()

            # Clipping gradients helps prevent the exploding gradient.
            if args.clip_grad > 0:
                if not args.fp16:
                    mpu.clip_grad_norm(model.parameters(), args.clip_grad)
                else:
                    optimizer.clip_master_grads(args.clip_grad)

        return lm_loss_reduced

   ```

##### Updating the Model Parameters
The `step()` function in DeepSpeed engine updates the model parameters as well
as the learning rate.

```python
     if args.deepspeed:
         model.step()
     else:
         optimizer.step()

         # Update learning rate.
         if not (args.fp16 and optimizer.overflow):
             lr_scheduler.step()
         else:
             skipped_iter = 1

```



##### Loss Scaling
The GPT2 training script logs the loss scaling value during training. Inside,
the DeepSpeed optimizer, this value is stored as `cur_scale` instead of
`loss_scale` in Megatron's optimizer. Therefore, we appropriately replace it in
the logging string.

```python
             if args.fp16:
                 log_string += ' loss scale {:.1f} |'.format(
                     optimizer.cur_scale if args.deepspeed else optimizer.loss_scale)

```


### Checkpoints Saving & Loading

DeepSpeed engine has flexible APIs for checkpoint saving and loading, to handle
the states from both the client model and its own internal.

```python
def save_checkpoint(self, save_dir, tag, client_state={})
def load_checkpoint(self, load_dir, tag)
```

Applying DeepSpeed needs to update utils.py in which Megatron-LM GPT2 saves and
loads its checkpoints.

A new function `save_ds_checkpoint()` is created as below for DeepSpeed, it
collects the client model states and passes to DeepSpeed engine by calling
`save_checkpoint()` of DeepSpeed.

```python
 def save_ds_checkpoint(iteration, model, args):
     """Save a model checkpoint."""

     sd = {}
     sd['iteration'] = iteration
     # rng states.
     if not args.no_save_rng:
         sd['random_rng_state'] = random.getstate()
         sd['np_rng_state'] = np.random.get_state()
         sd['torch_rng_state'] = torch.get_rng_state()
         sd['cuda_rng_state'] = torch.cuda.get_rng_state()
         sd['rng_tracker_states'] = mpu.get_cuda_rng_tracker().get_states()

     model.save_checkpoint(args.save, iteration, client_state = sd)

```

In Megatron-LM GPT2 `save_checkpoint()` function, adds following lines to
invoke the above function for DeepSpeed.

```python
 def save_checkpoint(iteration, model, optimizer,
                     lr_scheduler, args):
     """Save a model checkpoint."""
     if args.deepspeed:
         save_ds_checkpoint(iteration, model, args)
     else:
		......

```

In `load_checkpoint()` function, use DeepSpeed loading checkpoint API as below,
and return the states for the client model.

```python
 def load_checkpoint(model, optimizer, lr_scheduler, args):
     """Load a model checkpoint."""

     iteration, release = get_checkpoint_iteration(args)

     if args.deepspeed:
         checkpoint_name, sd = model.load_checkpoint(args.load, iteration)

         if checkpoint_name is None:
             if mpu.get_data_parallel_rank() == 0:
                 print("Unable to load checkpoint.")
             return iteration
     else:
         ......

```

Jeff Rasley's avatar
Jeff Rasley committed
323
324
### DeepSpeed Activation Checkpoints (Optional)

325
DeepSpeed can reduce the activation memory during model parallel training by partitioning activation checkpoints across model parallel GPUs, or offloading them to CPU. These optimizations are optional, and can be skipped unless activation memory becomes a memory bottleneck. To enable partition activation, we use the `deepspeed.checkpointing` API to replace Megatron's activation checkpointing and random state tracker APIs. The replacement should happen before the first invocation of these APIs.
Jeff Rasley's avatar
Jeff Rasley committed
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

a) Replace in `pretrain_gpt.py` :

 ```python
    # Optional DeepSpeed Activation Checkpointing Features
    #
    if args.deepspeed and args.deepspeed_activation_checkpointing:
        set_deepspeed_activation_checkpointing(args)

def set_deepspeed_activation_checkpointing(args):

    deepspeed.checkpointing.configure(mpu,
                            deepspeed_config=args.deepspeed_config,
                            partition_activation=True)

    mpu.checkpoint = deepspeed.checkpointing.checkpoint
    mpu.get_cuda_rng_tracker = deepspeed.checkpointing.get_cuda_rng_tracker
    mpu.model_parallel_cuda_manual_seed =
                    deepspeed.checkpointing.model_parallel_cuda_manual_seed
```

b) Replace in `mpu/transformer.py`:

```python
if deepspeed.checkpointing.is_configured():
            global get_cuda_rng_tracker, checkpoint
            get_cuda_rng_tracker = deepspeed.checkpoint.get_cuda_rng_tracker
            checkpoint = deepspeed.checkpointing.checkpoint

```

With these replacements, various DeepSpeed activation checkpointing optimizations such as activation partitioning, contiguous checkpointing, and CPU checkpointing, can be specified with either `deepspeed.checkpointing.configure` or in the `deepspeed_config` file.


Shaden Smith's avatar
Shaden Smith committed
360
361
362
363
364
365
366
367
### Train  scripts
Assume webtext data was prepared in previous step, to start training
Megatron-LM GPT2 model with DeepSpeed applied, execute the following command to
start training.

- Single GPU run
  - run `bash scripts/ds_pretrain_gpt2.sh`
- Multiple GPUs/Nodes run
Jeff Rasley's avatar
Jeff Rasley committed
368
  - run `bash scripts/ds_zero2_pretrain_gpt2_model_parallel.sh`
Shaden Smith's avatar
Shaden Smith committed
369

Jeff Rasley's avatar
Jeff Rasley committed
370
## DeepSpeed Evaluation using GPT-2
Shaden Smith's avatar
Shaden Smith committed
371
372

DeepSpeed enables training very large models effectively via the advanced [ZeRO
Jeff Rasley's avatar
Jeff Rasley committed
373
374
375
376
377
378
379
optimizer](https://arxiv.org/abs/1910.02054v2). In February, we released a sub-set
of optimizations from ZeRO in DeepSpeed that performs optimizer state partitioning.
We refer to them as ZeRO-1. In May, 2020 we extended ZeRO-1 in DeepSpeed to include
additional optimizations from ZeRO including gradient and activation partitioning,
as well as contiguous memory optimizations. We refer to this release as ZeRO-2.  

ZeRO-2 significantly reduces the memory
Shaden Smith's avatar
Shaden Smith committed
380
381
382
383
384
385
386
footprint for training large models which means large models can be trained with i) less
model parallelism and ii) larger batch sizes. A lower model parallelism degree improves
training efficiency by increasing the granularity of the computation such as the matrix
multiplication where performance is directly related to the size of the matrices.
Furthermore, less model parallelism also results in less communication between model
parallel GPUs, which further boosts performance.  Larger batch size has a similar effect
of increasing the computational granularity as well as reducing communication, also
Jeff Rasley's avatar
Jeff Rasley committed
387
388
resulting in better performance. Therefore, with DeepSpeed and ZeRO-2 integration into Megatron,
we elevate the model scale and speed to an entirely new level compared to Megatron alone.
Shaden Smith's avatar
Shaden Smith committed
389

Shaden Smith's avatar
Shaden Smith committed
390
![DeepSpeed-vs-Megatron](/assets/images/zero-full.png)
Shaden Smith's avatar
Shaden Smith committed
391
<p align="center">
Jeff Rasley's avatar
Jeff Rasley committed
392
<em>Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.</em>
Shaden Smith's avatar
Shaden Smith committed
393
394
</p>

Jeff Rasley's avatar
Jeff Rasley committed
395
396
More concretely, DeepSpeed and ZeRO-2 excel in four aspects (as visualized in Figure 2), supporting an order-of-magnitude bigger models, up to 10x faster, with superlinear scalability, and improved usability to democratize large model training. These four aspects are detailed below.

Shaden Smith's avatar
Shaden Smith committed
397

Jeff Rasley's avatar
Jeff Rasley committed
398
Figure 2: ZeRO-2 scales to 170 billion parameters, has up to 10x higher throughput, obtains super linear speedup, and improves usability by avoiding the need for code refactoring for models up to 13 billion parameters.
Shaden Smith's avatar
Shaden Smith committed
399

Jeff Rasley's avatar
Jeff Rasley committed
400
Model size: State-of-the-art large models such as OpenAI GPT-2, NVIDIA Megatron-LM, Google T5, and Microsoft Turing-NLG have sizes of 1.5B, 8.3B, 11B, and 17B parameters respectively. ZeRO-2 provides system support to efficiently run models of 170 billion parameters, an order-of-magnitude bigger than these largest models (Figure 2, top left).
Shaden Smith's avatar
Shaden Smith committed
401

Jeff Rasley's avatar
Jeff Rasley committed
402
Speed: Improved memory efficiency powers higher throughput and faster training. Figure 2 (bottom left) shows system throughput of ZeRO-2 and ZeRO-1 (both combining ZeRO-powered data parallelism with NVIDIA Megatron-LM model parallelism) as well as using the state-of-the-art model parallelism approach Megatron-LM alone (baseline in Figure 2, bottom left). ZeRO-2 runs 100-billion-parameter models on a 400 NVIDIA V100 GPU cluster with over 38 teraflops per GPU and aggregated performance over 15 petaflops. For models of the same size, ZeRO-2 is 10x faster in training speed when compared with using Megatron-LM alone and 5x faster when compared with ZeRO-1.
Shaden Smith's avatar
Shaden Smith committed
403

Jeff Rasley's avatar
Jeff Rasley committed
404
Scalability: We observe superlinear speedup (Figure 2, top right), where the performance more than doubles when the number of GPUs are doubled. ZeRO-2 reduces the memory footprint of the model states as we increase the data parallelism degree, allowing us to fit larger batch sizes per GPU and resulting in better performance.
Shaden Smith's avatar
Shaden Smith committed
405

Jeff Rasley's avatar
Jeff Rasley committed
406
Democratizing large model training: ZeRO-2 empowers model scientists to train models up to 13 billion parameters efficiently without any model parallelism that typically requires model refactoring (Figure 2, bottom right). 13 billion parameters is larger than most of the largest state-of-the-art models (such as Google T5, with 11 billion parameters). Model scientists can therefore experiment freely with large models without worrying about model parallelism. In comparison, the implementations of classic data-parallelism approaches (such as PyTorch Distributed Data Parallel) run out of memory with 1.4-billion-parameter models, while ZeRO-1 supports up to 6 billion parameters for comparison.
Shaden Smith's avatar
Shaden Smith committed
407

408
Furthermore, in the absence of model parallelism, these models can be trained on low bandwidth clusters while still achieving significantly better throughput compared to using model parallelism. For example, the GPT-2 model can be trained nearly 4x faster with ZeRO powered data parallelism compared to using model parallelism on a four node cluster connected with 40 Gbps Infiniband interconnect, where each node have four NVIDIA 16GB V100 GPUs connected with PCI-E. Therefore, with this performance improvement, large model training is no longer limited to GPU clusters with ultra fast interconnect but also accessible on modest clusters with limited bandwidth.