Commit a8a2bbea authored by dongcl's avatar dongcl
Browse files

patch for megatron 4429e8ebe

parent 2ddbd4be
#!/bin/bash
for para in $*
do
if [[ $para == --profiling* ]];then
profiling=${para#*=}
# export GPU_FLUSH_ON_EXECUTION=1
# export HIP_DIRECT_DISPATCH=0
fi
done
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
echo $CURRENT_DIR
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10
export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32
export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_TOPO_FILE="/public/home/yuguo/check/rccl-tests-0204/topo-input.xml" #"your topo file"
export GLOG_minloglevel=3
export GROUPED_GEMM_BatchLinear=1
export LD_LIBRARY_PATH=/public/home/yuguo/data/rocblas-install-0224/lib:$LD_LIBRARY_PATH
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
RANK=$OMPI_COMM_WORLD_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
### BASE CONFIG ###
MODEL_SIZE=A37B
BATCH_SIZE=1
GLOBAL_BATCH_SIZE=256
LR=1e-5
MIN_LR=1e-6
SEQ_LEN=4096
PR=bf16
### BASE CONFIG ###
### PARALLEL / BOOL OPTION ###
TP=1
PP=2
CP=1
EP=4
SP=true
DO=true
FL=true
SFT=false
### PARALLEL / BOOL OPTION ###
### OTHERS ###
AC=none
OPTIMIZER_OFFLOAD=false
SAVE_INTERVAL=500
DATASET_PATH=${MEGATRON_PATH}/deepseekv3_dataset/mmap_deepseekv3_datasets_text_document #"your data path"
VALID_DATASET_PATH=${MEGATRON_PATH}/deepseekv3_dataset/mmap_deepseekv3_datasets_text_document #"your data path"
PRETRAIN_CHECKPOINT_PATH=${MEGATRON_PATH}/deepseekv3_dataset #"your model path"
# the following two values will not be used when SFT is true
TRAIN_TOKENS=100000000
WARMUP_TOKENS=10000
###############################
OUTPUT_BASEPATH=./output
### OTHERS ###
if [ $FL = true ]; then
:
#exit -1
elif [ $FL = false ]; then
export NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=1
attn_backend_option=" \
--attention-backend fused
"
fi
if [ $MODEL_SIZE = A37B ]; then
TRAIN_ITERS=2
HIDDEN_SIZE=7168
NUM_ATTENTION_HEADS=128
NUM_LAYERS=2
INTERMEDIATE_SIZE=18432
MOE_INTERMEDIATE_SIZE=2048
MAX_POSITION_EMBEDDINGS=${SEQ_LEN}
EXTRA_VOCAB_SIZE=467
Q_LORA_RANK=1536
KV_LORA_RANK=512
QK_NOPE_HEAD_DIM=128
QK_ROPE_HEAD_DIM=64
V_HEAD_DIM=128
ROPE_THETA=10000
SCALE_FACTOR=40
NUM_EXPERTS=8 #256
ROUTER_TOPK=8
NUM_SHARED_EXPERTS=1
RMS_NORM_EPS=1e-6
moe_options=" \
--moe-grouped-gemm \
--moe-expert-capacity-factor 1 \
--moe-pad-expert-input-to-capacity \
--moe-token-dispatcher-type alltoall \
--moe-router-topk ${ROUTER_TOPK} \
--num-experts ${NUM_EXPERTS} \
--expert-model-parallel-size ${EP} \
--expert-tensor-parallel-size 1 \
--moe-ffn-hidden-size ${MOE_INTERMEDIATE_SIZE} \
--moe-router-load-balancing-type aux_loss \
--moe-aux-loss-coeff 0.001 \
--moe-layer-freq ([0]*0+[1]*2) \
--q-lora-rank ${Q_LORA_RANK} \
--kv-lora-rank ${KV_LORA_RANK} \
--qk-head-dim ${QK_NOPE_HEAD_DIM} \
--qk-pos-emb-head-dim ${QK_ROPE_HEAD_DIM} \
--v-head-dim ${V_HEAD_DIM} \
--moe-shared-expert-intermediate-size $((${MOE_INTERMEDIATE_SIZE} * ${NUM_SHARED_EXPERTS} )) \
"
fi
# Here are some configs controled by env
if [ -z ${MP_DATASET_TYPE} ];then
MP_DATASET_TYPE="idxmap"
fi
if [ -z ${MP_AC_LAYERS} ];then
MP_AC_LAYERS=1
fi
if [ -z ${MP_VP} ]; then
vp_option=""
else
vp_option=" \
--num-layers-per-virtual-pipeline-stage ${MP_VP}"
fi
if [ -z ${MP_SFT_PACKING} ]; then
MP_SFT_PACKING=false
fi
TP_COMM_OVERLAP=$(( ($TP > 1) ? 1 : 0 ))
comm_overlap_option="\
--overlap-grad-reduce \
--overlap-param-gather"
if [ $AC = full ]; then
_check=$(( ($NUM_LAYERS / $PP) % ${MP_AC_LAYERS} ))
if [ $_check != 0 ]; then
echo "the num layers per pp rank must be a multiple of the recompute layers."
exit -1
fi
activation_checkpoint_options=" \
--recompute-method uniform \
--recompute-num-layers ${MP_AC_LAYERS} \
--recompute-granularity full"
elif [ $AC = sel ]; then
activation_checkpoint_options=" \
--recompute-activations"
elif [ $AC = none ]; then
activation_checkpoint_options=" \
"
elif [ $AC = offload ]; then
activation_checkpoint_options=" \
--cpu-offloading \
--cpu-offloading-num-layers ${MP_AC_LAYERS}"
if [ $TP_COMM_OVERLAP -eq 1 ]; then
echo "Disable --overlap-grad-reduce and --overlap-param-gather when cpu offloading is on..."
comm_overlap_option="\
--tp-comm-overlap"
else
echo "Disable --overlap-grad-reduce and --overlap-param-gather when cpu offloading is on..."
comm_overlap_option=""
fi
fi
if [ $PR = fp16 ]; then
pr_options=" \
--fp16 \
--apply-query-key-layer-scaling"
export NVTE_APPLY_QK_LAYER_SCALING=1
elif [ $PR = bf16 ]; then
pr_options=" \
--bf16"
elif [ $PR = fp8 ]; then
pr_options=" \
--bf16 \
--fp8-format hybrid \
--fp8-amax-compute-algo max \
--fp8-amax-history-len 1024"
fi
if [ $OPTIMIZER_OFFLOAD != false ] && [ $DO = false ]; then
echo "Offload optimizer is valid only if \$DO=true"
DO=true
fi
if [ $DO = true ]; then
do_option=" \
--use-distributed-optimizer"
elif [ $DO = false ]; then
do_option=" \
"
fi
if [ $SP = true ] && [ $TP -gt 1 ]; then
sp_option=" \
--sequence-parallel"
elif [ $SP = false ]; then
sp_option=" \
"
fi
if [ -z ${MP_PP0_LAYERS} ];then
uneven_split_option=""
elif [ ${PP} -gt 1 ]; then
_check=$(( ( $NUM_LAYERS - ${MP_PP0_LAYERS} ) % ( ${PP} - 1 ) ))
if [ $_check != 0 ]; then
echo "With uneven pipelineing the left over layers must be divisible by left over stages."
exit -1
fi
uneven_split_option=" \
--decoder-first-pipeline-num-layers ${MP_PP0_LAYERS}
"
else
echo "uneven pipeline split must be used when PP > 1"
exit -1
fi
if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then
load_option=" \
--tokenizer-model $PRETRAIN_CHECKPOINT_PATH"
fi
if [ $OPTIMIZER_OFFLOAD != false ]; then
offload_option=" \
--optimizer-cpu-offload \
--use-precision-aware-optimizer \
--optimizer-offload-fraction ${OPTIMIZER_OFFLOAD}"
fi
if [ $SFT = true ]; then
TRAIN_ITERS=${24}
LR_WARMUP_ITERS=${25}
LR_DECAY_ITERS=$(( ${TRAIN_ITERS} - ${LR_WARMUP_ITERS}))
PREFIX="finetune-mcore-deepseek-v3"
else
# TRAIN_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
LR_WARMUP_ITERS=$(( ${WARMUP_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
LR_DECAY_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
PREFIX="pretrain-mcore-deepseek-v3"
fi
if [ ${MP_DATASET_TYPE} = "raw" ]; then
dataset_options=" \
--train-data-path ${DATASET_PATH} \
--valid-data-path ${VALID_DATASET_PATH} \
--dataloader-type cyclic \
--dataset JSON-SFT"
else
dataset_options=" \
--data-path ${DATASET_PATH} \
--split 99,1,0"
fi
if [ ${MP_SFT_PACKING} = true ]; then
echo "Currently MLA do not support THD format attention, thus sequence packing can not be used..."
packing_options=""
else
packing_options=""
fi
##### Prepare logdirs #######
NAME="${PREFIX}"
mkdir -p "${OUTPUT_BASEPATH}/tensorboard/"
mkdir -p "${OUTPUT_BASEPATH}/checkpoint/"
mkdir -p "${OUTPUT_BASEPATH}/log/"
TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${NAME}"
mkdir -p ${TENSORBOARD_DIR}
SAVED_PRETRAIN_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${NAME}"
mkdir -p ${SAVED_PRETRAIN_CHECKPOINT_PATH}
find -L ${PRETRAIN_CHECKPOINT_PATH} -maxdepth 1 -type f -name "*.json" -print0 | xargs -0 cp -t ${SAVED_PRETRAIN_CHECKPOINT_PATH}
megatron_options=" \
--lr ${LR} \
--min-lr ${MIN_LR} \
--lr-decay-style cosine \
--weight-decay 0.1 \
--adam-beta1 0.9 \
--adam-beta2 0.95 \
--clip-grad 1.0 \
--init-method-std 0.008 \
--attention-dropout 0.0 \
--hidden-dropout 0.0 \
--lr-decay-iters ${LR_DECAY_ITERS} \
--lr-warmup-iters ${LR_WARMUP_ITERS} \
--train-iters ${TRAIN_ITERS} \
--micro-batch-size ${BATCH_SIZE} \
--global-batch-size ${GLOBAL_BATCH_SIZE} \
--num-layers ${NUM_LAYERS} \
--hidden-size ${HIDDEN_SIZE} \
--num-attention-heads ${NUM_ATTENTION_HEADS} \
--ffn-hidden-size ${INTERMEDIATE_SIZE} \
--seq-length ${SEQ_LEN} \
--max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \
--log-interval 1 \
--log-throughput \
--eval-interval 10000 \
--eval-iters 5 \
--save-interval ${SAVE_INTERVAL} \
--tensorboard-queue-size 1 \
--tensorboard-dir ${TENSORBOARD_DIR} \
--log-timers-to-tensorboard \
--log-validation-ppl-to-tensorboard \
--tensor-model-parallel-size ${TP} \
--pipeline-model-parallel-size ${PP} \
--context-parallel-size ${CP} \
--no-load-optim \
--no-load-rng \
--num-workers 8 \
--extra-vocab-size ${EXTRA_VOCAB_SIZE} \
--tokenizer-type DeepSeekV2Tokenizer \
--swiglu \
--normalization RMSNorm \
--norm-epsilon ${RMS_NORM_EPS} \
--use-rotary-position-embeddings \
--no-bias-swiglu-fusion \
--no-rope-fusion \
--position-embedding-type rope \
--untie-embeddings-and-output-weights \
--disable-bias-linear \
--rotary-base ${ROPE_THETA} \
--rotary-scaling-factor ${SCALE_FACTOR} \
--no-save-optim \
--kv-channels ${V_HEAD_DIM} \
--qk-layernorm \
--ckpt-format torch \
--transformer-impl transformer_engine \
--use-rope-scaling \
--multi-latent-attention \
--mtp-num-layers 1 \
--use-mcore-models \
"
TORCH_PROFIE_ARGS=" \
--profile \
--profile-ranks 0 1 2 3 4 5 6 7 \
--profile-step-start 3 \
--profile-step-end 4 \
--profile-dir torch_prof_data_16nodes_dcu \
--use-pytorch-profiler \
"
HIP_PROFIE_ARGS=" \
--profile \
--profile-ranks 0 1 2 3 4 5 6 7 \
--profile-step-start 4 \
--profile-step-end 5 \
--use-hip-profiler \
"
APP="python3 -u ${MEGATRON_PATH}/pretrain_gpt.py
${megatron_options} \
${dataset_options} \
${pr_options} \
${load_option} \
${activation_checkpoint_options} \
${do_option} \
${sp_option} \
${moe_options} \
${offload_option} \
${sft_options} \
${vp_option} \
${packing_options} \
${uneven_split_option} \
${attn_backend_option} \
${comm_overlap_option} \
--rank ${RANK} \
--world-size ${WORLD_SIZE} \
--local-rank ${LOCAL_RANK} \
--dist-url tcp://${1}:25900 \
"
if [[ $profiling == "torch" ]]; then
APP+=" ${TORCH_PROFIE_ARGS}"
elif [[ $profiling == "hip" ]]; then
mkdir -p hip_prof_data
APP+=" ${HIP_PROFIE_ARGS}"
APP="hipprof -d hip_prof_data --hip-trace --trace-off ${APP}"
fi
case ${LOCAL_RANK} in
[0])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[1])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[2])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[3])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[4])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[5])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[6])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
[7])
export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
${APP}
;;
esac
...@@ -5,49 +5,60 @@ import os, sys ...@@ -5,49 +5,60 @@ import os, sys
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
megatron_path = os.path.join(current_dir, "Megatron-LM") megatron_path = os.path.join(current_dir, "Megatron-LM")
sys.path.append(megatron_path) sys.path.append(megatron_path)
import torch
from functools import partial
from contextlib import nullcontext
import inspect
from functools import partial
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
from megatron.training import get_args
from megatron.training import print_rank_0 import torch
from megatron.training import get_timers
from megatron.training import get_tokenizer from megatron.core import parallel_state
from megatron.core import mpu
from megatron.core.enums import ModelType
from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder
from megatron.core.datasets.gpt_dataset import GPTDatasetConfig from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset
from megatron.core.datasets.gpt_dataset import MockGPTDataset, GPTDataset from megatron.core.enums import ModelType
from megatron.core.rerun_state_machine import get_rerun_state_machine
import megatron.legacy.model
from megatron.core.models.gpt import GPTModel from megatron.core.models.gpt import GPTModel
from megatron.training import pretrain from megatron.core.models.gpt.gpt_layer_specs import (
from megatron.core.utils import StragglerDetector get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
get_gpt_mtp_block_spec,
)
from megatron.core.models.gpt.heterogeneous.heterogeneous_layer_specs import (
get_gpt_heterogeneous_layer_spec,
)
from megatron.core.rerun_state_machine import get_rerun_state_machine
from megatron.core.transformer.spec_utils import import_module from megatron.core.transformer.spec_utils import import_module
from megatron.core.utils import StragglerDetector
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import ( from megatron.training.utils import (
get_batch_on_this_cp_rank, get_batch_on_this_cp_rank,
get_batch_on_this_tp_rank, get_batch_on_this_tp_rank,
get_blend_and_blend_per_split, get_blend_and_blend_per_split,
) )
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.yaml_arguments import core_transformer_config_from_yaml from megatron.training.yaml_arguments import core_transformer_config_from_yaml
from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_decoder_block_spec,
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
from megatron.core.transformer.transformer_block import TransformerBlockSubmodules import megatron.legacy.model # isort: skip
from dcu_megatron.core.transformer.mtp.mtp_spec import get_mtp_spec
from dcu_megatron.core.utils import tensor_slide # NOTE: Loading `megatron.legacy.model` earlier fails due to circular import
try:
from megatron.post_training.arguments import add_modelopt_args, modelopt_args_enabled
from megatron.post_training.loss_func import loss_func as loss_func_modelopt
from megatron.post_training.model_provider import model_provider as model_provider_modelopt
has_nvidia_modelopt = True
except ImportError:
has_nvidia_modelopt = False
from dcu_megatron import megatron_adaptor from dcu_megatron import megatron_adaptor
stimer = StragglerDetector() stimer = StragglerDetector()
def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
def model_provider(
pre_process=True, post_process=True
) -> Union[GPTModel, megatron.legacy.model.GPTModel]:
"""Builds the model. """Builds the model.
If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model. If you set the use_legacy_models to True, it will return the legacy GPT model and if not the mcore GPT model.
...@@ -62,25 +73,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -62,25 +73,33 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
""" """
args = get_args() args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return model_provider_modelopt(pre_process, post_process)
if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))): if bool(int(os.getenv("USE_FLUX_OVERLAP", "0"))):
assert args.transformer_impl == "transformer_engine" assert args.transformer_impl == "transformer_engine"
use_te = args.transformer_impl == "transformer_engine" use_te = args.transformer_impl == "transformer_engine"
if args.record_memory_history: if args.record_memory_history:
torch.cuda.memory._record_memory_history(True, torch.cuda.memory._record_memory_history(
True,
# keep 100,000 alloc/free events from before the snapshot # keep 100,000 alloc/free events from before the snapshot
trace_alloc_max_entries=100000, trace_alloc_max_entries=100000,
# record stack information for the trace events # record stack information for the trace events
trace_alloc_record_context=True) trace_alloc_record_context=True,
)
def oom_observer(device, alloc, device_alloc, device_free): def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened # snapshot right after an OOM happened
print('saving allocated state during OOM') print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot() snapshot = torch.cuda.memory._snapshot()
from pickle import dump from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
dump(
snapshot,
open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'),
)
torch._C._cuda_attach_out_of_memory_observer(oom_observer) torch._C._cuda_attach_out_of_memory_observer(oom_observer)
...@@ -99,64 +118,58 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -99,64 +118,58 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
pre_process=pre_process, pre_process=pre_process,
post_process=post_process, post_process=post_process,
) )
else: # using core models else: # using core models
if args.spec is not None: if args.spec is not None:
transformer_layer_spec = import_module(args.spec) transformer_layer_spec = import_module(args.spec)
else: else:
if args.num_experts: if args.num_experts:
# Define the decoder block spec # Define the decoder block spec
transformer_layer_spec = get_gpt_decoder_block_spec(config, use_transformer_engine=use_te) transformer_layer_spec = get_gpt_decoder_block_spec(
config, use_transformer_engine=use_te, normalization=args.normalization
)
elif args.heterogeneous_layers_config_path is not None:
transformer_layer_spec = get_gpt_heterogeneous_layer_spec(config, use_te)
else: else:
# Define the decoder layer spec # Define the decoder layer spec
if use_te: if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) args.moe_grouped_gemm,
args.qk_layernorm,
args.multi_latent_attention,
args.moe_use_legacy_grouped_gemm,
)
else: else:
transformer_layer_spec = get_gpt_layer_local_spec( transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts,
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm) args.moe_grouped_gemm,
args.qk_layernorm,
build_model_context = nullcontext args.multi_latent_attention,
build_model_context_args = {} args.moe_use_legacy_grouped_gemm,
if args.fp8_param_gather: normalization=args.normalization,
try: )
from transformer_engine.pytorch import fp8_model_init mtp_block_spec = None
if args.mtp_num_layers is not None:
build_model_context = fp8_model_init mtp_block_spec = get_gpt_mtp_block_spec(
build_model_context_args["enabled"] = True config, transformer_layer_spec, use_transformer_engine=use_te
# Check if fp8_model_init supports preserve_high_precision_init_val
if "preserve_high_precision_init_val" in inspect.signature(fp8_model_init).parameters:
build_model_context_args["preserve_high_precision_init_val"] = True
except:
raise RuntimeError("--fp8-param-gather requires `fp8_model_init` from TransformerEngine, but not found.")
# Define the mtp layer spec
if isinstance(transformer_layer_spec, TransformerBlockSubmodules):
mtp_transformer_layer_spec = transformer_layer_spec.layer_specs[-1]
else:
mtp_transformer_layer_spec = transformer_layer_spec
with build_model_context(**build_model_context_args):
config.mtp_spec = get_mtp_spec(mtp_transformer_layer_spec, use_te=use_te)
model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
) )
# model = torch.compile(model,mode='max-autotune-no-cudagraphs')
print_rank_0(model) model = GPTModel(
config=config,
transformer_layer_spec=transformer_layer_spec,
vocab_size=args.padded_vocab_size,
max_sequence_length=args.max_position_embeddings,
pre_process=pre_process,
post_process=post_process,
fp16_lm_cross_entropy=args.fp16_lm_cross_entropy,
parallel_output=True,
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling,
mtp_block_spec=mtp_block_spec,
)
return model return model
...@@ -165,7 +178,9 @@ def get_batch(data_iterator): ...@@ -165,7 +178,9 @@ def get_batch(data_iterator):
"""Generate a batch.""" """Generate a batch."""
# TODO: this is pretty hacky, find a better way # TODO: this is pretty hacky, find a better way
if (not mpu.is_pipeline_first_stage()) and (not mpu.is_pipeline_last_stage()): if (not parallel_state.is_pipeline_first_stage(ignore_virtual=True)) and (
not parallel_state.is_pipeline_last_stage(ignore_virtual=True)
):
return None, None, None, None, None return None, None, None, None, None
# get batches based on the TP rank you are on # get batches based on the TP rank you are on
...@@ -181,12 +196,15 @@ def get_batch(data_iterator): ...@@ -181,12 +196,15 @@ def get_batch(data_iterator):
SPIKY_LOSS_FACTOR = 10 SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): def loss_func(
loss_mask: torch.Tensor, output_tensor: torch.Tensor, model: Optional[GPTModel] = None
):
"""Loss function. """Loss function.
Args: Args:
loss_mask (torch.Tensor): Used to mask out some portions of the loss loss_mask (torch.Tensor): Used to mask out some portions of the loss
output_tensor (torch.Tensor): The tensor with the losses output_tensor (torch.Tensor): The tensor with the losses
model (GPTModel, optional): The model (can be wrapped)
Returns: Returns:
the loss scalar for this micro-batch the loss scalar for this micro-batch
...@@ -196,15 +214,16 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -196,15 +214,16 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
""" """
args = get_args() args = get_args()
if has_nvidia_modelopt and modelopt_args_enabled(args): # [ModelOpt]
return loss_func_modelopt(loss_mask, output_tensor, model=model)
losses = output_tensor.float() losses = output_tensor.float()
if getattr(args, "num_nextn_predict_layers", 0) > 0:
loss_mask = tensor_slide(loss_mask, args.num_nextn_predict_layers, return_first=True)[0]
loss_mask = loss_mask.view(-1).float() loss_mask = loss_mask.view(-1).float()
total_tokens = loss_mask.sum() total_tokens = loss_mask.sum()
loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)]) loss = torch.cat([torch.sum(losses.view(-1) * loss_mask).view(1), total_tokens.view(1)])
if args.context_parallel_size > 1: if args.context_parallel_size > 1:
torch.distributed.all_reduce(loss, group=mpu.get_context_parallel_group()) torch.distributed.all_reduce(loss, group=parallel_state.get_context_parallel_group())
# Check individual rank losses are not NaN prior to DP all-reduce. # Check individual rank losses are not NaN prior to DP all-reduce.
rerun_state_machine = get_rerun_state_machine() rerun_state_machine = get_rerun_state_machine()
...@@ -213,14 +232,14 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -213,14 +232,14 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
result=loss[0], result=loss[0],
rejection_func=torch.isnan, rejection_func=torch.isnan,
message="found NaN in local forward loss calculation", message="found NaN in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss[0],
rejection_func=torch.isinf, rejection_func=torch.isinf,
message="found Inf in local forward loss calculation", message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
# Check for spiky loss # Check for spiky loss
...@@ -233,19 +252,18 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -233,19 +252,18 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
context="loss", context="loss",
), ),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=False, fatal=False,
) )
# Reduce loss for logging. # Reduce loss for logging.
reporting_loss = loss.clone().detach() reporting_loss = loss.clone().detach()
torch.distributed.all_reduce(reporting_loss, group=mpu.get_data_parallel_group()) torch.distributed.all_reduce(reporting_loss, group=parallel_state.get_data_parallel_group())
# loss[0] is a view of loss, so it has ._base not None, which triggers assert error
# in core/pipeline_parallel/schedule.py::deallocate_output_tensor, calling .clone()
# on loss[0] fixes this
local_num_tokens = loss[1].clone().detach().to(torch.int) local_num_tokens = loss[1].clone().detach().to(torch.int)
return ( return (loss[0].clone(), local_num_tokens, {'lm loss': (reporting_loss[0], reporting_loss[1])})
loss[0] * args.context_parallel_size,
local_num_tokens,
{'lm loss': (reporting_loss[0], reporting_loss[1])},
)
def forward_step(data_iterator, model: GPTModel): def forward_step(data_iterator, model: GPTModel):
...@@ -262,21 +280,26 @@ def forward_step(data_iterator, model: GPTModel): ...@@ -262,21 +280,26 @@ def forward_step(data_iterator, model: GPTModel):
timers('batch-generator', log_level=2).start() timers('batch-generator', log_level=2).start()
global stimer global stimer
with stimer(bdata=True): with stimer(bdata=True):
tokens, labels, loss_mask, attention_mask, position_ids = get_batch( tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
with stimer: with stimer:
output_tensor = model(tokens, position_ids, attention_mask, if args.use_legacy_models:
labels=labels) output_tensor = model(tokens, position_ids, attention_mask, labels=labels)
else:
output_tensor = model(
tokens, position_ids, attention_mask, labels=labels, loss_mask=loss_mask
)
return output_tensor, partial(loss_func, loss_mask) # [ModelOpt]: model is needed to access ModelOpt distillation losses
return output_tensor, partial(loss_func, loss_mask, model=model)
def is_dataset_built_on_rank(): def is_dataset_built_on_rank():
return ( return (
mpu.is_pipeline_first_stage() or mpu.is_pipeline_last_stage() parallel_state.is_pipeline_first_stage(ignore_virtual=True)
) and mpu.get_tensor_model_parallel_rank() == 0 or parallel_state.is_pipeline_last_stage(ignore_virtual=True)
) and parallel_state.get_tensor_model_parallel_rank() == 0
def core_gpt_dataset_config_from_args(args): def core_gpt_dataset_config_from_args(args):
...@@ -289,7 +312,7 @@ def core_gpt_dataset_config_from_args(args): ...@@ -289,7 +312,7 @@ def core_gpt_dataset_config_from_args(args):
return GPTDatasetConfig( return GPTDatasetConfig(
random_seed=args.seed, random_seed=args.seed,
sequence_length=args.seq_length + getattr(args, "num_nextn_predict_layers", 0), sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
split=args.split, split=args.split,
...@@ -301,7 +324,8 @@ def core_gpt_dataset_config_from_args(args): ...@@ -301,7 +324,8 @@ def core_gpt_dataset_config_from_args(args):
reset_attention_mask=args.reset_attention_mask, reset_attention_mask=args.reset_attention_mask,
eod_mask_loss=args.eod_mask_loss, eod_mask_loss=args.eod_mask_loss,
create_attention_mask=args.create_attention_mask_in_dataloader, create_attention_mask=args.create_attention_mask_in_dataloader,
s3_cache_path=args.s3_cache_path, object_storage_cache_path=args.object_storage_cache_path,
mid_level_dataset_surplus=args.mid_level_dataset_surplus,
) )
...@@ -323,10 +347,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -323,10 +347,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
print_rank_0("> building train, validation, and test datasets for GPT ...") print_rank_0("> building train, validation, and test datasets for GPT ...")
train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder( train_ds, valid_ds, test_ds = BlendedMegatronDatasetBuilder(
dataset_type, dataset_type, train_val_test_num_samples, is_dataset_built_on_rank, config
train_val_test_num_samples,
is_dataset_built_on_rank,
config
).build() ).build()
print_rank_0("> finished creating GPT datasets ...") print_rank_0("> finished creating GPT datasets ...")
...@@ -345,4 +366,5 @@ if __name__ == "__main__": ...@@ -345,4 +366,5 @@ if __name__ == "__main__":
ModelType.encoder_or_decoder, ModelType.encoder_or_decoder,
forward_step, forward_step,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}, args_defaults={'tokenizer_type': 'GPT2BPETokenizer'},
extra_args_provider=add_modelopt_args if has_nvidia_modelopt else None,
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment