zero-offload.md 4.78 KB
Newer Older
Olatunji Ruwase's avatar
Olatunji Ruwase committed
1
2
3
---
title: "ZeRO-Offload"
---
Conglong Li's avatar
Conglong Li committed
4
We recommend that you read the tutorials on [Getting Started](/getting-started/)  and [ZeRO](/tutorials/zero/) before stepping through this tutorial.
Olatunji Ruwase's avatar
Olatunji Ruwase committed
5
6
7
8

ZeRO-Offload is a ZeRO optimization that offloads the optimizer memory and computation from the GPU to the host CPU. ZeRO-Offload enables large models with up to 13 billion parameters to be efficiently trained on a single GPU. In this tutorial we will use ZeRO-Offload to train a 10-billion parameter GPT-2 model in DeepSpeed. Furthermore, *using ZeRO-Offload in a DeepSpeed model is quick and easy because all you need is to change a few configurations in the DeepSpeed configuration json*. No code changes are needed.

## ZeRO-Offload Overview
9
For large model training, optimizers such as [Adam](https://arxiv.org/abs/1412.6980), can consume a significant amount of GPU compute and memory. ZeRO-Offload reduces the GPU compute and memory requirements of such models by leveraging compute and memory resources on the host CPU  to execute the optimizer. Furthermore, to prevent the optimizer from becoming a bottleneck, ZeRO-Offload uses DeepSpeed's highly optimized CPU implementation of Adam called [DeeSpeedCPUAdam](https://github.com/microsoft/DeepSpeed/tree/master/deepspeed/ops/adam). DeepSpeedCPUAdam is 5X--7X faster than the standard PyTorch implementation. To deep dive into the design and performance of ZeRO-Offload, please see our [blog post](https://www.microsoft.com/en-us/research/blog/deepspeed-extreme-scale-model-training-for-everyone/#toc-heading-3).
Olatunji Ruwase's avatar
Olatunji Ruwase committed
10
11

## Training Environment
Conglong Li's avatar
Conglong Li committed
12
For this tutorial, we will configure a 10 billion parameter GPT-2 model using the DeepSpeed [Megatron-LM](https://github.com/microsoft/DeepSpeedExamples/tree/master/Megatron-LM) GPT-2 code. We advise stepping through the Megatron-LM [tutorial](/tutorials/megatron/) if you have not previously done so. We will use a single [NVIDIA Tesla V100-SXM3 Tensor Core GPU](https://www.nvidia.com/en-us/data-center/v100/) with 32GB RAM for this exercise.
Olatunji Ruwase's avatar
Olatunji Ruwase committed
13
14
15
16
17

## Training a 10B parameter GPT-2 on 1 V100 GPU
We need to make changes to the Megatron-LM launch script and to the DeepSpeed configuration json.

### Megatron-LM GPT-2 launch script changes
18
We need to apply two changes to the launch script for the DeepSpeed Megatron-LM GPT-2 model. The first change is to configure a 10B parameter GPT-2 model with activation checkpointing enabled, which can be achieved by the following set of changes:
Olatunji Ruwase's avatar
Olatunji Ruwase committed
19
20
21
22
23
24
25
26
27

```bash
       --model-parallel-size 1 \
       --num-layers 50 \
       --hidden-size 4096 \
       --num-attention-heads 32 \
       --batch-size 10 \
       --deepspeed_config ds_zero_offload.config \
       --cpu_optimizer \
28
       --checkpoint-activations
Olatunji Ruwase's avatar
Olatunji Ruwase committed
29
30
```

Conglong Li's avatar
Conglong Li committed
31
Most of the flags in the changes above should be familiar if you have stepped through the Megatron-LM [tutorial](/tutorials/megatron/), except for the **_--cpu_optimizer_**. This flag informs the model script to pass a CPU-based Adam optimizer, rather than a GPU-based one, to DeepSpeed as the client optimizer. It is very important that this flag be used when training with ZeRO-Offload to ensure correct operation of the DeepSpeed engine.  
Olatunji Ruwase's avatar
Olatunji Ruwase committed
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51

Second, we need to apply the following changes to ensure that only one GPU is used for training.
```bash
   deepspeed --num_nodes 1 --num_gpus 1 ...
```

### DeepSpeed Configuration Changes
ZeRO-Offload leverages much for ZeRO stage 2 mechanisms, and so the configuration changes to enable ZeRO-Offload is an extension of those required to enable ZeRO stage 2. The **zero_optimization** key to enable ZeRO-Offload is shown below:

```json
{
    "zero_optimization": {
        "stage": 2,
        "cpu_offload": true,
        "contiguous_gradients": true,
        "overlap_comm": true
    }
}
```

Olatunji Ruwase's avatar
Olatunji Ruwase committed
52
As seen above, in addition to setting the _stage_ field to **2** (to enable ZeRO stage 2), we also need to set _cpu_offload_ flag to **true** to enable ZeRO-Offload optimizations. In addition, we can  set other ZeRO stage 2 optimization flags, such as _overlap_comm_ to tune ZeRO-Offload performance.  With these changes we can now run the model. We share some screenshots of the training below.
Olatunji Ruwase's avatar
Olatunji Ruwase committed
53
54
55

Here is a screenshot of the training log:

Olatunji Ruwase's avatar
Olatunji Ruwase committed
56
57
58
59
<a href="/assets/images/zero_offload_dp1_10B_log.png">
<img src="/assets/images/zero_offload_dp1_10B_log.png">
</a>

Olatunji Ruwase's avatar
Olatunji Ruwase committed
60
61
62

Here is a screenshot of nvidia-smi showing that only GPU 0 is active during training:

Olatunji Ruwase's avatar
Olatunji Ruwase committed
63
64
65
<a href="/assets/images/zero_offload_dp1_10B_smi.png">
<img src="/assets/images/zero_offload_dp1_10B_smi.png">
</a>
Olatunji Ruwase's avatar
Olatunji Ruwase committed
66
67
68

Finally, here is a screenshot of htop showing host CPU and memory activity during optimizer computation:

Olatunji Ruwase's avatar
Olatunji Ruwase committed
69
70
71
<a href="/assets/images/zero_offload_dp1_10B_cpu.png">
<img src="/assets/images/zero_offload_dp1_10B_cpu.png">
</a>
Olatunji Ruwase's avatar
Olatunji Ruwase committed
72
73
74

Congratulations! You have completed the ZeRO-Offload tutorial.