"docs/source/vscode:/vscode.git/clone" did not exist on "da36c557f763cb4e515f71ba7c49191fede8966f"
Unverified Commit 101b9a7e authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[deepspeed] performance docs (#21573)



* [deepspeed] performance docs

* fix

* re-org

* update

* update

* a new NCCL Collectives section

* inference

* Update docs/source/en/main_classes/deepspeed.mdx
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* suggestion

* Update docs/source/en/main_classes/deepspeed.mdx

* suggestion

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 68eff403
......@@ -795,6 +795,39 @@ default value in the following cases:
the increased data buffers.
#### ZeRO-0 Config
Note that we're listing Stage 0 and 1 last since they are rarely used.
Stage 0 is disabling all types of sharding and just using DeepSpeed as DDP. You can turn it on with:
```json
{
"zero_optimization": {
"stage": 0
}
}
```
Ths will essentially disable ZeRO without you needing to change anything else.
#### ZeRO-1 Config
Stage 1 is Stage 2 minus gradient sharding. You can always try it to speed things a tiny bit to only shard the optimizer states with:
```json
{
"zero_optimization": {
"stage": 1
}
}
```
<a id='deepspeed-nvme'></a>
### NVMe Support
......@@ -1117,6 +1150,54 @@ values look like, but we highly recommend using the one with multiple `auto` set
}
```
#### How to Choose Which ZeRO Stage and Offloads To Use For Best Performance
So now you know there are all these different stages. How to decide which of them to use? This section will attempt to address this question.
In general the following applies:
- Speed-wise (left is faster than right)
Stage 0 (DDP) > Stage 1 > Stage 2 > Stage 2 + offload > Stage 3 > Stage 3 + offloads
- GPU Memory usage-wise (right is more GPU memory efficient than left)
Stage 0 (DDP) < Stage 1 < Stage 2 < Stage 2 + offload < Stage 3 < Stage 3 + offloads
So when you want to get the fastest execution while fitting into minimal number of GPUs, here is the process you could follow. We start with the fastest approach and if running into GPU OOM we then go to the next slower approach, but which will use less GPU memory. And so on and so forth.
First of all set batch size to 1 (you can always use gradient accumulation for any desired effective batch size).
1. Enable `--gradient_checkpointing 1` (HF Trainer) or directly `model.gradient_checkpointing_enable()` - if OOM then
2. Try ZeRO stage 2 first. if OOM then
3. Try ZeRO stage 2 + `offload_optimizer` - if OOM then
4. Switch to ZeRO stage 3 - if OOM then
5. Enable `offload_param` to `cpu` - if OOM then
6. Enable `offload_optimizer` to `cpu` - if OOM then
7. If you still can't fit a batch size of 1 first check various default values and lower them if you can. For example, if you use `generate` and you don't use a wide search beam make it narrower as it'd take a lot of memory.
8. Definitely use mixed half-precision over fp32 - so bf16 on Ampere and higher GPUs and fp16 on older gpu architectures.
9. If you still OOM you could add more hardware or enable ZeRO-Infinity - that is switch offloads `offload_param` and `offload_optimizer` to `nvme`. You need to make sure it's a very fast nvme. As an anecdote I was able to infer BLOOM-176B on a tiny GPU using ZeRO-Infinity except it was extremely slow. But it worked!
You can, of course, work through these steps in reverse by starting with the most GPU memory efficient config and then going backwards. Or try bi-secting it.
Once you have your batch size 1 not leading to OOM, measure your effective throughput.
Next try to increase the batch size to as large as you can, since the higher the batch size the more efficient the GPUs are as they perform the best when matrices they multiply are huge.
Now the performance optimization game starts. You can turn off some offload features or step down in ZeRO stages and increase/decrease batch size and again measure your effective throughput. Rinse and repeat until satisfied.
Don't spend forever on it, but if you're about to start a 3 months training - do spend a few days on it to find the most effective throughput-wise setup. So that your training cost will be the lowest and you will finish training faster. In the current crazy-paced ML world, if it takes you an extra month to train something you are likely to miss a golden opportunity. Of course, this is only me sharing an observation and in no way I'm trying to rush you. Before beginning to train BLOOM-176B I spent 2 days on this process and was able to increase throughput from 90 to 150 TFLOPs! This effort saved us more than one month of training time.
These notes were written primarily for the training mode, but they should mostly apply for inference as well. For example, during inference Gradient Checkpointing is a no-op since it is only useful during training. Additionally, we found out that if you are doing a multi-GPU inference and not using [DeepSpeed-Inference](https://www.deepspeed.ai/tutorials/inference-tutorial/), [Accelerate](https://huggingface.co/blog/bloom-inference-pytorch-scripts) should provide a superior performance.
Other quick related performance notes:
- if you are training something from scratch always try to have tensors with shapes that are divisible by 16 (e.g. hidden size). For batch size try divisible by 2 at least. There are [wave and tile quanitization](https://developer.nvidia.com/blog/optimizing-gpu-performance-tensor-cores/) divisibility that is hardware-specific if you want to squeeze even higher performance from your GPUs.
### Optimizer and Scheduler
As long as you don't enable `offload_optimizer` you can mix and match DeepSpeed and HuggingFace schedulers and
......@@ -1396,9 +1477,32 @@ As of `deepspeed==0.6.0` the bf16 support is new and experimental.
If you use [gradient accumulation](#gradient-accumulation) with bf16-enabled, you need to be aware that it'll accumulate gradients in bf16, which may not be what you want due to this format's low precision, as it may lead to a lossy accumulation.
A work is being done to fix that and provide an option to use a higher precision `dtype` (fp16 or fp32).
</Tip>
### NCCL Collectives
There is the `dtype` of the training regime and there is a separate `dtype` that is used for communication collectives like various reduction and gathering/scattering operations.
All gather/scatter ops are performed in the same `dtype` the data is in, so if you're using bf16 training regime it gets gathered in bf16 - gathering is a non-lossy operation.
Various reduce operations can be quite lossy, for example when gradients are averaged across multiple-gpus, if the communications are done in fp16 or bf16 the outcome is likely be lossy - since when one ads multiple numbers in low precision the result isn't exact. More so with bf16 as it has a lower precision than fp16. Often fp16 is good enough as the loss is minimal when averaging grads which are typically very small. Therefore, by default for half precision training fp16 is used as the default for reduction operations. But you have full control over this functionality and if you choose you can add a small overhead and ensure that reductions will be using fp32 as the accumulation dtype and only when the result is ready it'll get downcast to the half precision `dtype` you're training in.
In order to override the default you simply add a new configuration entry:
```json
{
"communication_data_type": "fp32"
}
```
The valid values as of this writing are "fp16", "bfp16", "fp32".
note: stage zero 3 had a bug with regards to bf16 comm dtype that was fixed in `deepspeed==0.8.1`
### apex
To configure apex AMP-like mode set:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment