Commit 1106877d authored by jerrrrry's avatar jerrrrry
Browse files

“13.0”

parents
Pipeline #2934 failed with stages
in 0 seconds
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
"""Export a GPTModel."""
import functools
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import modelopt.torch.export as mtex
import torch
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.model_provider import model_provider
from megatron.training import get_args, get_model
from megatron.training.initialize import initialize_megatron
from megatron.training.utils import unwrap_model
warnings.filterwarnings('ignore')
def add_modelopt_export_args(parser):
"""Add additional arguments for ModelOpt hf-like export."""
group = parser.add_argument_group(title='ModelOpt hf-like export')
group.add_argument(
"--export-extra-modules",
action="store_true",
help="Export extra modules such as Medusa, EAGLE, or MTP.",
)
group.add_argument(
"--pretrained-model-name",
type=str,
help="A pretrained model hosted inside a model repo on huggingface.co.",
)
group.add_argument("--export-dir", type=str, help="The target export path.")
add_modelopt_args(parser)
return parser
if __name__ == "__main__":
initialize_megatron(
extra_args_provider=add_modelopt_export_args,
args_defaults={
'tokenizer_type': 'HuggingFaceTokenizer',
'no_load_rng': True,
'no_load_optim': True,
},
)
args = get_args()
model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
if args.load is not None:
_ = load_modelopt_checkpoint(model)
unwrapped_model = unwrap_model(model)[0]
mtex.export_mcore_gpt_to_hf(
unwrapped_model,
args.pretrained_model_name,
export_extra_modules=args.export_extra_modules,
dtype=torch.bfloat16,
export_dir=args.export_dir,
)
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"
# Default arguments of this script
MLM_DEFAULT_ARGS="--finetune --auto-detect-ckpt-format --export-te-mcore-model --use-cpu-initialization"
if [ -z ${HF_MODEL_CKPT} ]; then
HF_MODEL_CKPT=${1}
fi
if [ -z ${HF_TOKEN} ]; then
printf "${MLM_WARNING} Variable ${PURPLE}HF_TOKEN${WHITE} is not set! Pretrained config download may fail!\n"
fi
if [ -z ${EXPORT_DIR} ]; then
EXPORT_DIR=${MLM_WORK_DIR}/${MLM_MODEL_CFG}_export
printf "${MLM_WARNING} Variable ${PURPLE}EXPORT_DIR${WHITE} is not set (default: ${EXPORT_DIR})!\n"
fi
if [ "${TP}" != "1" ]; then
TP=1
printf "${MLM_WARNING} Variable ${PURPLE}TP${WHITE} is forced to be 1 during export!!\n"
fi
# Default arguments of this script
MLM_DEFAULT_ARGS="--finetune --auto-detect-ckpt-format --export-te-mcore-model --use-cpu-initialization"
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/export.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
--pretrained-model-name ${HF_MODEL_CKPT} \
--export-dir ${EXPORT_DIR} \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Supervised Finetuning GPT."""
import itertools
import os
import sys
from functools import partial
from typing import Any, Dict, Optional
import jsonlines
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import datasets
import torch
import transformers
from megatron.core import mpu, tensor_parallel
from megatron.core.enums import ModelType
from megatron.core.models.gpt import GPTModel
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.model_provider import model_provider
from megatron.post_training.non_loss_data_func import report_draft_acceptance_length
from megatron.training import get_args, get_timers, get_tokenizer, pretrain
from megatron.training.utils import (
average_losses_across_data_parallel_group,
get_batch_on_this_cp_rank,
get_ltor_masks_and_position_ids,
print_rank_0,
unwrap_model,
)
REMOVE_THINK_CHAT_TEMPLATE = (
"{% if '</think>' in content %}{% set content = content.split('</think>')[-1] %}{% endif %}"
)
def get_eos_id():
tokenizer = get_tokenizer()
hf_tokenizer = tokenizer._tokenizer
if hf_tokenizer.eos_token == "<|eot_id|>":
return 128001
if hf_tokenizer.eos_token == "<|eot|>":
return 200001
return hf_tokenizer.eos_token_id
class SFTDataset(torch.utils.data.Dataset):
hf_dataset_to_kwargs = {
"Open-Orca/OpenOrca": {"split": "train"},
"Open-Orca/SlimOrca": {"split": "train"},
"nvidia/HelpSteer2": {"split": "train"},
"nvidia/Daring-Anteater": {"split": "train"},
"Magpie-Align/Magpie-Llama-3.1-Pro-MT-300K-Filtered": {"split": "train"},
"/hf-local/modelopt/AA-Synthetic-Scout": {"split": "train"},
"/hf-local/modelopt/Multilingual": {"split": "train"},
}
hf_dataset_to_conversation = {
"Open-Orca/OpenOrca": lambda data: SFTDataset._to_conversation(
data["question"], data["response"]
),
"Open-Orca/SlimOrca": lambda data: SFTDataset._sharegpt_to_openai_conversations(data),
"nvidia/HelpSteer2": lambda data: SFTDataset._to_conversation(
data["prompt"], data["response"]
),
"nvidia/Daring-Anteater": lambda data: SFTDataset._sharegpt_to_openai_conversations(data),
"Magpie-Align/Magpie-Llama-3.1-Pro-MT-300K-Filtered": lambda data: SFTDataset._sharegpt_to_openai_conversations(
data
),
"/hf-local/modelopt/AA-Synthetic-Scout": lambda data: SFTDataset._special_to_openai_conversations(
data
),
}
hf_dataset_to_prompt_template = {
"Open-Orca/OpenOrca": "{{ messages['question'] + ' ' + messages['response'] + ' ' }}",
"nvidia/HelpSteer2": "{{ messages['prompt'] + ' ' + messages['response'] + ' ' }}",
}
def __init__(
self,
num_packed_samples: int,
data_path: Optional[str],
tokenizer: transformers.PreTrainedTokenizerBase,
seq_length: int,
hf_dataset: Optional[str] = None,
num_medusa_heads: int = 0,
num_shards: int = 1,
shard_index: int = 0,
):
"""A simple dataset implementation for supervised fine-tuning.
The raw data is processed and packed to an indexed dataset on the fly. Users
specify the total number of packed samples and the dataloader (or sampler)
access the packed dataset by indices. When the packed dataset length is smaller
than the index, the packing process fetches the raw data in a cyclic fashion
until the packed dataset has sufficient length.
Args:
data_path: Path to the json or jsonl file
num_packed_samples: total number of packed samples (cyclic access)
tokenizer: hf tokenizer
seq_length: max sequence length
hf_dataset: not supported yet
num_medusa_heads: number of medusa heads will incease the sample sequence
length for training additional medusa prediction heads
"""
if not isinstance(tokenizer, transformers.PreTrainedTokenizerBase):
raise ValueError("SFTDataset only supports transformers.PreTrainedTokenizerBase!")
self.num_packed_samples = num_packed_samples
self.data_path = data_path
self.tokenizer = tokenizer
self.seq_length = seq_length
self.hf_dataset = hf_dataset
self.data_transformation = lambda data: data
self.num_shards = num_shards
self.shard_index = shard_index
self.num_medusa_heads = num_medusa_heads
self.indexed_dataset = []
self._raw_sample_index = 0
# [WAR]: For DeepSeek-V3/R1 tokenizer, we modify the chat_template such that the <think>
# tokens are preserved for supervised learning.
self.tokenizer.chat_template = self.tokenizer.chat_template.replace(
REMOVE_THINK_CHAT_TEMPLATE, ""
)
if data_path is not None:
if data_path.endswith(".json"):
self._raw_samples = json.load(open(data_path))
elif data_path.endswith(".jsonl"):
with jsonlines.open(data_path, mode='r') as reader:
self._raw_samples = [obj for obj in reader]
else:
raise ValueError("data_path must be json or jsonl")
elif self.hf_dataset is not None:
hf_dataset_kwargs = SFTDataset.hf_dataset_to_kwargs.get(
self.hf_dataset, {"split": "train"}
)
self._raw_samples = datasets.load_dataset(self.hf_dataset, **hf_dataset_kwargs)
self._raw_samples = self._raw_samples.shard(
num_shards=self.num_shards, index=shard_index
)
print(
"Rank {:3}/{:3} creates SFT data shard {:3}/{:3} with {:10} raw samples".format(
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
self.shard_index,
self.num_shards,
len(self._raw_samples),
),
flush=True,
)
else:
raise ValueError("Either hf_dataset or data_path must be provided!")
if self.tokenizer.chat_template is None:
self.tokenizer.chat_template = SFTDataset.hf_dataset_to_prompt_template
elif self.hf_dataset is not None:
self.data_transformation = SFTDataset.hf_dataset_to_conversation.get(
self.hf_dataset, lambda data: data
)
if self.tokenizer.chat_template is None:
raise ValueError("No valid chat template!")
def __len__(self):
return self.num_packed_samples
def __getitem__(self, idx):
"""Get the idx packed data.
The packed data index is different from the raw data index where a packed sample
of sequence-length may require concatenting multiple raw data. When all raw data
are used up, the last packed data is throw away, and we have a packed dataset
in memory. The packed data index may exceed the length of the packed dataset
which will just wrap in a cyclic fashion.
"""
idx = idx // self.num_shards
while idx >= len(self.indexed_dataset):
packed_samples = self._process_and_pack_example()
if packed_samples is None:
break
else:
self.indexed_dataset.append(packed_samples)
if len(self.indexed_dataset) % 10000 == 0:
print(
"Rank {:3}/{:3} requests {:10}/{:10} packed SFT sample".format(
torch.distributed.get_rank(),
torch.distributed.get_world_size(),
idx,
len(self.indexed_dataset),
),
flush=True,
)
idx = idx % len(self.indexed_dataset)
torch_sample = {}
for key, val in self.indexed_dataset[idx].items():
torch_sample[key] = torch.LongTensor(val)
return torch_sample
def _process_and_pack_example(self):
"""Process multiple raw data and pack them into fixed sequence length."""
required_packed_tokens = self.seq_length + 1 + self.num_medusa_heads
current_packed_samples = []
current_packed_samples_token_count = 0
while current_packed_samples_token_count < required_packed_tokens:
if self._raw_sample_index >= len(self._raw_samples):
return None
raw_sample = self._raw_samples[self._raw_sample_index]
self._raw_sample_index += 1
processed_sample = self._process_example(raw_sample)
if processed_sample is not None:
current_packed_samples.append(processed_sample)
current_packed_samples_token_count += processed_sample["token_count"]
packed_samples = {}
for key in ['input_ids', 'loss_mask']:
packed_samples[key] = list(
itertools.chain.from_iterable([obj[key] for obj in current_packed_samples])
)
for key in ['token_count']:
packed_samples[key] = [obj[key] for obj in current_packed_samples]
return packed_samples
def _process_example(self, example: Dict[str, Any]):
"""Apply the chat template and compute the answer-only loss mask."""
if not isinstance(example, Dict):
raise ValueError("The sample must be a Dict but got {}".format(type(example)))
# Several things can happen here after the transformation is applied:
#
# 1. If the transformation is identity transformation, then either the chat data
# is already in OpenAI chat format or there is a custom prompt template used.
# 2. Otherwise, the tokenizer must have a default chat template and we are either
# converting the ShareGPT chat data or standard SFT data to OpenAI chat data.
example = self.data_transformation(example)
# Check if this is OpenAI chat data?
conversations = example.get("conversations", None)
if conversations is None:
conversations = example.get("messagess", None)
# We don't use the data if there is no assistant reply or the conversation that
# starts with the assistant.
if conversations is not None:
example = conversations
if len(conversations) < 2 or example[0]["role"] == "assistant":
return None
# We always add eos between samples for training purpose.
input_ids = self.tokenizer.apply_chat_template(example)
current_loss_mask = [1] * len(input_ids)
input_ids = input_ids + [get_eos_id()]
current_loss_mask += [0]
assert len(input_ids) == len(current_loss_mask)
if len(input_ids) > self.seq_length:
input_ids = input_ids[: self.seq_length]
current_loss_mask = current_loss_mask[: self.seq_length]
processed_example = {
'input_ids': input_ids,
'loss_mask': current_loss_mask,
'token_count': len(input_ids),
}
return processed_example
@classmethod
def _to_conversation(cls, question, response):
msg_question = {"role": "user", "content": question}
msg_response = {"role": "assistant", "content": response}
return {"conversations": [msg_question, msg_response]}
@classmethod
def _sharegpt_to_openai_conversations(cls, data):
role_mapping = {
"user": "user",
"User": "user",
"human": "user",
"assistant": "assistant",
"Assistant": "assistant",
"gpt": "assistant",
"system": "system",
"System": "system",
}
processed_data = {"conversations": []}
for msg in data["conversations"]:
role = role_mapping[msg["from"]]
content = msg["value"]
processed_data["conversations"].append({"role": role, "content": content})
return processed_data
@classmethod
def _special_to_openai_conversations(cls, data):
processed_data = {"conversations": data["input"]["messages"]}
return processed_data
def train_valid_test_sft_datasets_provider(train_val_test_num_samples):
"""Build the train test and validation datasets.
Args:
train_val_test_num_samples : A list containing the number of samples
in train test and validation.
"""
print_rank_0("> building train, validation, and test SFT datasets ...")
args = get_args()
tokenizer = get_tokenizer()
if not isinstance(tokenizer._tokenizer, transformers.PreTrainedTokenizerBase):
raise ValueError("SFTDataset only supports transformers.PreTrainedTokenizerBase!")
if args.micro_batch_size > 1:
raise ValueError("SFTDataloader only supports micro_batch_size=1.")
# Providing additional Medusa arguments to prepare the data
kwargs = {
"tokenizer": tokenizer._tokenizer,
"seq_length": args.seq_length,
# Optional kwargs
"hf_dataset": args.finetune_hf_dataset,
"num_shards": mpu.get_data_parallel_world_size(),
"shard_index": mpu.get_data_parallel_rank(),
"num_medusa_heads": args.export_num_medusa_heads,
}
data_path = [
args.train_data_path[0] if args.train_data_path else None,
args.valid_data_path[0] if args.valid_data_path else None,
args.test_data_path[0] if args.test_data_path else None,
]
train_ds = SFTDataset(train_val_test_num_samples[0], data_path[0], **kwargs)
valid_ds = SFTDataset(train_val_test_num_samples[1], data_path[1], **kwargs)
test_ds = SFTDataset(train_val_test_num_samples[2], data_path[2], **kwargs)
print_rank_0("> finished creating SFT datasets ...")
return train_ds, valid_ds, test_ds
def get_batch(data_iterator):
"""Generate a batch."""
# TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()):
return None, None, None, None, None
args = get_args()
# Items and their type.
keys = ["input_ids", "loss_mask"]
datatype = torch.int64
# Broadcast data since only TP rank-0 has the data_iterator.
if data_iterator is not None:
data = next(data_iterator)
else:
data = None
data_b = tensor_parallel.broadcast_data(keys, data, datatype)
# Unpack the data received.
tokens_ = data_b["input_ids"]
tokens = tokens_[:, 0 : 0 + args.seq_length].contiguous()
labels = tokens_[:, 1 : 1 + args.seq_length].contiguous()
answer_only_loss_mask = data_b["loss_mask"][:, 1 : 1 + args.seq_length].contiguous()
# Get the masks and postition ids.
attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids(
tokens, get_eos_id(), args.reset_position_ids, args.reset_attention_mask, args.eod_mask_loss
)
loss_mask = loss_mask * answer_only_loss_mask.to(dtype=loss_mask.dtype)
# Medusa label and loss_mask preparation
#
# Explanation:
#
# To predict 1 + k labels, an input tokens need to have additional k tokens. Given
# sequence length s, then overall s + 1 + k tokens are fed from the dataset.
#
# inputs = tokens[0:s]
# labels = tokens[1:1+s]
# kth medusa head labels = tokens[1+k:1+k+s]
#
# Examples: (s=5, k=2)
#
# | 0 1 2 3 4 5 6 7 |
# ------------------|-----------------|
# tokens | x x x x x x x x |
# inputs | x x x x x |
# lm_head labels | x x x x x | (next token prediction)
# 1st medusa labels | x x x x x | (next-next token prediction)
# 2nd medusa labels | x x x x x | (next-next-next token prediction)
#
for i in range(args.export_num_medusa_heads):
new_labels = tokens_[:, 2 + i : 2 + i + args.seq_length]
new_loss_mask = data_b["loss_mask"][:, 2 + i : 2 + i + args.seq_length].to(
dtype=loss_mask.dtype
)
labels = torch.cat((labels, new_labels), dim=-1)
loss_mask = torch.cat((loss_mask, new_loss_mask), dim=-1)
if args.export_num_medusa_heads > 0:
loss_mask = loss_mask.view(args.export_num_medusa_heads + 1, -1)
# if args.export_num_eagle_layers > 0:
# loss_mask = loss_mask[:, 1:]
# MTP label and loss_mask preparation
# Examples: (s=5, k=2)
#
# | 0 1 2 3 4 5 |
# ------------------|-------------|
# tokens | x x x x x x |
# inputs | x x x x x |
# lm_head labels | x x x x x | (next token prediction)
# mtp_0 labels | x x x x | (next-next token prediction)
# mtp_1 labels | x x x | (next-next-next token prediction)
#
# mtp_i_labels = labels[:, 1 + i :]
# So we do not need to prepare extra labels for mtp
# Modelopt will shift labels and reuse them
#
# loss_mask
if args.export_num_mtp > 0:
loss_masks = []
for i in range(args.export_num_mtp):
new_loss_mask = data_b["loss_mask"][:, 2 + i : 1 + args.seq_length].to(
dtype=loss_mask.dtype, device=loss_mask.device
)
if i in args.export_freeze_mtp:
new_loss_mask = torch.zeros_like(
new_loss_mask, dtype=loss_mask.dtype, device=loss_mask.device
)
loss_masks.append(new_loss_mask)
loss_mask = torch.cat(loss_masks, dim=-1)
labels = labels.contiguous()
loss_mask = loss_mask.contiguous()
batch = {
"tokens": tokens,
"labels": labels,
"loss_mask": loss_mask,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank(batch)
return batch.values()
def _mask_loss(output_tensor, loss_mask, mp_reduce=False):
"""Apply mask to the unreduced loss tensor."""
args = get_args()
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
if args.context_parallel_size > 1:
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), loss_mask.sum().view(1)])
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group())
loss = loss[0] / loss[1]
else:
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
if mp_reduce and args.tensor_model_parallel_size > 1:
# KD loss requires extra all-reduce to ensure same values across MP-TP partitions.
loss = torch.sum(tensor_parallel.gather_from_tensor_model_parallel_region(loss.reshape(1)))
return loss
def _allreduce_loss(loss):
"""Reduce loss for reporting purposes."""
args = get_args()
# Check individual rank losses are not NaN prior to DP all-reduce.
if args.check_for_nan_in_loss_and_grad:
global_rank = torch.distributed.get_rank()
assert not loss.isnan(), (
f'Rank {global_rank}: found NaN in local forward loss calculation. '
f'Device: {torch.cuda.current_device()}, node: {os.uname()[1]}'
)
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss * args.context_parallel_size, averaged_loss[0]
def loss_func(loss_mask: torch.Tensor, model: GPTModel, output_tensor: torch.Tensor):
"""Loss function (with KD Loss support).
Args:
loss_mask (Tensor): Used to mask out some portions of the loss
model (GPTModel): The model (can be wrapped)
output_tensor (Tensor): The tensor with the losses
"""
args = get_args()
# Unwrap for both Distillation and LANA
model = unwrap_model(model)
# Standard lm loss
output_tensor = output_tensor.float() # cache
loss_lm = _mask_loss(output_tensor, loss_mask)
loss_lm, loss_lm_avg = _allreduce_loss(loss_lm)
loss, report = loss_lm, {'lm loss': loss_lm_avg}
return loss, report
def non_loss_data_func(model: GPTModel):
"""Callback to compute the acceptance length."""
report_draft_acceptance_length(model)
def forward_step(data_iterator, model: GPTModel):
"""Forward training step.
Args:
data_iterator: Input data iterator
model: The GPT Model
"""
timers = get_timers()
# Get the batch.
timers("batch-generator", log_level=2).start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers("batch-generator").stop()
output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
return output_tensor, partial(loss_func, loss_mask, model)
if __name__ == "__main__":
pretrain(
train_valid_test_sft_datasets_provider,
model_provider,
ModelType.encoder_or_decoder,
forward_step,
extra_args_provider=add_modelopt_args,
args_defaults={"tokenizer_type": "HuggingFaceTokenizer"},
non_loss_data_func=non_loss_data_func,
)
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"
# Extra arguments of this script
MLM_DEFAULT_ARGS=" \
--distributed-timeout-minutes 30 \
--auto-detect-ckpt-format \
--export-te-mcore-model \
--finetune \
"
if [ -z ${MLM_MODEL_SAVE} ]; then
MLM_MODEL_SAVE=${MLM_MODEL_CKPT}
printf "${MLM_WARNING} Variable ${PURPLE}MLM_MODEL_SAVE${WHITE} is not set (default: ${MLM_MODEL_CKPT})!\n"
fi
if [ -z ${MLM_DATA_ARGS} ]; then
MLM_DATA_ARGS=" \
--train-samples 128000 \
--lr-decay-samples 128000 \
--lr-warmup-samples 0 \
--split 100,0,0 \
--finetune-hf-dataset nvidia/Daring-Anteater \
"
fi
if [ -z ${MLM_TRAIN_ARGS} ]; then
MLM_TRAIN_ARGS=" \
--recompute-activations \
--no-gradient-accumulation-fusion \
--reset-position-ids \
--reset-attention-mask \
--eod-mask-loss \
--global-batch-size 128 \
--micro-batch-size 1 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--no-check-for-nan-in-loss-and-grad \
"
fi
if [ -z ${MLM_OPTIM_ARGS} ]; then
MLM_OPTIM_ARGS=" \
--lr 1.0e-5 \
--min-lr 1.0e-7 \
--lr-decay-style cosine \
--clip-grad 1.0 \
--weight-decay 0.0 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.010 \
"
fi
if [ -z ${MLM_EVAL_ARGS} ]; then
MLM_EVAL_ARGS=" \
--eval-iters 1 \
--eval-interval 1000 \
--save-interval 1000 \
--log-interval 100 \
"
fi
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/finetune.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
--save ${MLM_MODEL_SAVE} \
${MLM_DATA_ARGS} \
${MLM_OPTIM_ARGS} \
${MLM_TRAIN_ARGS} \
${MLM_EVAL_ARGS} \
${MLM_RESUME_ARGS} \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Sample Generate GPT."""
import functools
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import torch
from datasets import load_dataset
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.generate import simple_generate
from megatron.post_training.model_provider import model_provider
from megatron.post_training.utils import report_current_memory_info
from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron
from megatron.training.utils import print_rank_0, unwrap_model
warnings.filterwarnings('ignore')
def add_generate_args(parser):
"""Add additional arguments for ModelOpt acceptance rate validation."""
group = parser.add_argument_group(title='ModelOpt ar validation')
group.add_argument("--osl", type=int, default=128, help="Output sequence length.")
group.add_argument("--draft-length", type=int, default=0, help="Only used in EAGLE.")
group.add_argument("--draft-topk", type=int, default=1, help="Only used in EAGLE.")
group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.")
group.add_argument("--percentage", type=float, default=1.0)
add_modelopt_args(parser)
return parser
def check_arguments():
"""Checking user arguments."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True:
print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.")
args.moe_grouped_gemm = False
def mtbench_to_oai_chat(example):
"""Convert MTBench data to OpenAI chat completion format."""
conversations = []
for prompt in example["prompt"]:
conversations.append({"role": "user", "content": prompt})
example["conversations"] = conversations
return example
def get_conversations(example):
"""Extract the input for tokenizer.apply_chat_template."""
conversations = example.get("conversations", None)
if conversations is None:
conversations = example.get("messages", None)
if conversations is None:
raise ValueError(
"The data must either have conversations or messages field, but got {}".format(example)
)
return conversations
if __name__ == "__main__":
initialize_megatron(
extra_args_provider=add_generate_args,
args_defaults={
'tokenizer_type': 'HuggingFaceTokenizer',
'no_load_rng': True,
'no_load_optim': True,
},
)
check_arguments()
args = get_args()
default_conversations = [
{
"role": "user",
"content": "Write an email to a wine expert, requesting a guest "
"article contribution for your wine blog.",
}
]
if args.finetune_hf_dataset is None:
if args.draft_length > 0:
dataset = load_dataset("HuggingFaceH4/mt_bench_prompts", split="train")
dataset = dataset.map(mtbench_to_oai_chat)
else:
dataset = [{"conversations": default_conversations}]
else:
dataset = load_dataset(args.finetune_hf_dataset, split=args.finetune_data_split)
tokenizer = get_tokenizer()._tokenizer
model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
report_current_memory_info()
if args.load is not None:
load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
print_rank_0("Done loading checkpoint")
unwrapped_model = unwrap_model(model)[0]
unwrapped_model.eval()
for idx, example in enumerate(dataset):
if idx > args.percentage * len(dataset):
break
ref_conversations = get_conversations(example)
new_conversations = []
for message in ref_conversations:
ground_truth = None
if message["role"] == "assistant":
ground_truth = message["content"]
if message["role"] == "user":
new_conversations.append(message)
print_rank_0(
"{}".format(
tokenizer.apply_chat_template(
new_conversations, tokenize=False, add_generation_prompt=True
)
)
)
input_ids = tokenizer.apply_chat_template(
new_conversations, return_tensors="pt", add_generation_prompt=True
)
output_ids = simple_generate(
unwrapped_model, input_ids.cuda(), osl=args.osl, disable_tqdm=args.disable_tqdm
)
output_texts = tokenizer.batch_decode(output_ids)[0]
print_rank_0("{}".format(output_texts))
new_conversations.append({"role": "assistant", "content": output_texts})
torch.distributed.barrier()
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"
# Extra arguments of this script
MLM_DEFAULT_ARGS="--finetune --auto-detect-ckpt-format --export-te-mcore-model"
if [ -z ${MLM_MODEL_CKPT} ]; then
printf "${MLM_ERROR} Variable ${PURPLE}MLM_MODEL_CKPT${WHITE} must be set!\n"
exit 1
fi
if [ -z ${DRAFT_LEN} ]; then
DRAFT_LEN=0
fi
if [ -z ${PROMPTS_PATH} ]; then
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/generate.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
--draft-length ${DRAFT_LEN} \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
else
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/generate.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
--data ${PROMPTS_PATH} \
--draft-length ${DRAFT_LEN} \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
fi
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Sample Generate GPT."""
import functools
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import torch
from datasets import load_dataset
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.generate import simple_generate
from megatron.post_training.model_provider import model_provider
from megatron.post_training.utils import report_current_memory_info
from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron
from megatron.training.utils import print_rank_0, unwrap_model
warnings.filterwarnings('ignore')
def add_mmlu_args(parser):
"""Add additional arguments for ModelOpt text generation PTQ."""
group = parser.add_argument_group(title='ModelOpt text generation ptq')
group.add_argument("--disable-tqdm", action="store_true", help="Disable tqdm.")
group.add_argument("--percentage", type=float, default=1.0)
group.add_argument("--lower-bound", type=float, default=None)
add_modelopt_args(parser)
return parser
def get_all_subjects():
"""Return all MMLU subjects."""
return [
'abstract_algebra',
'anatomy',
'astronomy',
'business_ethics',
'clinical_knowledge',
'college_biology',
'college_chemistry',
'college_computer_science',
'college_mathematics',
'college_medicine',
'college_physics',
'computer_security',
'conceptual_physics',
'econometrics',
'electrical_engineering',
'elementary_mathematics',
'formal_logic',
'global_facts',
'high_school_biology',
'high_school_chemistry',
'high_school_computer_science',
'high_school_european_history',
'high_school_geography',
'high_school_government_and_politics',
'high_school_macroeconomics',
'high_school_mathematics',
'high_school_microeconomics',
'high_school_physics',
'high_school_psychology',
'high_school_statistics',
'high_school_us_history',
'high_school_world_history',
'human_aging',
'human_sexuality',
'international_law',
'jurisprudence',
'logical_fallacies',
'machine_learning',
'management',
'marketing',
'medical_genetics',
'miscellaneous',
'moral_disputes',
'moral_scenarios',
'nutrition',
'philosophy',
'prehistory',
'professional_accounting',
'professional_law',
'professional_medicine',
'professional_psychology',
'public_relations',
'security_studies',
'sociology',
'us_foreign_policy',
'virology',
'world_religions',
]
def format_example(example, include_answer: bool = True):
"""Format an example into a multi-choices problem."""
prompt = example["question"]
for choice, answer in zip(["A", "B", "C", "D"], example["choices"]):
prompt += "\n{}. {}".format(choice, answer)
if include_answer:
prompt += "Answer: {}\n\n".format(example["answer"])
else:
prompt += "\nAnswer:"
return prompt
def generate_prompt(test_example, dev_examples, few_shots=0):
"""Generating few-shot prompts."""
prompt = "The following are multiple choice questions (with answers) about {}.\n\n".format(
" ".join(test_example["subject"].split("_"))
)
for i in range(few_shots):
prompt += format_example(dev_examples[i])
prompt += format_example(test_example, include_answer=False)
return prompt
if __name__ == "__main__":
initialize_megatron(
extra_args_provider=add_mmlu_args,
args_defaults={
'tokenizer_type': 'HuggingFaceTokenizer',
'no_load_rng': True,
'no_load_optim': True,
},
)
args = get_args()
disable_tqdm = args.disable_tqdm or torch.distributed.get_rank() > 0
tokenizer = get_tokenizer()._tokenizer
model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
report_current_memory_info()
if args.load is not None:
load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
print_rank_0("Done loading checkpoint")
unwrapped_model = unwrap_model(model)[0]
all_subjects = get_all_subjects()
all_correct = {}
for subject in all_subjects:
test_data = load_dataset("cais/mmlu", subject, split="test")
dev_data = load_dataset("cais/mmlu", subject, split="dev")
correct = []
for idx, test_example in enumerate(test_data):
if idx > args.percentage * len(test_data):
break
prompt = generate_prompt(test_example, dev_data, few_shots=0)
label = ["A", "B", "C", "D"][test_example["answer"]]
tokens = tokenizer(prompt, return_tensors="pt")
generated_ids = simple_generate(
unwrapped_model, tokens.input_ids.cuda(), osl=2, disable_tqdm=disable_tqdm
)
predict = tokenizer.batch_decode(generated_ids)[0].strip()
correct += [True] if predict.startswith(label) else [False]
all_correct[subject] = correct
if torch.distributed.get_rank() == 0:
print(
"{:48}| {:.3f} | {:5}/{:5}".format(
subject, sum(correct) / len(correct), sum(correct), len(correct)
),
flush=True,
)
avg_correct = []
for subject, correct in all_correct.items():
avg_correct += correct
if torch.distributed.get_rank() == 0:
print(
"{:48}| {:.3f} | {:5}/{:5}".format(
"average", sum(avg_correct) / len(avg_correct), sum(avg_correct), len(avg_correct)
),
flush=True,
)
if args.lower_bound is not None:
assert sum(avg_correct) / len(avg_correct) > args.lower_bound
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"
# Extra arguments of this script
MLM_DEFAULT_ARGS="--finetune --auto-detect-ckpt-format --export-te-mcore-model --sequence-parallel"
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/mmlu.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Sample Generate GPT."""
import functools
import os
import sys
import warnings
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")))
import modelopt
import modelopt.torch.quantization as mtq
import torch
from datasets import load_dataset
from packaging.version import Version
from tqdm import tqdm
from megatron.post_training.arguments import add_modelopt_args
from megatron.post_training.checkpointing import load_modelopt_checkpoint
from megatron.post_training.generate import simple_generate
from megatron.post_training.model_provider import model_provider
from megatron.post_training.utils import report_current_memory_info
from megatron.training import get_args, get_model, get_tokenizer, initialize_megatron
from megatron.training.checkpointing import save_checkpoint
from megatron.training.utils import print_rank_0, unwrap_model
warnings.filterwarnings('ignore')
QUANT_CFG_CHOICES = {
"int8_sq": mtq.INT8_SMOOTHQUANT_CFG,
"fp8": mtq.FP8_DEFAULT_CFG,
"fp8_real_quant": mtq.FP8_DEFAULT_CFG,
"fp8_blockwise": mtq.FP8_2D_BLOCKWISE_WEIGHT_ONLY_CFG,
"int4_awq": mtq.INT4_AWQ_CFG,
"w4a8_awq": mtq.W4A8_AWQ_BETA_CFG,
"nvfp4": mtq.NVFP4_DEFAULT_CFG,
}
def add_text_generate_ptq_args(parser):
"""Add additional arguments for ModelOpt text generation PTQ."""
group = parser.add_argument_group(title='ModelOpt text generation ptq')
group.add_argument(
"--calib-size", type=int, default=512, help="Samples to use for ptq calibration."
)
parser.add_argument(
"--prompts",
type=str,
default=("Hello!|Born in California, Soyer trained as a"),
help="Input texts. Please use | to separate different batches.",
)
parser.add_argument(
"--references",
type=str,
default="",
help="Reference texts. Please use | to separate different batches.",
)
parser.add_argument(
"--pretrained-model-path", type=str, default=None, help="HuggingFace pretrained model"
)
add_modelopt_args(parser)
return parser
def check_arguments():
"""Checking user arguments."""
args = get_args()
if args.num_layers_per_virtual_pipeline_stage is not None:
print_rank_0("Interleaved pipeline schedule is not yet supported for text generation.")
exit()
if hasattr(args, 'moe_grouped_gemm') and args.moe_grouped_gemm == True:
print_rank_0("WARNING: Forcing moe_grouped_gemm to False for PTQ and export.")
args.moe_grouped_gemm = False
def get_modelopt_torch_quantization_config():
"""Return a quantization config."""
args = get_args()
mtq_config = QUANT_CFG_CHOICES[args.export_quant_cfg]
fp8_config = {"enable": True, "num_bits": (4, 3), "axis": None}
fp4_config = {
"num_bits": (2, 1),
"block_sizes": {-1: 16, "type": "dynamic", "scale_bits": (4, 3)},
"axis": None,
"enable": True,
}
if "fp8" == args.export_quant_cfg:
# Enable Medusa heads and kv-cache quantization
mtq_config["quant_cfg"]["*medusa_heads**"] = fp8_config
if "fp4" in args.export_quant_cfg:
# Enable Medusa heads and kv-cache quantization
mtq_config["quant_cfg"]["*medusa_heads**"] = fp4_config
if "awq" in args.export_quant_cfg:
weight_quantizer = mtq_config["quant_cfg"]["*weight_quantizer"] # type: ignore
if isinstance(weight_quantizer, list):
weight_quantizer = weight_quantizer[0]
weight_quantizer["block_sizes"][-1] = 128
if args.export_kv_cache_quant:
mtq_config["quant_cfg"]["*linear_qkv.output_quantizer"] = fp8_config
return mtq_config
def get_calib_dataloader(calib_size=512, max_sequence_length=512):
"""Return a dataloader for calibration."""
dataset = load_dataset("cnn_dailymail", name="3.0.0", split="train")
text_column = "article"
calib_size = min(len(dataset), calib_size)
for i in range(calib_size):
yield dataset[i][text_column][:max_sequence_length]
if __name__ == "__main__":
initialize_megatron(
extra_args_provider=add_text_generate_ptq_args,
args_defaults={
'tokenizer_type': 'HuggingFaceTokenizer',
'no_load_rng': True,
'no_load_optim': True,
},
)
check_arguments()
args = get_args()
tokenizer = get_tokenizer()._tokenizer
model = get_model(functools.partial(model_provider, parallel_output=True), wrap_with_ddp=False)
report_current_memory_info()
if args.load is not None:
load_modelopt_checkpoint(model, strict=not args.untie_embeddings_and_output_weights)
print_rank_0("Done loading checkpoint")
if args.pretrained_model_path is not None:
from modelopt.torch.export import import_mcore_gpt_from_hf
unwrapped_model = unwrap_model(model)[0]
workspace_dir = os.environ.get("MLM_WORK_DIR", "/tmp")
import_mcore_gpt_from_hf(unwrapped_model, args.pretrained_model_path, workspace_dir)
def _custom_prompt_forward_loop_func(model):
all_prompts = args.prompts.split("|")
if args.references == "":
all_references = [None] * len(all_prompts)
else:
all_references = args.references.split("|")
for idx, prompt in tqdm(enumerate(all_prompts), disable=torch.distributed.get_rank()):
tokens = tokenizer(prompt, return_tensors="pt")
generated_ids = simple_generate(model, tokens.input_ids.cuda(), osl=32)
generated_texts = tokenizer.batch_decode(generated_ids)
print_rank_0("{}".format(generated_texts))
if all_references[idx] is not None:
assert all_references[idx] == generated_texts[0], all_references[idx]
def _hf_dataset_forword_loop_func(model):
dataloader = get_calib_dataloader(args.calib_size)
for prompt in tqdm(dataloader, total=args.calib_size, disable=torch.distributed.get_rank()):
tokens = tokenizer(prompt, return_tensors="pt")
generated_ids = simple_generate(model, tokens.input_ids.cuda(), osl=1)
unwrapped_model = unwrap_model(model)[0]
if args.export_quant_cfg in QUANT_CFG_CHOICES:
print_rank_0("Quantizing the model...")
mtq_config = get_modelopt_torch_quantization_config()
ptq_forward_loop_func = _hf_dataset_forword_loop_func
if hasattr(unwrapped_model, "calibration_mode"):
unwrapped_model.calibration_mode = True
mtq.quantize(unwrapped_model, mtq_config, ptq_forward_loop_func)
unwrapped_model.calibration_mode = False
else:
mtq.quantize(unwrapped_model, mtq_config, ptq_forward_loop_func)
if "real_quant" in args.export_quant_cfg:
mtq.compress(unwrapped_model)
print_rank_0(f"Fake Quantized Model:\n {unwrapped_model}")
if torch.distributed.get_rank() == 0:
for k, v in unwrapped_model.state_dict().items():
if "amax" not in k:
continue
if isinstance(v, torch.Tensor):
print("{:80} {:32} max {:.4e}".format(k, str(v.shape), torch.max(torch.abs(v))))
else:
print("{:80}".format(k))
_custom_prompt_forward_loop_func(unwrapped_model)
if args.save is not None and args.export_quant_cfg in QUANT_CFG_CHOICES:
save_checkpoint(1, model, None, None, 0)
#!/bin/bash
SCRIPT_DIR="$(dirname "$(readlink -f "$0")")"
# Common arguments and base model specific arguments
source "${SCRIPT_DIR}/conf/arguments.sh"
# Extra arguments of this script
MLM_DEFAULT_ARGS="--finetune --auto-detect-ckpt-format --export-te-mcore-model --sequence-parallel"
QUANT_CFG=$2
if [ -z ${QUANT_CFG} ]; then
QUANT_CFG=fp8
printf "${MLM_WARNING} Variable ${PURPLE}QUANT_CFG${WHITE} is not set (default: ${QUANT_CFG})!\n"
fi
if [ -z ${MLM_QUANT_CKPT} ]; then
MLM_QUANT_CKPT=${MLM_WORK_DIR}/${MLM_MODEL_CFG}_quant
printf "${MLM_WARNING} Variable ${PURPLE}MLM_QUANT_CKPT${WHITE} is not set (default: ${MLM_QUANT_CKPT})!\n"
fi
if [ -z ${MLM_MODEL_CKPT} ]; then
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/quantize.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--pretrained-model-path ${HF_MODEL_CKPT} \
--save ${MLM_QUANT_CKPT} \
--export-quant-cfg ${QUANT_CFG} \
--references "${MLM_REF_LABEL}" \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
else
${LAUNCH_SCRIPT} ${SCRIPT_DIR}/quantize.py \
${MODEL_ARGS} \
--tensor-model-parallel-size ${TP} \
--expert-model-parallel-size ${EP} \
--pipeline-model-parallel-size ${PP} \
--tokenizer-model ${TOKENIZER_MODEL} \
--load ${MLM_MODEL_CKPT} \
--save ${MLM_QUANT_CKPT} \
--export-quant-cfg ${QUANT_CFG} \
--references "${MLM_REF_LABEL}" \
${MLM_DEFAULT_ARGS} ${MLM_EXTRA_ARGS}
fi
datasets
jsonlines
mamba-ssm
causal-conv1d
nvidia-modelopt
omegaconf
pulp
tensorstore!=0.1.46,!=0.1.72
torchprofile
transformers
zarr
# Speculative Decoding
[Medusa](https://arxiv.org/abs/2401.10774) and [EAGLE](https://arxiv.org/pdf/2401.15077)
training and model export are supported (fast decoding is supported through TensorRT-LLM).
To run the examples, follow [README.md](README.md) to setup the containerized environment
and `NGC_CLI_API_KEY`, then
```sh
TP=8 bash medusa_sft.sh meta-llama/Llama-3.1-8B-Instruct
```
EAGLE training is similar. Just replace `medusa_sft.sh` with `eagle_sft.sh`
(requires `nvidia-modelopt>=0.20.0`).
Medusa head top-1 accuracy is reported per step (**NOTE:** the accuracy here does not
translate to the acceptance rate described in the writeup. The top-1 of the 1st head
can however signal whether the training is converged). By the end of the example, the
end results are stored in the following locations.
```sh
/tmp/megatron_workspace/meta-llama/
├── Llama-3.1-8B-Instruct_medusa
│   ├── iter_0000001
│   └── ...
├── Llama-3.1-8B-Instruct_medusa_quant
│   ├── iter_0000001
│   └── ...
└── Llama-3.1-8B-Instruct_medusa_quant_trtllm_export
```
`Llama-3.1-8B-Instruct_medusa_quant_trtllm_export` is the TensorRT-LLM checkpoint. To
deploy, check the TensorRT-LLM section below.
> **IMPORTANT:** The sample flow `medusa_sft.sh` does not contain synthetic data generation.
> To achieve the best acceptance rate, check the whole receipt and options in the following sections.
## Table of Contents
[[_TOC_]]
## Training and Export Workflow
In practice, speculative decoding should be combined with quantization (weights and kv-cache)
to achieve the the highest tokens-per-second-per-user (or TPS) without changing the quality of
the model. We provide quantization-aware training (QAT) receipt with self-distillation in the following.
### Model Convertion
To ensure no quality degredation, base model is frozen and the draft model is attached as a
transformation. By providing `--export-num-medusa-heads` or `--export-num-eagle-layers`,
the resulting model stored in `${MLM_MODEL_SAVE}` will have randomly initialized draft model weights.
```
python examples/post_training_opt/convert_gpt.py \
--export-num-medusa-heads 4 \
--load ${MLM_MODEL_CKPT} --save ${MLM_MODEL_SAVE} ${OTHER_MLM_ARGS}
```
> **NOTE:** `MLM_MODEL_SAVE=Llama-3.1-8B-Instruct_medusa` in the example.
### Synthetic Data Generation
Rather than learning the language and syntax, the draft model is trained to mimic the base
model output. As a result, self-synthesized data is crucial for the draft model accuracy
and acceptance rate (AR). In EAGLE training, hidden state and logits distillation are also
applied.
For simplicity and efficiency, we use `vllm serve --quantization modelopt` to host an quantized
endpoint and we feed multi-turn conversation data to synthesize the assistant output.
See ModelOpt's example (https://github.com/NVIDIA/TensorRT-Model-Optimizer/tree/main/speculative_decoding)
for more details. The final output is stored as jsonlines in an OpenAI chat completion format.
### Quantization-Aware Training (QAT)
For quantize-aware training (QAT), the process is `bf16 training`, `fake quantization`, `qat`.
Since the base model weights are frozen, the initial training is mainly to get an more accurate
range of the draft model activation and weights. We store a new checkpoint where the model
now has additional quantization scalars for both the base and draft models. We launch the
finetuning again to continue the training with fake quantization until convergence.
```sh
python examples/post_training_opt/finetune_gpt.py \
--export-num-medusa-heads 4 \
--load ${MLM_MODEL_SAVE} --save ${MLM_MODEL_SAVE} ${OTHER_MLM_ARGS}
python examples/post_training_opt/text_generation_ptq.py \
--export-quant-cfg fp8 \
--decoder llama \
--export-num-medusa-heads 4 \
--load ${MLM_MODEL_SAVE} --save ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
python examples/post_training_opt/finetune_gpt.py \
--export-num-medusa-heads 4 \
--load ${MLM_QUANT_SAVE} --save ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
```
> **NOTE:** `MLM_QUANT_SAVE=Llama-3.1-8B-Instruct_medusa_quant` in the example.
### Export TensorRT-LLM Checkpoint
To finally export a TensorRT-LLM checkpoint, we leverage the same script by providing
`${TRTLLM_CKPT}` and the inference `${TP}`.
```sh
python examples/post_training_opt/text_generation_ptq.py \
--export-dir ${TRTLLM_CKPT} \
--inference-tensor-parallel ${TP} \
--export-quant-cfg None \
--decoder llama \
--export-num-medusa-heads 4 \
--load ${MLM_QUANT_SAVE} ${OTHER_MLM_ARGS}
```
> **NOTE:** `TRTLLM_CKPT=Llama-3.1-8B-Instruct_medusa_quant_trtllm_export` in the example.
**TensorRT-LLM deployment:** To build (`trtllm-build`) and run TensorRT-LLM engine, follow the steps here
https://github.com/NVIDIA/TensorRT-Model-Optimizer#installation--docker to prepare the container.
For `tensorrt-llm>0.12`, the builder can detect this is a Medusa checkpoint directly
```sh
trtllm-build --checkpoint_dir Llama-3.1-8B-Instruct_medusa_quant_trtllm_export --output_dir /tmp/trtllm_engine ${other args}
```
The `run.py` (https://github.com/NVIDIA/TensorRT-LLM/blob/main/examples/run.py) and `gptManagerBenchmark` (https://github.com/NVIDIA/TensorRT-LLM/tree/main/benchmarks/cpp)
both support Medusa decoding by supplying argument `--medusa_choices`. This argument describes the sparse attention tree structure used in the Medusa writeup. For examples,
the following option is tree with 63 nodes which represent 63 draft tokens proposed by the 4 Medusa heads.
```sh
--medusa_choices="[[0], [0, 0], [1], [0, 1], [2], [0, 0, 0], [1, 0], [0, 2], [3], [0, 3], [4], [0, 4], [2, 0], [0, 5], [0, 0, 1], [5], [0, 6], [6], [0, 7], [0, 1, 0], [1, 1], [7], [0, 8], [0, 0, 2], [3, 0], [0, 9], [8], [9], [1, 0, 0], [0, 2, 0], [1, 2], [0, 0, 3], [4, 0], [2, 1], [0, 0, 4], [0, 0, 5], [0, 0, 0, 0], [0, 1, 1], [0, 0, 6], [0, 3, 0], [5, 0], [1, 3], [0, 0, 7], [0, 0, 8], [0, 0, 9], [6, 0], [0, 4, 0], [1, 4], [7, 0], [0, 1, 2], [2, 0, 0], [3, 1], [2, 2], [8, 0], [0, 5, 0], [1, 5], [1, 0, 1], [0, 2, 1], [9, 0], [0, 6, 0], [0, 0, 0, 1], [1, 6], [0, 7, 0]]"
```
> **ADVANCED USAGE:** When training, we typically train `4` heads if memory is sufficient and by default the max draft length is `63`.
> Optionally, users can change these values something smaller in TensorRT-LLM checkpoint's `config.json` before calling `trtllm-build`.
> For example, it is possible to only use 2 heads with maximum draft tokens 7 if this is a sweet spot. You must also change
> `--medusa_choices` to make sure you are not accessing draft tokens from the 3rd and 4th heads as well as shorting the list to have
> length 7.
# RETRO MODEL
## Table of contents
- [1. Training Setup](#1-training-setup)
- [2. Data Preprocessing](#2-data-preprocessing)
- [3. Configurations](#3-configurations)
## 1. Training setup
<a id="markdown-training-setup" name="training-setup"></a>
To run the model using a docker container run it as follows
```
PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3
CHECKPOINT_PATH="" #<Specify path>
TENSORBOARD_LOGS_PATH=""#<Specify path>
docker run \
--gpus=all \
--ipc=host \
--workdir /workspace/megatron-lm \
-v /path/to/data:/path/to/data \
-v /path/to/megatron-lm:/workspace/megatron-lm \
megatron-lm nvcr.io/nvidia/pytorch:23.09-py3 \
bash examples/retro/train_retro_2b_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH"
```
NOTE: Depending on the environment you are running it the above command might look slightly different.
NOTE: Due to how Retro preprocess and caches elements of the pretraining dataset before training begins, some arguments are auto-loaded from the Retro preprocessing configuration. These loaded arguments include:
- `--data-path`
- `--data-cache-path`
- `--eval-interval`
- `--eval-iters`
- `--global-batch-size`
- `--tokenizer-type`
- `--tokenizer-model`
- `--vocab-file`
- `--merge-file`
- `--seed`
- `--seq-length`
- `--train-samples`
## 2. Data Preprocessing
<a id="markdown-data-preprocessing" name="data-preprocessing"></a>
Retro preprocesses and caches data prior to pretraining, to greatly speed up pretraining. During data preprocessing, the retrieval database is built, and neighbor IDs are queried for each sample within the pretraining dataset. Please see `preprocess_data.sh` for an example script to preprocess data for Retro. The reference documentation for data preprocessing can be found [here](tools/retro/README.md).
## 3. Configurations
<a id="markdown-configurations" name="configurations"></a>
The example in this folder shows you how to run a 2B model. Below are a few other example configurations.
### 857M
```
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--seq-length 2048 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
```
### 4B
```
--num-layers 48 \
--hidden-size 2560 \
--num-attention-heads 32 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
```
#!/bin/bash
set -u
unset NCCL_DEBUG
######## Megatron, Retro dirs. ########
REPO_DIR="<path/to/megatron/repo>"
RETRO_PROJECT_DIR="<path/to/retro/project/directory>"
######## Task (e.g., db, index, query). ########
# This script takes a single argument, which specifies the retro task to be
# performed. The available tasks are: db-build, index-train, index-add, and
# query-neighbors.
# ~~ Examples ~~
# RETRO_TASKS="db-build" # Build the retrieval database
# RETRO_TASKS="index-train" # Train the index
# RETRO_TASKS="index-add" # Add data to the index
# RETRO_TASKS="query-neighbors" # Perform query pretraining for neighbors
# You can also provide the task as a command-line argument when executing the
# script. Example: ./preprocess_data.sh index-add
RETRO_TASKS=$1
######## Data. ########
DATA_BLEND="<see --data-path in arguments.py>"
######## Index. ########
RETRO_INDEX_STR="OPQ32_64,IVF65536_HNSW8,PQ32"
RETRO_INDEX_NTRAIN=66625331
RETRO_INDEX_TRAIN_LOAD_FRACTION=0.97
RETRO_INDEX_ADD_LOAD_FRACTION=0.95
######## GPT. ########
RETRO_GPT_SEED=1234
RETRO_GPT_SPLIT="98,2,0"
RETRO_GPT_DATA_PATH=${DATA_BLEND}
RETRO_GPT_TRAIN_SAMPLES=200000
RETRO_GPT_EVAL_INTERVAL=2000
RETRO_GPT_EVAL_ITERS=50
RETRO_GPT_LR_DECAY_SAMPLES=175000
RETRO_GPT_LR_WARMUP_SAMPLES=10000
RETRO_GPT_SEQ_LENGTH=2048
RETRO_GPT_GLOBAL_BATCH_SIZE=256
RETRO_GPT_CHUNK_LENGTH=64
######## Query. ########
RETRO_QUERY_NUM_NEIGHBORS_QUERY=200
RETRO_QUERY_NUM_NEIGHBORS_SAVE=20
RETRO_QUERY_EF_SEARCH=32
RETRO_QUERY_NPROBE=4096
######## Args. ########
ARGS=" \
--distributed-timeout-minutes 600 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--num-layers 24 \
--hidden-size 1024 \
--num-attention-heads 16 \
--micro-batch-size 1 \
--global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--seq-length 512 \
--max-position-embeddings 512 \
--load ${RETRO_PROJECT_DIR}/checkpoints/bert \
--exit-on-missing-checkpoint \
--no-load-optim \
--data-path [null] \
--tokenizer-type BertWordPieceLowerCase \
--vocab-file ${RETRO_PROJECT_DIR}/tokenizer/bert-large-uncased-vocab.txt \
--split ${RETRO_GPT_SPLIT} \
--distributed-backend nccl \
--lr 0.0001 \
--lr-decay-style linear \
--min-lr 1.0e-5 \
--train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
--lr-decay-samples ${RETRO_GPT_LR_DECAY_SAMPLES} \
--lr-warmup-samples ${RETRO_GPT_LR_WARMUP_SAMPLES} \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--eval-iters ${RETRO_GPT_EVAL_ITERS} \
--bf16 \
--no-data-sharding \
--no-gradient-accumulation-fusion \
--no-async-tensor-model-parallel-allreduce \
--bert-embedder-type megatron \
--output-bert-embeddings \
\
--retro-project-dir ${RETRO_PROJECT_DIR} \
--retro-tasks ${RETRO_TASKS} \
--retro-bert-vocab-file tokenizer/bert-large-uncased-vocab.txt \
--retro-bert-tokenizer-type BertWordPieceLowerCase \
\
--retro-gpt-seed ${RETRO_GPT_SEED} \
--retro-gpt-tokenizer-type GPTSentencePieceTokenizer \
--retro-gpt-tokenizer-model /path/to/tokenizer/model \
--retro-gpt-seq-length ${RETRO_GPT_SEQ_LENGTH} \
--retro-gpt-chunk-length ${RETRO_GPT_CHUNK_LENGTH} \
--retro-gpt-global-batch-size ${RETRO_GPT_GLOBAL_BATCH_SIZE} \
--retro-gpt-eval-interval ${RETRO_GPT_EVAL_INTERVAL} \
--retro-gpt-eval-iters ${RETRO_GPT_EVAL_ITERS} \
--retro-gpt-split ${RETRO_GPT_SPLIT} \
--retro-gpt-data-path ${RETRO_GPT_DATA_PATH} \
--retro-gpt-train-samples ${RETRO_GPT_TRAIN_SAMPLES} \
\
--retro-index-str ${RETRO_INDEX_STR} \
--retro-index-ntrain ${RETRO_INDEX_NTRAIN} \
--retro-index-train-load-fraction ${RETRO_INDEX_TRAIN_LOAD_FRACTION} \
--retro-index-add-load-fraction ${RETRO_INDEX_ADD_LOAD_FRACTION} \
--no-retro-index-delete-training-embeddings \
--no-retro-index-delete-added-codes \
\
--retro-query-num-neighbors-query ${RETRO_QUERY_NUM_NEIGHBORS_QUERY} \
--retro-query-num-neighbors-save ${RETRO_QUERY_NUM_NEIGHBORS_SAVE} \
--retro-query-ef-search ${RETRO_QUERY_EF_SEARCH} \
--retro-query-nprobe ${RETRO_QUERY_NPROBE} \
"
######## Command. ########
NPROCS=8 # Number of GPUs.
CMD="\
cd ${REPO_DIR} && pwd && \
export PYTHONPATH=$PYTHONPATH:${REPO_DIR} && \
python -m torch.distributed.run \
--nproc_per_node ${NPROCS} \
--nnodes 1 \
--node_rank ${NODE_RANK} \
--master_addr ${MASTER_ADDR} \
--master_port 6000 \
tools/retro/preprocess_data.py ${ARGS} \
"
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
echo "CMD = '$CMD'."
echo "~~~~~~~~~~~~~~~~~~~~~~~~~~"
eval $CMD
#!/bin/bash
# Runs the "307M" parameter Retro model.
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_LOGS_PATH=$2 #<Specify path>
DISTRIBUTED_ARGS=(
--nproc_per_node $GPUS_PER_NODE
--nnodes $NUM_NODES
--master_addr $MASTER_ADDR
--master_port $MASTER_PORT
)
######## GPT or Retro? ########
# 0 : GPT.
# 1 : Retro
ADD_RETRIEVER=1
######## Megatron, Retro dirs. ########
RETRO_PROJECT_DIR="<path/to/retro/project/directory>"
######## Model, training args. ########
# ** Note: --seq-length auto loaded from Retro project dir.
RETRO_MODEL_ARGS=(
--num-layers 32
--hidden-size 2048
--num-attention-heads 32
)
# ** Note: --data-path, --tokenizer-type, and --tokenizer-model auto loaded from Retro project dir.
DATA_ARGS=(
--split 98,2,0
)
MODEL_PARALLEL_ARGS=(
--tensor-model-parallel-size 8
--pipeline-model-parallel-size 1
)
# ** Note: --eval-interval, --eval-iters auto loaded from Retro project dir.
EVAL_AND_LOGGING_ARGS=(
--log-interval 100
--save-interval 10000
--eval-interval 1000
--save $CHECKPOINT_PATH
--load $CHECKPOINT_PATH
--eval-iters 10
--tensorboard-dir $TENSORBOARD_LOGS_PATH
)
TRAINING_ARGS=" \
--retro-project-dir ${RETRO_PROJECT_DIR} \
--transformer-impl transformer_engine \
--num-workers 8 \
--micro-batch-size 4 \
--lr-decay-samples 166400000 \
--lr-warmup-samples 162761 \
--lr 6.0e-4 \
--min-lr 6.0e-5 \
--lr-decay-style cosine \
--clip-grad 1.0 \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--init-method-std 0.023 \
--log-params-norm \
--log-num-zeros-in-grad \
--bf16 \
--no-data-sharding \
"
if [ "$ADD_RETRIEVER" = "1" ]; then
TRAINING_ARGS+=" --retro-add-retriever"
fi
######## Command. ########
torchrun ${DISTRIBUTED_ARGS[@]} pretrain_retro.py \
${RETRO_MODEL_ARGS[@]} \
${TRAINING_ARGS} \
${MODEL_PARALLEL_ARGS[@]} \
${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]}
import os
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from functools import partial
from pathlib import Path
from megatron.core import parallel_state
from megatron.core import dist_checkpointing
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed
from megatron.core.transformer.transformer_config import TransformerConfig
from megatron.core.models.gpt.gpt_model import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import get_gpt_layer_local_spec
from megatron.core.datasets.utils import compile_helpers
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig, MockGPTDataset
from megatron.training.tokenizer.tokenizer import _NullTokenizer
_SEQUENCE_LENGTH = 64
def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1):
parallel_state.destroy_model_parallel()
# Torch setup for distributed training
rank = int(os.environ['LOCAL_RANK'])
world_size = torch.cuda.device_count()
torch.cuda.set_device(rank)
torch.distributed.init_process_group(world_size=world_size, rank=rank)
# Megatron core distributed training initialization
parallel_state.initialize_model_parallel(tensor_model_parallel_size, pipeline_model_parallel_size)
def model_provider():
"""Build the model."""
transformer_config = TransformerConfig(
num_layers=2,
hidden_size=12,
num_attention_heads=4,
use_cpu_initialization=True,
pipeline_dtype=torch.float32,
)
gpt_model = GPTModel(
config=transformer_config,
transformer_layer_spec=get_gpt_layer_local_spec(),
vocab_size=100,
max_sequence_length=_SEQUENCE_LENGTH,
)
return gpt_model
def get_train_data_iterator():
if torch.distributed.is_available() and torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
compile_helpers()
torch.distributed.barrier()
else:
compile_helpers()
config = GPTDatasetConfig(
random_seed=0,
sequence_length=_SEQUENCE_LENGTH,
reset_position_ids=False,
reset_attention_mask=False,
eod_mask_loss=False,
tokenizer=_NullTokenizer(vocab_size=_SEQUENCE_LENGTH),
mid_level_dataset_surplus=0.005,
)
datasets = BlendedMegatronDatasetBuilder(
MockGPTDataset, [1000, None, None], lambda: True, config
).build()
train_dataloader = DataLoader(datasets[0], batch_size=8, shuffle=True)
train_iterator = iter(train_dataloader)
return train_iterator
def forward_step_func(data_iterator, model):
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
losses = output_tensor.float()
loss_mask = loss_mask.view(-1).float()
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()
# If you have data parallel reduce loss across data parallel groups.
# If pipeline parallel, loss computation is done only in last stage.
return loss, {'lm loss': loss}
data = next(data_iterator)
tokens = data['tokens'].to(device)
attention_mask = data['attention_mask'].to(device)
position_ids = data['position_ids'].to(device)
labels = data['labels'].to(device)
loss_mask = data['loss_mask'].to(device)
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)
return output_tensor, partial(loss_func, loss_mask)
def save_distributed_checkpoint(checkpoint_path, gpt_model):
sharded_state_dict = gpt_model.sharded_state_dict(prefix='')
dist_checkpointing.save(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
def load_distributed_checkpoint(checkpoint_path, gpt_model):
sharded_state_dict=gpt_model.sharded_state_dict(prefix='')
checkpoint = dist_checkpointing.load(sharded_state_dict=sharded_state_dict, checkpoint_dir=checkpoint_path)
gpt_model.load_state_dict(checkpoint)
return gpt_model
if __name__ == "__main__":
initialize_distributed(tensor_model_parallel_size=2, pipeline_model_parallel_size=1)
model_parallel_cuda_manual_seed(123)
gpt_model = model_provider()
device = torch.device("cuda")
gpt_model.to(device)
optim = Adam(gpt_model.parameters())
train_iterator = get_train_data_iterator()
forward_backward_func = get_forward_backward_func()
# Running the model for 5 iterations
for _ in range(5):
optim.zero_grad()
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
data_iterator=train_iterator,
model=gpt_model,
num_microbatches=1,
seq_length=_SEQUENCE_LENGTH,
micro_batch_size=8,
decoder_seq_length=_SEQUENCE_LENGTH,
forward_only=False)
optim.step()
print(f'Losses reduced : {losses_reduced}')
# Saving the model
ckpt_path = os.getcwd() + '/ckpt'
Path(ckpt_path).mkdir(exist_ok=True)
save_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path)
# Loading the model
gpt_model = load_distributed_checkpoint(gpt_model=gpt_model, checkpoint_path=ckpt_path)
gpt_model.to(device)
print('Successfully loaded the model')
# T5 MODEL
## Table of contents
- [1. Training Setup](#1-training-setup)
- [2. Configurations](#2-configurations)
- [3. Training Results](#3-training-results)
## 1. Training setup
<a id="markdown-training-setup" name="training-setup"></a>
To run the model on a Slurm based cluster
```
PYTORCH_IMAGE=nvcr.io/nvidia/pytorch:23.09-py3
ACCOUNT_NAME=""
PARTITION=""
JOB_NAME=""
NUM_NODES=1
CHECKPOINT_PATH="" #<Specify path to checkpoint>
TENSORBOARD_LOGS_PATH=""#<Specify path to tensorboard log>
VOCAB_FILE="" #<Specify path to file>/bert-large-cased-vocab.txt
DATA_PATH="" #<Specify path and file prefix>_text_document
srun -N $NUM_NODES --container-image $PYTORCH_IMAGE --container-mounts "/path/to/data:/path/to/data,/path/to/megatron-lm:/workspace/megatron-lm" --account $ACCOUNT -N 1 -J $JOB_NAME -p $PARTITION --no-container-mount-home -c "
cd /workspace/megatron-lm
./examples/t5/train_t5_220m_distributed.sh $CHECKPOINT_PATH $TENSORBOARD_LOGS_PATH $VOCAB_FILE $DATA_PATH"
```
## 2. Configurations
<a id="markdown-configurations" name="configurations"></a>
The architecture arguments below shows configuration for T5 220M model.
### 220M
```
--num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--kv-channels 64 \
--ffn-hidden-size 3072 \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--max-position-embeddings 512 \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
```
## 3. Training Results
<a id="markdown-training-results" name="training-results"></a>
Below is the training curve for the 220M model on Pile dataset. The training takes 4 days on 32 GPUs, with batch size of 2048.
Finetuning on SQUAD dataset, the validation result is: 63.44\%
<p align="center">
<img src="./t5_mcore_train_curve.png" width="800" height="400">
</p>
#!/bin/bash
# Runs the "220M" parameter model
export CUDA_DEVICE_MAX_CONNECTIONS=1
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NUM_NODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NUM_NODES))
CHECKPOINT_PATH=$1 #<Specify path>
TENSORBOARD_DIR=$2 #<Specify path>
VOCAB_FILE=$3 #<Specify path to file>/bert-large-cased-vocab.txt
DATA_PATH=$4 #<Specify path and file prefix>_text_document
DISTRIBUTED_ARGS="
--nproc_per_node $GPUS_PER_NODE \
--nnodes $NUM_NODES \
--node_rank $NODE_RANK \
--master_addr $MASTER_ADDR \
--master_port $MASTER_PORT
"
T5_ARGS="
--encoder-num-layers 12 \
--decoder-num-layers 12 \
--hidden-size 768 \
--num-attention-heads 12 \
--kv-channels 64 \
--ffn-hidden-size 3072 \
--encoder-seq-length 512 \
--decoder-seq-length 128 \
--max-position-embeddings 512 \
--micro-batch-size 64 \
--global-batch-size 512 \
--lr 0.0001 \
--train-iters 1000000 \
--lr-decay-iters 1000000 \
--lr-decay-style linear \
--min-lr 0.00001 \
--weight-decay 1e-2 \
--lr-warmup-fraction .01 \
--clip-grad 1.0 \
--bf16 \
--vocab-extra-ids 100 \
--init-method-std 0.015 \
--transformer-impl transformer_engine \
--tensor-model-parallel-size 1 \
--pipeline-model-parallel-size 1 \
--attention-backend auto \
"
DATA_ARGS="
--data-path $DATA_PATH \
--vocab-file $VOCAB_FILE \
--tokenizer-type BertWordPieceCase \
--split 99982,9,9 \
"
OUTPUT_ARGS="
--log-interval 100 \
--tensorboard-dir ${TENSORBOARD_DIR} \
--save-interval 500 \
--eval-interval 1000 \
--eval-iters 10
"
torchrun $DISTRIBUTED_ARGS pretrain_t5.py \
$T5_ARGS \
$DATA_ARGS \
$OUTPUT_ARGS \
--distributed-backend nccl \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment