Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).
Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats.
Strategies are passed to `dist_checkpointing.load` and `dist_checkpointing.save` functions and control the actual saving/loading procedure.
The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks (https://arxiv.org/abs/1910.02054), versus the naive method of replicating the optimizer state across data parallel ranks.
Theoretical memory savings vary depending on the combination of the datatype of the model's parameters (`param_dtype`) and main gradients accumulated across data-parallel replicas (`grad_dtype`). We always use `fp32` main parameters for optimizer steps. In the current implementation, the theoretical number of bytes per parameter is (where d is the data parallel size):
Our implementation of the distributed optimizer uses contiguous buffers for parameters and main gradients; model gradients are copied over to the main gradients as soon as they are fully computed.
The figures below illustrate the distributed optimizer's sharding scheme, and the key steps of the distributed optimizer's parameter update:
_(note: using illustrations above, assuming `bf16` model weights, `bf16` model gradients that are computed by the backward pass and `fp32` main gradients that are also used for optimizer steps; we always use `fp32` main weights for optimizer steps)_
- Each DP rank now has 4 elements within the gradient buffer that are fully reduced (remaining 12 elements are garbage).
- DP rank 0 has gradient values for elements [0:4].
- DP rank 1 has gradient values for elements [4:8].
- DP rank 2 has gradient values for elements [8:12].
- DP rank 3 has gradient values for elements [12:16].
- Optimizer.step().
- Each DP rank copies its 4 `fp32` main parameter elements into the corresponding `bf16` parameter buffer (each element is cast from fp32 to fp16).
- Call all-gather on each DP rank.
- The parameter buffer now contains all 16, fully updated, `bf16` model parameter elements. Parameters in PyTorch modules already point to the appropriate locations in this parameter buffer, and thus forward passes are ready to run after the all-gather completes.
- At this point, the gradient buffer is also ready to be zero'd for the next iteration.
This package provides modules that provide commonly fused
operations. Fusing operations improves compute efficiency by
increasing the amount of work done each time a tensor is read from
memory. To perform the fusion, modules in this either rely on PyTorch
functionality for doing just-in-time compilation
(i.e. `torch.jit.script` in older PyTorch versions of `torch.compile`
in recent versions), or call into custom kernels in external libraries
such as Apex or TransformerEngine.
Submodules
----------
fusions.fused\_bias\_dropout module
-----------------------------------
This module uses PyTorch JIT to fuse the bias add and dropout operations. Since dropout is not used during inference, different functions are used when in train mode and when in inference mode.
.. automodule:: core.fusions.fused_bias_dropout
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_bias\_gelu module
--------------------------------
This module uses PyTorch JIT to fuse the bias add and GeLU nonlinearity operations.
.. automodule:: core.fusions.fused_bias_gelu
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_layer\_norm module
---------------------------------
This module provides a wrapper around various fused LayerNorm implementation in Apex.
.. automodule:: core.fusions.fused_layer_norm
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_softmax module
-----------------------------
This module provides wrappers around variations of Softmax in Apex.
.. automodule:: core.fusions.fused_softmax
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_cross\_entropy\_loss module
------------------------------------------
This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.
This is the implementation of the popular GPT model. It supports several features like model parallelization (Tensor Parallel, Pipeline Parallel, Data Parallel) , mixture of experts, FP8 , Distributed optimizer etc. We are constantly adding new features. So be on the lookout or raise an issue if you want to have something added.
This package contains most of the popular LLMs . Currently we have support for GPT, Bert, T5 and Retro . This is an ever growing list so keep an eye out.