# Llama, Mistral and other Llama-like model support in Megatron-LM
NOTE: In order to simplify code we now only support converting llama-3.x and mistral checkpoints downloaded from Huggingface.
The [Llama-2](https://ai.meta.com/llama/) and [Llama-3.x](https://llama.meta.com/) family of models are an open-source set of pretrained & finetuned (for chat) models that have achieved strong results across a wide set of benchmarks. At their times of release, both Llama-2 and Llama-3 models achieved among the best results for open-source models, and were competitive with leading closed-source models (see https://arxiv.org/pdf/2307.09288.pdf and https://ai.meta.com/blog/meta-llama-3/).
Similarly, [Mistral-7b](https://mistral.ai/news/announcing-mistral-7b/) is an open-source model with pretrained and finetuned (for chat) variants that achieve strong benchmark results.
Architecturally Llama-2, Llama-3 and Mistral-7b are very similar. As such Megatron can support loading checkpoints from all three for inference and finetuning. Converting the checkpoints and loading them is slightly different for each model and is detailed for each below.
# Contents
-[Llama, Mistral and other Llama-like model support in Megatron-LM](#llama-mistral-and-other-llama-like-model-support-in-megatron-lm)
-[Contents](#contents)
-[Llama-2](#llama-2)
-[Download Meta or Huggingface checkpoints](#download-meta-or-huggingface-checkpoints)
-[Using legacy model format](#using-legacy-model-format)
# Llama-2
Llama-2 checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of three steps:
1. Get access to download the checkpoints.
2. Convert the checkpoints from Meta/Huggingface format to Megatron format.
3. Setup arguments for launching the model.
The following sections detail these steps. The final section lists benchmark result comparisons between: 1) Llama-2 inference code running the Meta-format checkpoints, and 2) Megatron inference code running the converted checkpoints.
## Download Meta or Huggingface checkpoints
Users must first apply for access to download the Llama-2 checkpoints either directly from [Meta](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) or through [Huggingface](https://huggingface.co/docs/transformers/main/model_doc/llama2)(HF). The checkpoints are available in two formats, Meta's native format (available from both the Meta and HF links), and HF's format (available only from HF). Either format can be converted to Megatron, as detailed next.
## Convert checkpoint format
We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16.
### Meta format
The Meta format checkpoints are converted to HF format as an intermediate step before converting to Megatron format. The `transformers` package is required, and must have version >=4.31.0 (e.g., `pip install transformers>=4.31.0`). (**Note**: we have specifically tested with versions `4.31.0` and `4.32.0`; your experience may vary with newer versions.) Assuming the downloaded checkpoints are in `$CHECKPOINT_DIR` (with separate sub-directories for 7B, 13B, 70B, etc.), the following example command can be used to convert from Llama-2 format to HF format in bfloat16:
```
python tools/checkpoint/convert.py \
> --model-type GPT \
> --loader llama_mistral \
> --load-dir ${META_FORMAT_DIR} \
> --model-size ${MODEL_SIZE} \
> --checkpoint-type meta \
> --tokenizer-model ${TOKENIZER_MODEL} \
> --saver core \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --target-tensor-parallel-size ${TP} \
> --target-pipeline-parallel-size ${PP} \
> --bf16
```
Valid values for `--model-size` are `llama2-7B`, `llama2-13B`, and `llama2-70B` (for pretrained-only models), and `llama2-7Bf`, `llama2-13Bf`, and `llama2-70Bf` (for chat-finetuned models).
### Huggingface format
The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-2 checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
| Model size | Tensor parallel size (`TP`) |
| ---------- | --------------------------- |
| 7B | 1 |
| 13B | 2 |
| 70B | 8 |
Using these values for `TP`, along with the path to the Llama-2 tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format:
```
python tools/checkpoint/convert.py \
> --model-type GPT \
> --loader llama_mistral \
> --load-dir ${HF_FORMAT_DIR} \
> --model-size ${MODEL_SIZE} \
> --checkpoint-type hf \
> --tokenizer-model ${TOKENIZER_MODEL} \
> --saver core \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --target-tensor-parallel-size ${TP} \
> --target-pipeline-parallel-size ${PP} \
> --bf16
```
After this conversion, we are ready to load the checkpoints into a Megatron GPT model.
## Launch model
### Launch Megatron
If loading for either inference or finetuning, use the following arguments:
```
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--tokenizer-type Llama2Tokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${CHECKPOINT_DIR} \
--exit-on-missing-checkpoint \
--use-checkpoint-args \
--no-load-optim \
--no-load-rng \
--untie-embeddings-and-output-weights \
--use-rotary-position-embeddings \
--normalization RMSNorm \
--no-position-embedding \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32
```
**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
### Launch Meta
Meta checkpoints can be launched with: https://github.com/facebookresearch/llama
### Launch Huggingface
Huggingface checkpoints can be launched with: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
## Benchmark results
The tables below list the benchmark comparisons between native Llama-2 (using Meta's checkpoint and Meta's inference code) and Megatron (using a converted HF checkpoint and Megatron's inference code).
The values are the percent error between Megatron and Llama-2, calculated using the formula: `|<llama_score> - <megatron_score>| / <llama_score>`, where the type of score is detailed before each table. Across all tests (80 total per model size), the mean error is 0.15%. The small difference in benchmark scores between the two models is due to minor arithmetic differences in implementation that alter the numerics slightly. Some of the factors that influence this difference include:
- Megatron performs batch matrix multiplications in a couple places, such as within self attention and in SwiGLU, that Llama performs separately.
- Megatron uses `torch.baddbmm` within self attention, versus Llama using `torch.matmul`.
- Megatron uses a `sin`/`cos` implementation for rotary position embeddings, versus Llama using a `polar`/`complex` implementation.
- Llama calls `torch.set_default_dtype(torch.float16)` during initialization, which Megatron does not.
Llama-3.x checkpoints can be loaded into Megatron for inference and for finetuning. Loading these checkpoints consists of several steps:
1. Get access to download the checkpoints (weights and tokenizer).
2. Convert the checkpoints from Huggingface format to Megatron format.
3. (Optional) Validate converted checkpoints
4. Setup arguments for launching the model.
The following sections detail these steps.
## Download Huggingface checkpoints
Users must first apply for access to download the Llama-3.x checkpoints from [Huggingface](https://huggingface.co/meta-llama).
## Convert checkpoint format
We recommend passing `--dtype bf16` for training or finetuning. Inference can be done in bfloat16 or float16.
### Huggingface format
The HF checkpoints can be converted to Megatron format by using Megatron's own Llama-3.x checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`). One important argument that must be set correctly is the tensor parallel size (`TP`) for each model. The following table shows these values:
| Model size | Tensor parallel size (`TP`) |
| ---------- | --------------------------- |
| 1B | 1 |
| 3B | 1 |
| 8B | 1 |
| 70B | 8 |
Using these values for `TP`, along with the path to the Llama-3.x tokenizer model (automatically downloaded with original checkpoint download; see `${TOKENIZER_MODEL}` below), run the following command from the root of your Megatron source code to convert from HF format to Megatron format:
```
$>: python tools/checkpoint/convert.py \
> --bf16 \
> --model-type GPT \
> --loader llama_mistral \
> --saver core \
> --target-tensor-parallel-size ${TP} \
> --checkpoint-type hf \
> --load-dir ${HF_FORMAT_DIR} \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --tokenizer-model ${TOKENIZER_MODEL} \
> --model-size llama3 \
```
After this conversion, we are ready to load the checkpoints into a Megatron GPT model.
## (Optional) Validate checkpoints
A Megatron-LM text generation server for Llama3 can be launched using the script `examples/inference/llama_mistral/run_text_generation_llama3.sh <PATH_TO_CONVERTED_CORE_CHECKPOINT> <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT>`. For Llama3.1, please use `examples/inference/llama_mistral/run_text_generation_llama3.1.sh`.
Once running, query the server with `curl 'http://<TEXT_GENERATION_SERVER_IP>:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["<SOME_PROMPT>"], "tokens_to_generate":100, "top_k":1}'`.
A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/llama_mistral/huggingface_reference.py --model_path <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT> --prompt <SOME_PROMPT>`.
## Launch model
If loading for either inference or finetuning, use the following arguments for Llama 3.0:
```
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--seq-length 8192 \
--max-position-embeddings 8192 \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${CHECKPOINT_DIR} \
--exit-on-missing-checkpoint \
--use-checkpoint-args \
--no-load-optim \
--no-load-rng \
--untie-embeddings-and-output-weights \
--normalization RMSNorm \
--position-embedding-type rope \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--disable-bias-linear \
--transformer-impl transformer_engine \
--group-query-attention 8 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--rotary-base 500000 \
--rotary-percent 1.0 \
--ffn-hidden-size 14336 \
--num-attention-heads 32 \
--swiglu \
--bf16 \
```
For Llama3.1 please use the following arguments:
```
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--seq-length 8192 \
--max-position-embeddings 131072 \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${CHECKPOINT_DIR} \
--exit-on-missing-checkpoint \
--use-checkpoint-args \
--no-load-optim \
--no-load-rng \
--untie-embeddings-and-output-weights \
--normalization RMSNorm \
--position-embedding-type rope \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32 \
--disable-bias-linear \
--transformer-impl transformer_engine \
--group-query-attention 8 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--rotary-base 500000 \
--rotary-percent 1.0 \
--use-rope-scaling \
--ffn-hidden-size 14336 \
--num-attention-heads 32 \
--swiglu \
--bf16 \
```
**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
# Mistral-7b
Megatron currently supports loading the v0.3 release of Mistral-7b (which does not use sliding window attention and offers a larger 32768 vocabulary) for inference and finetuning. Loading these checkpoints consists of several steps:
1. Get access to download the checkpoints (weights and tokenizer).
2. Convert the checkpoints from HuggingFace format to Megatron format.
3. (Optional) Validate converted checkpoints
4. Setup arguments for launching the model.
The following sections detail these steps.
## Download Huggingface checkpoints
Users must first apply for access to download the Mistral-7b checkpoints through [Huggingface](https://huggingface.co/mistralai/Mistral-7B-v0.3)(HF).
## Convert checkpoint format
The HF checkpoints can be converted to Megatron format by using Megatron's own Mistral checkpoint converter for HF format (see script `tools/checkpoint/loader_llama_mistral.py`).
Using the path to the Mistral tokenizer model (downloaded alongside the HF checkpoint), run the following command from the root of your Megatron source code to convert from HF format to the Megatron core format:
```
$>: python tools/checkpoint/convert.py \
> --bf16 \
> --model-type GPT \
> --loader llama_mistral \
> --saver core \
> --target-tensor-parallel-size ${TP} \
> --checkpoint-type hf \
> --load-dir ${HF_FORMAT_DIR} \
> --save-dir ${MEGATRON_FORMAT_DIR} \
> --tokenizer-model ${TOKENIZER_MODEL} \
> --model-size mistral \
```
After this conversion, we are ready to load the checkpoints into a Megatron core GPT model.
## (Optional) Validate checkpoints
A Megatron-LM text generation server for Mistral-7B can be launched using the script `examples/inference/llama_mistral/run_text_generation_mistral.sh <PATH_TO_CONVERTED_MCORE_CHECKPOINT> <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT>`.
Once running, query the server with `curl 'http://<TEXT_GENERATION_SERVER_IP>:5000/api' -X 'PUT' -H 'Content-Type: application/json; charset=UTF-8' -d '{"prompts":["<SOME_PROMPT>"], "tokens_to_generate":100, "top_k":1}'`.
A reference generation for comparison can be obtained from the Huggingface transformers library by running `python examples/inference/llama_mistral/huggingface_reference.py --model_path <PATH_TO_DOWNLOADED_HUGGINGFACE_CHECKPOINT> --prompt <SOME_PROMPT>`.
## Launch model
If loading for either inference or finetuning, use the following arguments:
```
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size 1 \
--seq-length 4096 \
--max-position-embeddings 4096 \
--tokenizer-type HuggingFaceTokenizer \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${CHECKPOINT_DIR} \
--exit-on-missing-checkpoint \
--use-checkpoint-args \
--no-load-optim \
--no-load-rng \
--untie-embeddings-and-output-weights \
--normalization RMSNorm \
--position-embedding-type rope \
--no-masked-softmax-fusion \
--attention-softmax-in-fp32
--apply-layernorm-1p \
--transformer-impl transformer_engine \
--group-query-attention 8 \
--disable-bia-linear \
--rotary-base 1000000 \
--rotary-percent 1.0 \
--swiglu \
--ffn-hidden-size 14336 \
--num-attention-heads 32
```
**Note:** If you converted to the legacy model format (i.e., `--saver legacy`), please see [here](#using-legacy-model-format).
# Other Llama-like model support
*Note: Experimental*
Many models such as Yi-34B and Qwen2.x use the Llama architecture and may be converted from HuggingFace to Megatron using the commands in [Llama-3.x](#llama-3x).
# Known numerical differences
It is not expected that the megatron and Huggingface implementations of llama3.x and mistral models will produce numerically identical results. There are multiple points where small numerical differences are expected. This is a non-exhaustive list:
1. TransformerEngine (TE) uses the model params_dtype inside RMSNorm whereas the Huggingface implementation uses fp32. See for details: https://github.com/NVIDIA/TransformerEngine/issues/1132
2. Huggingface `transformers` implements the q, k and v projections in self-attention as separate GEMMs whereas Megatron core combines them into a single GEMM for efficiency. This leads to small numerical differences.
# Using legacy model format
In all the checkpoint conversion examples used in this document, the saver format `--saver core` is used, signifying that the newer (and recommended) Megatron GPT model class will be used. I.e.:
- old class: `megatron.legacy.model.gpt_model.GPTModel`
- new class: `megatron.core.models.gpt.gpt_model.GPTModel`
Using this new format is the recommended approach. However, if your use case requires using the older class (i.e., convert using `--saver legacy`), then when launching training or finetuning, the following args must be added:
-`--use-legacy-models`: use the older model class
-`--ckpt-format torch`: use the `torch` checkpoint format, which is the only checkpoint format that is compatible with the legacy model format
Figure 1: A transformer layer running with TP2CP2. Communications next to Attention are for CP, others are for TP. (AG/RS: all-gather in forward and reduce-scatter in backward, RS/AG: reduce-scatter in forward and all-gather in backward, /AG: no-op in forward and all-gather in backward).
Context Parallelism ("CP") is a parallelization scheme on the dimension of sequence length. Unlike prior SP (sequence parallelism) which only splits the sequence of Dropout and LayerNorm activations, CP partitions the network inputs and all activations along sequence dimension. With CP, all modules except attention (e.g., Linear, LayerNorm, etc.) can work as usual without any changes, because they do not have inter-token operations. As for attention, the Q (query) of each token needs to compute with the KV (key and value) of all tokens in the same sequence. Hence, CP requires additional all-gather across GPUs to collect the full sequence of KV. Correspondingly, reduce-scatter should be applied to the activation gradients of KV in backward propagation. To reduce activation memory footprint, each GPU only stores the KV of a sequence chunk in forward and gathers KV again in backward. KV communication happens between a GPU and its counterparts in other TP groups. The all-gather and reduce-scatter are transformed to point-to-point communications in ring topology under the hood. Exchanging KV also can leverage MQA/GQA to reduce communication volumes, as they only have one or few attention heads for KV.
For example, in Figure 1, assuming sequence length is 8K, each GPU processes 4K tokens. GPU0 and GPU2 compose a CP group, they exchange KV with each other. Same thing also happens between GPU1 and GPU3. CP is similar to `Ring Attention <https://arxiv.org/abs/2310.01889>`_ but provides better performance by (1) leveraging the latest OSS and cuDNN flash attention kernels; (2) removing unnecessary computation resulted from low-triangle causal masking and achieving optimal load balance among GPUs.
Figure 2: Speedup of 175B GPT with various TP+CP combinations vs. full recompute (i.e., TP8CP1).
LLM encounters OOM (out of memory) issue with long context (i.e., long sequence length) because of linearly increasing memory footprint of activations. Recomputing activations in backward can avoid OOM but also introduce significant overheads (~30% with full recompute). Enlarging TP (tensor model parallelism) can fix the OOM issue as well, but it potentially makes compute (e.g., Linear) too short to overlap communication latencies. To be clear, scaling out to more GPUs with bigger TP can hit the overlapping problem no matter if OOM happens.
CP can better address the issues. With CP, each GPU only computes on a part of the sequence, which reduces both computation and communication by CP times. Therefore, there are no concerns about the overlapping between them. The activation memory footprint per GPU is also CP times smaller, hence no OOM issue anymore. As Figure 2 shows, the combinations of TP and CP can achieve optimal performance by eliminating recompute overheads and making the best tradeoff between computation and communications.
Enabling context parallelism
----------------------------
CP support has been added to GPT. All models that share GPT code path also should be able to benefit from CP, such as Llama. CP can work with TP (tensor model parallelism), PP (pipeline model parallelism), and DP (data parallelism), where the total number of GPUs equals TPxCPxPPxDP. CP also can work with different attention variants, including MHA/MQA/GQA, uni-directional and bi-directional masking.
CP is enabled by simply setting context_parallel_size=<CP_SIZE> in command line. Default context_parallel_size is 1, which means CP is disabled. Running with CP requires Megatron-Core (>=0.5.0) and Transformer Engine (>=1.1).
-**Sharding Strategy**: Efficiently shards optimizer states, gradients, and parameters to reduce memory consumption.
-**Communication and Computation Overlap**: Optimized to enable concurrent execution of communication and computation, enhancing overall efficiency.
-**Supports automatic mixed precision training**: Compatible with BF16 O1/O2/O3 recipes, as well as FP8 compute with FP32 parameters and FP8 parameter training, allowing for flexible precision configurations.
-**Tensor Parallelism (TP), Expert Parallelism (EP) and Context Parallelism (CP)**: Compatible with TP, EP and CP configurations, enabling efficient scaling of large language models.
-**Distributed Model Initialization with Meta Device**: Allows model initialization using meta device, followed by layer-by-layer initialization of distributed model weight buffers via the `Module.reset_parameters` API, facilitating the initialization of extremely large models.
## Configuration Recommendations
### 1. Disable `CUDA_MAX_CONNECTIONS`
To ensure full parallelization of FSDP communication and computation, disable the CUDA_MAX_CONNECTIONS environment variable. This step avoids potential bubble in CUDA stream. (But it may slow down TP and CP to some extent.)
```bash
unset CUDA_MAX_CONNECTIONS
```
### 2. Add `--calculate-per-token-loss`
For gradients sharding mode optimization, include the `--calculate-per-token-loss` flag in your training script. This improves performance by reducing the frequency of gradient scaling, which is also a sizable drain on SM resources.
## Design of Custom FSDP
### 1. Overview
The custom Fully Sharded Data Parallelism (FSDP) implementation in Megatron-Core is specifically designed to optimize memory consumption and performance for large language models. The core design principles include:
-**Optimized for Large Language Models**: This custom FSDP implementation is tailored to efficiently scale with models containing billions of parameters, ensuring seamless execution and training of massive models.
-**Efficient Memory Consumption**: By strategically sharding optimizer states, gradients, and model parameters, the custom FSDP significantly reduces memory usage. This approach enables the training of models that would otherwise be too large to fit in memory.
-**Efficient Workflow & Overlapping Communication and Computation**: The implementation is engineered to minimize the number of communication steps required during training. It maximizes the overlap between communication and computation, thereby enhancing overall training efficiency and reducing latency.
-**Support for MCore's Efficient Training Methods**: The custom FSDP seamlessly integrates with Megatron-Core's advanced parallelism techniques, including tensor parallelism, expert parallelism and context parallelism. Additionally, it supports automatic mixed precision training, further optimizing training performance and efficiency.
The design of Custom FSDP draws inspiration from PyTorch FSDP [Zhao, Yanli, et al.](https://arxiv.org/pdf/2304.11277) and MCore's distributed optimizer. The introduction to PyTorch FSDP is referenced here to clarify the underlying concepts of the custom FSDP design.
> In DistributedDataParallel, (DDP) training, each process/ worker owns a replica of the model and processes a batch of data, finally it uses all-reduce to sum up gradients over different workers. In DDP the model weights and optimizer states are replicated across all workers. FSDP is a type of data parallelism that shards model parameters, optimizer states and gradients across DDP ranks.
> When training with FSDP, the GPU memory footprint is smaller than when training with DDP across all workers. This makes the training of some very large models feasible by allowing larger models or batch sizes to fit on device. This comes with the cost of increased communication volume. The communication overhead is reduced by internal optimizations like overlapping communication and computation.
*Notice that the unit processed in workflow here is the “FSDP instance 1: N layers”, where an FSDP instance is the smallest FSDP processing unit (also a PyTorch module), which means that we can safely release this module weights after using it (executing the forward or backward of this module), and there will be no other computations computations relying on these weights. This capability is the foundation of FSDP's layer-by-layer execution and memory-saving strategy. An FSDP instance is also referred to as an **FSDP Unit**.*
*It is worth noting that an FSDP instance can correspond to multiple FSDP parameter groups. These groups are separated by Data Parallel (DP) communication groups and the data type of the parameter or gradient. Consequently, an FSDP instance may require several parameter-gather tasks before execution (forward or backward). Each **FSDP parameter group** corresponds to one **Data Parallel Buffer** in custom FSDP.*
At a high level FSDP works as follow:
In constructor
- Shard model parameters and each rank only keeps its own shard
In forward path
- Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
- Run forward computation
- Discard parameter shards it has just collected
In backward path
- Run all_gather to collect all shards from all ranks to recover the full parameter in this FSDP unit
- Run backward computation
- Run reduce_scatter to sync gradients
- Discard parameters.
One way to view FSDP’s sharding is to decompose the DDP gradient all-reduce into reduce-scatter and all-gather. Specifically, during the backward pass, FSDP reduces and scatters gradients, ensuring that each rank possesses a shard of the gradients. Then it updates the corresponding shard of the parameters in the optimizer step. Finally, in the subsequent forward pass, it performs an all-gather operation to collect and combine the updated parameter shards.
To implement the FSDP functionality described above, the custom FSDP is designed with the following Python classes and data structure:

### 3. The custom FSDP interface: FullyShardedDataParallel
The custom FSDP provides the same programming interface as PyTorch's DistributedDataParallel (DDP) as FullyShardedDataParallel (FSDP). For example, you can apply FSDP to models as follows:
- You can configure which modules should be treated as FSDP units via the `fsdp_unit_modules` argument. This configuration is mandatory.
- The custom FSDP must be used with a distributed optimizer since it provides distributed checkpointing.
- The data-parallel communication group for parameters is not explicitly shown. Custom FSDP configures these groups as either DP (data-parallel) or EDP (expert data-parallel) based on parameter markings.
#### 3.1 Initializing Models on the Meta Device
For training particularly large models with FSDP, you can initialize the model on the meta device. Using PyTorch's `reset_parameters` API, you can initialize model weights layer by layer during the construction of the `ParamAndGradBuffer`. Most PyTorch native modules and TransformerEngine modules support this API (e.g., [PyTorch Linear](https://github.com/pytorch/pytorch/blob/v2.6.0/torch/nn/modules/linear.py#L114), [TE LayerNormLinear](https://github.com/NVIDIA/TransformerEngine/blob/release_v2.0/transformer_engine/pytorch/module/layernorm_linear.py#L1107)).
1.*Custom Modules*: If your model contains custom modules, ensure they implement the `reset_parameters` API. Otherwise, you may need to force parameter initialization on a CUDA or CPU device.
2.*Tensor Initialization*: Be cautious of tensors created during model initialization without a specified device—they will default to the meta device. To avoid issues, explicitly specify the device for these tensors to ensure compatibility with this function.
### 4. Interaction between Custom FSDP and Model Forward/Backward Propagation
Custom FSDP implements Fully Sharded Data Parallelism (FSDP) through a series of module hooks, gradient hooks, or by adding functions between modules. This involves inserting communications and manipulating parameters and gradients during PyTorch's module forward or backward propagation.
Module hooks summary:
- Module pre-forward hook(`module.register_forward_pre_hook`): This hook unshards model weights before the forward pass. In the case of an FSDP Unit Module, add a RegisterFSDPBackwardFunction function that will reshard model weights and reduce gradients after module backward propagation.
- Module post-forward hook(`module.register_forward_hook`): This hook is used to reshard model weights after the forward pass.
- Root module pre-backward hook(`root_module.register_full_backward_pre_hook`): This hook checks that all model parameters are resharded, in order to avoid unnecessary memory spikes. It also marks all modules as being in the `TrainingState.PRE_BACKWARD` state.
- Module pre-backward hook(`module.register_full_backward_pre_hook`): This hook is used to unshard the model weights before the backward pass.
- Root module post-backward hook(`torch.autograd.Variable._execution_engine.queue_callback`): This hook is used to make sure all gradients in the backprop are properly handled / available.
The gradient reduction pipeline maintains a map of gradients to FSDP parameter groups. If all gradients in an FSDP parameter group are ready, it launches a gradient reduction. Note that this assumes that the model's gradients are always generated in a certain order (reverse of `module.parameters()`), as otherwise, FSDP would maintain too many parameter group grad buffers, leading to excessive memory usage.
#### 4.1 Optimized for Activation Recompute
Using the activation recompute will cause the same module to execute the forward function first and then the backward function in the backward prop, which will cause model weights unshard twice and model weights reshard twice. If we can tell program that this is a forward + backward operation, we can just call unshard once and reshard once.
To make this determination, we keep track of the model's state with training_state, `FORWARD`, `PRE_BACKWARD`, `POST_BACKWARD`, `IDLE`. It's worth noting that pre-backward hook act before pre-forward hook, and we'll let pre-backward hook execute the model weight unshard, and then mark the model as `PRE_BACKWARD`, and when pre-forward hook sees this marking it will not perform the unshard operation. Similarly, for model weight reshard duplicate, post-forward hook act before post-backward function, and checking for the `PRE_BACKWARD` flag in the post-forward hook will cancel the unshard.
### 5. Memory Mechanisms and Features of Custom FSDP
FSDP can fully distribute the model parameters, gradients, and optimizer states, and for mixed-precision training, it can also fully distribute the high-precision main weights. This is pretty much distributes all the memory except for the activation memory, but FSDP will also face some memory issues.
FSDP frequently unshards and reshards model weights, which can lead to busy memory allocation and deallocation. This results in untimely tensor releases, causing memory spikes (or even out-of-memory errors), crashes of the PyTorch memory allocator cache, and a large number of `cudaMalloc` and `cudaFree` calls. These issues can significantly slow down the system.
The problem of untimely tensor release can generally be addressed using the `tensor._typed_storage(). _resize_(0)` API, which immediately deallocates the storage's memory. Custom FSDP provides interfaces in `AllGatherPipeline` and `GradReducePipeline` to replace the temporary buffer memory allocator used for parameter gathering and gradient reduction with ` StorageResizeBasedBucketAllocator`. This replaces the tensor release operation with the `tensor._typed_storage(). _resize_(0)` API.
The PyTorch memory allocator cache crash is a complex issue that occurs frequently when the actual memory usage approaches the GPU memory limit, leading to poor performance. This problem is challenging and can only be mitigated by avoiding frequent hits on the GPU memory limit. Using a self-managed memory allocator like ` RotaryBucketAllocator` is another potential solution. However, note that `RotaryBucketAllocator` is not yet mature.
## References
-[Getting Started with Fully Sharded Data Parallel (FSDP)](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html)
Package defining different checkpoint formats (backends) and saving/loading algorithms (strategies).
Strategies can be used for implementing new checkpoint formats or implementing new (more optimal for a given use case) ways of saving/loading of existing formats.
Strategies are passed to `dist_checkpointing.load` and `dist_checkpointing.save` functions and control the actual saving/loading procedure.
The motivation for the distributed optimizer is to save memory by distributing the optimizer state evenly across data parallel ranks (https://arxiv.org/abs/1910.02054), versus the naive method of replicating the optimizer state across data parallel ranks.
Theoretical memory savings vary depending on the combination of the datatype of the model's parameters (`param_dtype`) and main gradients accumulated across data-parallel replicas (`grad_dtype`). We always use `fp32` main parameters for optimizer steps. In the current implementation, the theoretical number of bytes per parameter is (where d is the data parallel size):
Our implementation of the distributed optimizer uses contiguous buffers for parameters and main gradients; model gradients are copied over to the main gradients as soon as they are fully computed.
The figures below illustrate the distributed optimizer's sharding scheme, and the key steps of the distributed optimizer's parameter update:
_(note: using illustrations above, assuming `bf16` model weights, `bf16` model gradients that are computed by the backward pass and `fp32` main gradients that are also used for optimizer steps; we always use `fp32` main weights for optimizer steps)_
- Each DP rank now has 4 elements within the gradient buffer that are fully reduced (remaining 12 elements are garbage).
- DP rank 0 has gradient values for elements [0:4].
- DP rank 1 has gradient values for elements [4:8].
- DP rank 2 has gradient values for elements [8:12].
- DP rank 3 has gradient values for elements [12:16].
- Optimizer.step().
- Each DP rank copies its 4 `fp32` main parameter elements into the corresponding `bf16` parameter buffer (each element is cast from fp32 to fp16).
- Call all-gather on each DP rank.
- The parameter buffer now contains all 16, fully updated, `bf16` model parameter elements. Parameters in PyTorch modules already point to the appropriate locations in this parameter buffer, and thus forward passes are ready to run after the all-gather completes.
- At this point, the gradient buffer is also ready to be zero'd for the next iteration.
This package provides modules that provide commonly fused
operations. Fusing operations improves compute efficiency by
increasing the amount of work done each time a tensor is read from
memory. To perform the fusion, modules in this either rely on PyTorch
functionality for doing just-in-time compilation
(i.e. `torch.jit.script` in older PyTorch versions of `torch.compile`
in recent versions), or call into custom kernels in external libraries
such as Apex or TransformerEngine.
Submodules
----------
fusions.fused\_bias\_dropout module
-----------------------------------
This module uses PyTorch JIT to fuse the bias add and dropout operations. Since dropout is not used during inference, different functions are used when in train mode and when in inference mode.
.. automodule:: core.fusions.fused_bias_dropout
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_bias\_gelu module
--------------------------------
This module uses PyTorch JIT to fuse the bias add and GeLU nonlinearity operations.
.. automodule:: core.fusions.fused_bias_gelu
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_layer\_norm module
---------------------------------
This module provides a wrapper around various fused LayerNorm implementation in Apex.
.. automodule:: core.fusions.fused_layer_norm
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_softmax module
-----------------------------
This module provides wrappers around variations of Softmax in Apex.
.. automodule:: core.fusions.fused_softmax
:members:
:undoc-members:
:show-inheritance:
fusions.fused\_cross\_entropy\_loss module
------------------------------------------
This module uses PyTorch JIT to fuse the cross entropy loss calculation and batches communication calls.