nvme_offload.md 8.15 KB
Newer Older
1
2
3
4
5
6
7
# NVMe offload

Author: Hongxin Liu

**Prerequisite:**
- [Zero Redundancy Optimizer with chunk-based memory management](../features/zero_with_chunk.md)

8
9
10
11
12
**Related Paper**

- [ZeRO-Offload: Democratizing Billion-Scale Model Training](https://arxiv.org/abs/2101.06840)
- [ZeRO-Infinity: Breaking the GPU Memory Wall for Extreme Scale Deep Learning](https://arxiv.org/abs/2104.07857)

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
## Introduction

If a model has `N` parameters, when using Adam, it has `8N` optimizer states. For billion-scale models, optimizer states take at least 32 GB memory. GPU memory limits the model scale we can train, which is called GPU memory wall. If we offload optimizer states to the disk, we can break through GPU memory wall.

We implement a user-friendly and efficient asynchronous Tensor I/O library: [TensorNVMe](https://github.com/hpcaitech/TensorNVMe). With this library, we can simply implement NVMe offload.

> This library is compatible with all kinds of disk (HDD, SATA SSD, and NVMe SSD). As I/O bandwidth of HDD or SATA SSD is low, it's recommended to use this lib only on NVMe disk.

When optimizing a parameter, we can divide the optimization process into three stages: read, compute and offload. We perform the optimization process in a pipelined fashion, which can overlap computation and I/O.

<figure style={{textAlign: "center"}}>
<img src="https://s2.loli.net/2022/08/16/CvRnowrsNyB4hza.jpg"/>
<figcaption>Optimization process</figcaption>
</figure>

## Usage

First, please make sure you installed [TensorNVMe](https://github.com/hpcaitech/TensorNVMe):

```shell
pip install packaging
pip install tensornvme
```

We implement NVMe offload of optimizer states for Adam ([CPUAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.cpu_adam.html) and [HybridAdam](https://colossalai.readthedocs.io/en/latest/colossalai/colossalai.nn.optimizer.hybrid_adam.html)).

ver217's avatar
ver217 committed
39
40
41

<!--- doc-test-ignore-start -->

42
43
44
45
46
47
```python
from colossalai.nn.optimizer import CPUAdam, HybridAdam

optimizer = HybridAdam(model.parameters(), lr=1e-3, nvme_offload_fraction=1.0, nvme_offload_dir='./')
```

ver217's avatar
ver217 committed
48
49
<!--- doc-test-ignore-end -->

50
51
52
`nvme_offload_fraction` is the fraction of optimizer states to be offloaded to NVMe. `nvme_offload_dir` is the directory to save NVMe offload files. If `nvme_offload_dir` is `None`, a random temporary directory will be used.

It's compatible with all parallel methods in ColossalAI.
ver217's avatar
ver217 committed
53
54
55
56
57
58
59

> ⚠ It only offloads optimizer states on CPU. This means it only affects CPU training or Zero/Gemini with offloading.

## Exampls

Let's start from two simple examples -- training GPT with different methods. These examples relies on `transformers`.

60
We should install dependencies first:
ver217's avatar
ver217 committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80

```shell
pip install psutil transformers
```

First, we import essential packages and modules:

```python
import os
import time
from typing import Dict, Optional

import psutil
import torch
import torch.nn as nn
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

import colossalai
from colossalai.nn.optimizer import HybridAdam
81
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
ver217's avatar
ver217 committed
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from colossalai.utils.model.colo_init_context import ColoInitContext
```

Then we define a loss function:

```python
class GPTLMLoss(nn.Module):

    def __init__(self):
        super().__init__()
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, logits, labels):
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()
        # Flatten the tokens
        return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)),
                            shift_labels.view(-1))
```

102
And we define some utility functions, which generates random data, computes the number of parameters of a model and get memory usage of current process:
ver217's avatar
ver217 committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253

```python
def get_data(batch_size: int, seq_len: int,
             vocab_size: int, device: Optional[str] = None) -> Dict[str, torch.Tensor]:
    device = torch.cuda.current_device() if device is None else device
    input_ids = torch.randint(vocab_size, (batch_size, seq_len),
                              device=device)
    attn_mask = torch.ones_like(input_ids)
    return dict(input_ids=input_ids, attention_mask=attn_mask)


def get_model_numel(model: nn.Module) -> int:
    return sum(p.numel() for p in model.parameters())


def get_mem_usage() -> int:
    proc = psutil.Process(os.getpid())
    return proc.memory_info().rss
```

We first try to train GPT model on CPU:

```python
def train_cpu(nvme_offload_fraction: float = 0.0):
    config = GPT2Config()
    model = GPT2LMHeadModel(config)
    criterion = GPTLMLoss()
    optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
    print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')

    start = time.time()
    for step in range(3):
        data = get_data(4, 128, config.vocab_size, device='cpu')
        outputs = model(**data)
        loss = criterion(outputs.logits, data['input_ids'])
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        print(f'[{step}] loss: {loss.item():.3f}')

    print(f'Time: {time.time() - start:.3f} s')
    print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```

Run without NVME offload:

```python
train_cpu(0.0)
```

We may get below output:

```
Model numel: 0.116 B
[0] loss: 10.953
[1] loss: 10.974
[2] loss: 10.965
Time: 7.739 s
Mem usage: 5966.445 MB
```

And then run with (full) NVME offload:

```python
train_cpu(1.0)
```

We may get:

```
Model numel: 0.116 B
[0] loss: 10.951
[1] loss: 10.994
[2] loss: 10.984
Time: 8.527 s
Mem usage: 4968.016 MB
```

For GPT2-S, which has 0.116 billion parameters, its optimizer states take about 0.928 GB memory. And NVME offload saves about 998 MB memory, which meets our expectations.

Then we can train GPT model with Gemini. The placement policy of Gemini should be `"auto"`, `"cpu"` or `"const"`.

```python
def train_gemini_cpu(nvme_offload_fraction: float = 0.0):
    colossalai.launch_from_torch({})
    config = GPT2Config()
    with ColoInitContext(device=torch.cuda.current_device()):
        model = GPT2LMHeadModel(config)
    criterion = GPTLMLoss()
    optimizer = HybridAdam(model.parameters(), nvme_offload_fraction=nvme_offload_fraction)
    print(f'Model numel: {get_model_numel(model) / 1024**3:.3f} B')

    gemini_config = dict(strict_ddp_mode=True, device=torch.cuda.current_device(),
                         placement_policy='cpu', pin_memory=True, hidden_dim=config.n_embd)
    model = zero_model_wrapper(model, zero_stage=3, gemini_config=gemini_config)
    optimizer = zero_optim_wrapper(model, optimizer, initial_scale=2**5)

    start = time.time()
    for step in range(3):
        data = get_data(4, 128, config.vocab_size)
        outputs = model(**data)
        loss = criterion(outputs.logits, data['input_ids'])
        optimizer.backward(loss)
        optimizer.step()
        optimizer.zero_grad()
        print(f'[{step}] loss: {loss.item():.3f}')

    print(f'Time: {time.time() - start:.3f} s')
    print(f'Mem usage: {get_mem_usage() / 1024**2:.3f} MB')
```

Run without NVME offload:

```python
train_gemini_cpu(0.0)
```

We may get:

```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 2.997 s
Mem usage: 5592.227 MB
```

And run with (full) NVME offload:

```python
train_gemini_cpu(1.0)
```

We may get:

```
Model numel: 0.116 B
searching chunk configuration is completed in 0.27 s.
used number: 118.68 MB, wasted number: 0.75 MB
total wasted percentage is 0.63%
[0] loss: 10.953
[1] loss: 10.938
[2] loss: 10.969
Time: 3.691 s
Mem usage: 5298.344 MB
```

254
NVME offload saves about 294 MB memory. Note that enabling `pin_memory` of Gemini can accelerate training but increase memory usage. So this result also meets our expectation. If we disable `pin_memory`, we can also observe a memory usage drop about 900 MB.
ver217's avatar
ver217 committed
255
256
257
258
259
260

## API Reference

{{ autodoc:colossalai.nn.optimizer.HybridAdam }}

{{ autodoc:colossalai.nn.optimizer.CPUAdam }}
261
262
263


<!-- doc-test-command: torchrun --standalone --nproc_per_node=1 nvme_offload.py  -->