Commit 0816dd4a authored by libo11's avatar libo11
Browse files

Initial commit

parents
Pipeline #1728 canceled with stages
node116 slots=8
node117 slots=8
node118 slots=8
node119 slots=8
## Quick Start
The following guide will show you how to quickly get started with Megatron Core. It will show you the following
* We will initalize megatron core on 2 GPUS.
* We will build a GPT model with tensor model parallel size 2, pipeline parallel size 1
* We will train it for a few iterations using megatron core schedules
* We will save the model using the distributed checkpointing format
* We will load the model saved above.
*NOTE: The following has been testing for megatron core version 0.8.0 and NGC Pytorch Container version 24.02
### Environment Setup
```
docker run --ipc=host --shm-size=512m --gpus 2 -it nvcr.io/nvidia/pytorch:24.02-py3
git clone https://github.com/NVIDIA/Megatron-LM.git && cd Megatron-LM
```
<br>
### Writing Your First Training Loop
The following steps will walk you through how you can create a sample GPT model split across tensors (Tensor model parallel ) on 2 GPUS, and run a forward pass through it using a MockGPT dataset helper class that we created in Megatron core.
<br>
**NOTE: All of the following steps are already put into a script [run_simple_mcore_train_loop.py](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/run_simple_mcore_train_loop.py) which you can run as follows**
```
PYTHONPATH=$PYTHON_PATH:./megatron torchrun --nproc-per-node 2 examples/run_simple_mcore_train_loop.py
```
<br>
**STEP 1 - Initialize Distributed Training and Model parallel setup**
The following utility when called initalizes your distributed setup.
```python
import os
import torch
from megatron.core import parallel_state
def initialize_distributed(tensor_model_parallel_size = 1, pipeline_model_parallel_size = 1):
# Torch setup for distributed training
rank = int(os.environ['LOCAL_RANK'])
world_size = torch.cuda.device_count()
torch.cuda.set_device(rank)
torch.distributed.init_process_group(world_size=world_size, rank=rank)
# Megatron core distributed training initialization
parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
```
<br>
**STEP 2 - GPT Model Setup**
The following step shows you how you can quickly create a GPT model. For a list of other configs that you can pass into the model look into [transformer_config.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/transformer/transformer_config.py)
```
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
def model_provider():
"""Build the model."""
transformer_config = TransformerConfig(
num_layers=2,
hidden_size=12,
num_attention_heads=4,
use_cpu_initialization=True,
pipeline_dtype=torch.float32)
gpt_model = GPTModel(
config=transformer_config,
transformer_layer_spec=get_gpt_layer_local_spec(),
vocab_size=100,
max_sequence_length=64)
return gpt_model
```
<br>
**STEP 3 - GPT Mock dataset setup**
The following shows you how you can quickly get started with a mock dataset utility we created. In order to train with your data, please use the actual GPTDataset class in [gpt_dataset.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/gpt_dataset.py)
To find more information about megatron core data pipeline please refer to [this](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/datasets/readme.md?ref_type=heads)
```
import torch
from torch.utils.data import DataLoader
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
from megatron.training.tokenizer.tokenizer import _NullTokenizer
from megatron.core.datasets.utils import compile_helpers
_SEQUENCE_LENGTH = 64
def get_train_data_iterator():
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
compile_helpers()
torch.distributed.barrier()
else:
compile_helpers()
config = GPTDatasetConfig(
random_seed=0,
sequence_length=_SEQUENCE_LENGTH,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
tokenizer=_NullTokenizer(vocab_size=_SEQUENCE_LENGTH),
)
datasets = BlendedMegatronDatasetBuilder(
MockGPTDataset, [1000, None, None], lambda: True, config
).build()
train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True)
train_iterator = iter(train_dataloader)
return train_iterator
```
<br>
**STEP 4 - Forward Step Function**
In megatron core, we use [schedules.py](https://github.com/NVIDIA/Megatron-LM/tree/main/megatron/core/pipeline_parallel/schedules.py) to run the model. So it is sufficient to define a forward step function which takes as input the data iterator and the model and produces as output the output tensor and a loss function
```python
from functools import partial
def forward_step_func(data_iterator, model):
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# If you have data parallel reduce loss across data parallel groups.
# If pipeline parallel, loss computation is done only in last stage.
return loss, {'lm loss': loss}
data = next(data_iterator)
tokens = data['tokens'].to(device)
attention_mask = data['attention_mask'].to(device)
position_ids = data['position_ids'].to(device)
labels = data['labels'].to(device)
loss_mask = data['loss_mask'].to(device)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
```
<br>
**STEP 5 - Load and Save Distributed Checkpoint**
Megatron core uses distributed checkpoint for loading and saving model. This gives you the flexiblity to convert model from one model parallel setting to another when you load a model (i.e A model trained with tensor parallel size 2, can now be loaded as tensor model parallel size 4 etc.)
```python
from megatron.core import dist_checkpointing
def save_distributed_checkpoint(checkpoint_path, gpt_model):
sharded_state_dict = gpt_model.sharded_state_dict(prefix='')
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
def load_distributed_checkpoint(checkpoint_path, gpt_model):
sharded_state_dict=gpt_model.sharded_state_dict(prefix='')
checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
gpt_model.load_state_dict(checkpoint)
return gpt_model
```
<br>
**STEP 6 - Main Function**
The following is the main function that needs to go into your script.
```python
from pathlib import Path
from torch.optim import Adam
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
if __name__ == "__main__":
initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
model_parallel_cuda_manual_seed(123)
gpt_model = model_provider()
device = torch.device("cuda")
gpt_model.to(device)
optim = Adam(gpt_model.parameters())
train_iterator = get_train_data_iterator()
forward_backward_func = get_forward_backward_func()
# Running the model for 5 iterations
for _ in range(5):
optim.zero_grad()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=train_iterator,
model=gpt_model,
num_microbatches=1,
seq_length=64,
micro_batch_size=8,
decoder_seq_length=64,
forward_only=False)
optim.step()
print(f'Losses reduced : {losses_reduced}')
# Saving the model
save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt')
# Loading the model
gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path='/workspace/ckpt')
gpt_model.to(device)
print('Successfully loaded the model')
```
<br>
### Extending Further
The above example introduced you to a basic training loop in MCore. To see more advanced examples please look at [pretrain_gpt.py]. That will show you how you can write more complex training loops, involving pipeline parallel, context parallel, rope embeddings, mixture of experts and all other functionalities present in mcore.
Megatron Core is a library for efficient and scalable training of transformer based models.
\ No newline at end of file
## StragglerDetector for a TP Group
The file `megatron/core/utils.py` has a class named `StragglerDetector` which supports Python Contexts.
It can be used to find straggling TP group based on the RTT of the ranks in the TP Group. It also collects
Power/Temp/Utilization for GPUs, which can additionally be used to narrow down to the exact GPU in the TP Group,
assuming the straggling was caused by hardware anomaly in a given GPU.<br>
This class supports collecting timing events for various steps of a given iteration. It
keeps collecting such timing events on a per rank basis, and when the reporter is invoked
during a logging interval, it computes the min and max of certain metric across all
ranks and logs the observed metric and the rank as follows
```
0: INFO:megatron.core.utils:[2024-03-14 23:07:56] | MnRtt/Rnk: 3453.08ms/8 | MxRtt/Rnk: 3468.20ms/0 | MnPwr/Rnk: 601796W/8 | MxPwr/Rnk: 683801W/18 | MnTmp/Rnk: 52C/0 | MxTmp/Rnk: 65C/21 | MnUtl/Rnk: 97%/8 | MxUtl/Rnk: 100%/6 | MnClk/Rnk: 1950MHz/28 | MxClk/Rnk: 1980MHz/0 | MnDRtt/Rnk: 14.27ms/23 | MxDRtt/Rnk: 34.65ms/3 | MnEtpt/Rnk: 296.02TF/0 | MxEtpt/Rnk: 297.32TF/8
```
<hr>
### Description of the metrics
Each metric is prefixed with `Mn` or `Mx` to represent `Minimum` or `Maximum`. Each metric is also suffixed with the rank where the metric was measured. The metrics are averaged over the logging interval. Between the prefix and the rank is the name of the metric as follows
- Rtt : RoundTrip Time (time spent in all the traced ops per iteration)
- Pwr : GPU Power
- Tmp : GPU Temperature
- Utl : GPU Utilization
- Clk : GPU Clock
- DRtt: get_batch latency
- Etpt: Estimated throughput. This is derived from actual computed throughput dividied by Rtt. Since we do not collect timing for backward pass, the value is further divided by three to come up with estimated throughput.
<hr>
### Command Line activation
To start using the StragglerDetector, need to pass the following argument `--log-straggler`. It optionally also takes two additional parameters. Default disabled
- `--disable-straggler-on-startup` - whether to keept the StragglerDetector disabled on startup and enable later. Default enabled
- `--straggler-ctrlr-port` - The StragglerDetector can toggle between on/off just by sending `curl Rank0Host:port`. Default port is 65535. Every time it is turned
- `--straggler-minmax-count` - If set to > 1 (N), it prints N Top and Bottom Etpt/Rank pairs as shown below
```
0: INFO:megatron.core.utils:^^^^ Bottom 4 Ranks with lowest Etpt(TF): 296.02/0, 296.17/2, 296.23/1, 296.23/4,
0: INFO:megatron.core.utils:^^^^ Top 4 Ranks with highest Etpt(TF): 297.28/15, 297.28/11, 297.32/12, 297.32/8,
```
<hr>
### Programming the StragglerDetector
The StragglerDetector class supports context, and its implementation is a Singleton.
- Initialization
```
# initialization, where StragglerDetector will be used
from megatron.core.utils import StragglerDetector
stimer = StragglerDetector()
```
- One time for each rank
```
# one time before the training loop starts
stimer.configure(world, rank, enabled=True, port=65545)
# Arguments to configure
# world : World Size
# rank : The rank of this trainer
# mmcnt : (Optional) Number of ranks to print for showing Min/Max Etpt
# amp : (Optional) Set to 3.0 if we only use timers in fwd pass
# port : (Optional) control port, useful only for rank-0
# prefill : (Optional) howmany Events to pre-populate
# enabled : (Optional) whether or not collection is enabled on startup
```
- To Capture time
```
# whereever timing need to be captured
with stimer:
do_operation()
# special case for get_batch
with stimer(bdata=True):
input,... = get_batch(iterator,...)
```
- Logging in main training loop
```
# logging
total_flops = 0.0
iteration = 0
# inside the main training loop
while training:
iteration += 1
do_step()
total_flops += get_computed_flops()
if iteration % log_interval:
stimer.report(total_flops, log_interval)
total_flops = 0.0
```
import megatron.core.tensor_parallel
import megatron.core.utils
from megatron.core import parallel_state
from megatron.core.distributed import DistributedDataParallel
from megatron.core.inference_params import InferenceParams
from megatron.core.model_parallel_config import ModelParallelConfig
from megatron.core.package_info import (
__contact_emails__,
__contact_names__,
__description__,
__download_url__,
__homepage__,
__keywords__,
__license__,
__package_name__,
__repository_url__,
__shortversion__,
__version__,
)
from megatron.core.timers import Timers
# Alias parallel_state as mpu, its legacy name
mpu = parallel_state
__all__ = [
"parallel_state",
"tensor_parallel",
"utils",
"DistributedDataParallel",
"InferenceParams",
"ModelParallelConfig",
"Timers",
]
CXXFLAGS += -O3 -Wall -shared -std=c++11 -fPIC -fdiagnostics-color
CPPFLAGS += $(shell python3 -m pybind11 --includes)
LIBNAME = helpers
LIBEXT = $(shell python3-config --extension-suffix)
default: $(LIBNAME)$(LIBEXT)
%$(LIBEXT): %.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $< -o $@
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment