The following guide will show you how to quickly get started with Megatron Core. It will show you the following
* We will initalize megatron core on 2 GPUS.
* We will build a GPT model with tensor model parallel size 2, pipeline parallel size 1
* We will train it for a few iterations using megatron core schedules
* We will save the model using the distributed checkpointing format
* We will load the model saved above.
*NOTE: The following has been testing for megatron core version 0.5 and NGC Pytorch Container version 24.02
### Environment Setup
```
docker run --ipc=host --shm-size=512m --gpus all -it nvcr.io/nvidia/pytorch:24.02-py3
pip install megatron_core
pip install tensorstore==0.1.45
pip install zarr
```
<br>
### Writing Your First Training Loop
The following steps will walk you through how you can create a sample GPT model split across tensors (Tensor model parallel ) on 2 GPUS, and run a forward pass through it using a MockGPT dataset helper class that we created in Megatron core.
<br>
**NOTE: All of the folowing steps needs to be put into a script and then run as explained in the last step**
<br>
**STEP 1 - Initialize Distributed Training and Model parallel setup**
The following utility when called initalizes your distributed setup.
The following step shows you how you can quickly create a GPT model. For a list of other configs that you can pass into the model look into [transformer_config.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/transformer_config.py)
```
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
The following shows you how you can quickly get started with a mock dataset utility we created. In order to train with your data, please use the actual GPTDataset class in [gpt_dataset.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/gpt_dataset.py)
To find more information about megatron core data pipeline please refer to [this](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/readme.md?ref_type=heads)
```
from torch.utils.data import DataLoader
from megatron.core.datasets.utils import Split
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
In megatron core, we use [schedules.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/pipeline_parallel/schedules.py) to run the model. So it is sufficient to define a forward step function which takes as input the data iterator and the model and produces as output the output tensor and a loss function
Megatron core uses distributed checkpoint for loading and saving model. This gives you the flexiblity to convert model from one model parallel setting to another when you load a model (i.e A model trained with tensor parallel size 2, can now be loaded as tensor model parallel size 4 etc.)
*NOTE: Make sure you have zarr and tensorstore pip package installed as shown in the environment setup*
All the above steps are put to gether in a [run_simple_mcore_train_loop.py](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/run_simple_mcore_train_loop.py) script in examples folder in megatron . You can run it as follows
The above example introduced you to a basic training loop in MCore. To see more advanced examples please look at [pretrain_gpt.py]. That will show you how you can write more complex training loops, involving pipeline parallel, context parallel, rope embeddings, mixture of experts and all other functionalities present in mcore.
Each metric is prefixed with `Mn` or `Mx` to represent `Minimum` or `Maximum`. Each metric is also suffixed with the rank where the metric was measured. The metrics are averaged over the logging interval. Between the prefix and the rank is the name of the metric as follows
- Rtt : RoundTrip Time (time spent in all the traced ops per iteration)
- Pwr : GPU Power
- Tmp : GPU Temperature
- Utl : GPU Utilization
- Clk : GPU Clock
- DRtt: get_batch latency
- Etpt: Estimated throughput. This is derived from actual computed throughput dividied by Rtt. Since we do not collect timing for backward pass, the value is further divided by three to come up with estimated throughput.
<hr>
### Command Line activation
To start using the StragglerDetector, need to pass the following argument `--log-straggler`. It optionally also takes two additional parameters. Default disabled
-`--disable-straggler-on-startup` - whether to keept the StragglerDetector disabled on startup and enable later. Default enabled
-`--straggler-ctrlr-port` - The StragglerDetector can toggle between on/off just by sending `curl Rank0Host:port`. Default port is 65535. Every time it is turned
-`--straggler-minmax-count` - If set to > 1 (N), it prints N Top and Bottom Etpt/Rank pairs as shown below