Commit 885a0428 authored by Lawrence McAfee's avatar Lawrence McAfee
Browse files

fixes.

parent 69e25145
...@@ -323,14 +323,14 @@ For cases where memory is very tight, `full` checkpointing saves just the inputs ...@@ -323,14 +323,14 @@ For cases where memory is very tight, `full` checkpointing saves just the inputs
Usage: `--use-distributed-optimizer`. Compatible with all model and data types. Usage: `--use-distributed-optimizer`. Compatible with all model and data types.
The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described in https://arxiv.org/abs/1910.02054, our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed. The distributed optimizer is a memory savings technique, whereby the optimizer state is evenly distributed across data parallel ranks (versus the traditional method of replicating the optimizer state across data parallel ranks). As described in https://arxiv.org/abs/1910.02054, our implementation distributes all optimizer state that does not overlap with the model state. For example, when using fp16 model params, the distributed optimizer maintains its own separate copy of fp32 main params & grads, which are distributed across DP ranks. When using bf16 model params, however, the distributed optimizer's fp32 main grads are the same as the model's fp32 grads, and so the grads in this case are not distributed (although the fp32 main params are still distributed, as they are separate from the bf16 model params).
Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size): Theoretical memory savings vary depending on the combination of the model's param dtype and grad dtype. In our implementation, the theoretical number of bytes per parameter is (where 'd' is the data parallel size):
| | Non-distributed optim | Distributed optim | | | Non-distributed optim | Distributed optim |
|-|-|-| |-|-|-|
| float16 param, float16 grads | 20 | 4 + 16/d | | fp16 param, fp16 grads | 20 | 4 + 16/d |
| float16 param, fp32 grads | 18 | 6 + 12/d | | bf16 param, fp32 grads | 18 | 6 + 12/d |
| fp32 param, fp32 grads | 16 | 8 + 8/d | | fp32 param, fp32 grads | 16 | 8 + 8/d |
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment