The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks, versus the current method of replicating the optimizer state across data parallel ranks. As described in https://arxiv.org/abs/1910.02054, this branch specifically implements the following:
- [yes] distribute all 'non-overlapping' optimizer state (i.e., model params already in fp32 are NOT distributed)
...
...
@@ -24,15 +26,15 @@ The grad buffer is used for performing reduce-scatter and all-gather operations,
The figures below illustrate the grad buffer's sharding scheme, and the key steps of the distributed optimizer's param update: