#!/bin/bash

for para in $*
do
    if [[ $para == --profiling* ]];then
        profiling=${para#*=}
    fi
done

# Runs DeepseekV2 236B model
source /opt/dtk/env.sh

# default env
CURRENT_DIR="$( cd "$( dirname "$0" )" && pwd )"
MEGATRON_PATH=$( dirname $( dirname ${CURRENT_DIR}))
export PYTHONPATH=${MEGATRON_PATH}:${MEGATRON_PATH}/Megatron-LM-241113:$PYTHONPATH
export GLOG_minloglevel=3
export CUDA_DEVICE_MAX_CONNECTIONS=1
export HSA_FORCE_FINE_GRAIN_PCIE=1
export OMP_NUM_THREADS=1
export GPU_MAX_HW_QUEUES=10

# nccl env
export NCCL_ALGO=Ring
export NCCL_MIN_NCHANNELS=32
export NCCL_MAX_NCHANNELS=32
export NCCL_NET_GDR_LEVEL=7
export NCCL_NET_GDR_READ=1
export RCCL_SDMA_COPY_ENABLE=0
export NCCL_IB_HCA=mlx5_2:1,mlx5_3:1,mlx5_4:1,mlx5_5:1,mlx5_6:1,mlx5_7:1,mlx5_8:1,mlx5_9:1
export NCCL_TOPO_FILE="./topo-input.xml"

# enable BatchLinear
export GROUPED_GEMM_BatchLinear=1

LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
RANK=$OMPI_COMM_WORLD_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE

# Here are some configs controled by env
if [ -z ${MP_DATASET_TYPE} ];then
    MP_DATASET_TYPE="idxmap"
fi

if [ -z ${MP_AC_LAYERS} ];then
    MP_AC_LAYERS=1
fi

if [ $ENV = dsw ]; then
    export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
    MASTER_ADDR=localhost
    MASTER_PORT=$(shuf -n 1 -i 10000-65535)
    NNODES=1
    NODE_RANK=0
    GPUS_PER_NODE=8
elif [ $ENV = dlc ]; then
    NNODES=${WORLD_SIZE}
    NODE_RANK=${RANK}
    GPUS_PER_NODE=${KUBERNETES_CONTAINER_RESOURCE_GPU}
fi

if [ -z ${MP_VP} ]; then
    vp_options=""
else
    vp_options=" \
        --num-layers-per-virtual-pipeline-stage ${MP_VP}"
fi

if [ -z ${MP_SFT_PACKING} ]; then
    MP_SFT_PACKING=false
fi


DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

### BASE CONFIG ###
MODEL_SIZE=A21B
BATCH_SIZE=1
GLOBAL_BATCH_SIZE=2048
LR=1e-5
MIN_LR=1e-6
SEQ_LEN=4096
PAD_LEN=4096
PR=bf16
### BASE CONFIG ###

### PARALLEL / BOOL OPTION ###
TP=1
PP=2
CP=1
EP=8
SP=true
DO=true
FL=true
SFT=false
### PARALLEL / BOOL OPTION ###

### OTHERS ###
AC=none
OPTIMIZER_OFFLOAD=false
SAVE_INTERVAL=500
DATASET_PATH="path to mmap_deepseekv2_datasets_text_document"
VALID_DATASET_PATH="path to mmap_deepseekv2_datasets_text_document"
PRETRAIN_CHECKPOINT_PATH="path to deepseekv2_dataset"

# the following two values will not be used when SFT is true
TRAIN_TOKENS=100000000
WARMUP_TOKENS=10000
###############################

OUTPUT_BASEPATH=./output
### OTHERS ###

if [ $FL = true ]; then
    export NVTE_FLASH_ATTN=1 NVTE_FUSED_ATTN=0
elif [ $FL = false ]; then
    export NVTE_FLASH_ATTN=0 NVTE_FUSED_ATTN=1
fi

if [ $MODEL_SIZE = A2.4B ]; then
    TRAIN_ITERS=10
    HIDDEN_SIZE=2048
    NUM_ATTN_HEADS=16
    NUM_LAYERS=27
    INTERMEDIATE_SIZE=10944
    MOE_INTERMEDIATE_SIZE=1408
    MAX_POSITION_EMBEDDINGS=${SEQ_LEN}
    EXTRA_VOCAB_SIZE=2400
    KV_LORA_RANK=512
    QK_NOPE_HEAD_DIM=128
    QK_ROPE_HEAD_DIM=64
    V_HEAD_DIM=128
    ROPE_THETA=10000
    SCALE_FACTOR=40
    NUM_EXPERTS=64
    ROUTER_TOPK=6
    NUM_SHARED_EXPERTS=2
    MOE_LAYER_FREQ=1
    RMS_NORM_EPS=1e-6

    moe_options=" \
        --moe-ffn-hidden-size ${MOE_INTERMEDIATE_SIZE} \
        --moe-router-topk ${ROUTER_TOPK} \
        --num-experts ${NUM_EXPERTS} \
        --moe-layer-freq ${MOE_LAYER_FREQ} \
        --moe-aux-loss-coeff 1e-2 \
        --moe-shared-expert-intermediate-size $((${MOE_INTERMEDIATE_SIZE} * ${NUM_SHARED_EXPERTS} )) \
        --expert-model-parallel-size ${EP} \
        --kv-lora-rank ${KV_LORA_RANK} \
        --qk-head-dim ${QK_NOPE_HEAD_DIM} \
        --qk-pos-emb-head-dim ${QK_ROPE_HEAD_DIM} \
        --v-head-dim ${V_HEAD_DIM} \
        --moe-router-load-balancing-type aux_loss"

elif [ $MODEL_SIZE = A21B ]; then
    TRAIN_ITERS=10
    HIDDEN_SIZE=5120
    NUM_ATTN_HEADS=128
    NUM_LAYERS=8
    INTERMEDIATE_SIZE=12288
    MOE_INTERMEDIATE_SIZE=1536
    MAX_POSITION_EMBEDDINGS=${SEQ_LEN}
    EXTRA_VOCAB_SIZE=2400
    Q_LORA_RANK=1536
    KV_LORA_RANK=512
    QK_NOPE_HEAD_DIM=128
    QK_ROPE_HEAD_DIM=64
    V_HEAD_DIM=128
    ROPE_THETA=10000
    SCALE_FACTOR=40
    NUM_EXPERTS=160
    ROUTER_TOPK=6
    NUM_SHARED_EXPERTS=2
    MOE_LAYER_FREQ=1
    RMS_NORM_EPS=1e-6

    moe_options=" \
        --moe-ffn-hidden-size ${MOE_INTERMEDIATE_SIZE} \
        --moe-router-topk ${ROUTER_TOPK} \
        --num-experts ${NUM_EXPERTS} \
        --moe-layer-freq ${MOE_LAYER_FREQ} \
        --moe-aux-loss-coeff 1e-2 \
        --moe-shared-expert-intermediate-size $((${MOE_INTERMEDIATE_SIZE} * ${NUM_SHARED_EXPERTS} )) \
        --expert-model-parallel-size ${EP} \
        --q-lora-rank ${Q_LORA_RANK} \
        --kv-lora-rank ${KV_LORA_RANK} \
        --qk-head-dim ${QK_NOPE_HEAD_DIM} \
        --qk-pos-emb-head-dim ${QK_ROPE_HEAD_DIM} \
        --v-head-dim ${V_HEAD_DIM} \
        --moe-router-load-balancing-type aux_loss"

fi

TP_COMM_OVERLAP=$(( ($TP > 1) ? 1 : 0 ))
comm_overlap_option="\
    --overlap-grad-reduce \
    --overlap-param-gather"
 

if [ $TP_COMM_OVERLAP -eq 1 ]; then
    comm_overlap_option="\
        --tp-comm-overlap \
        --overlap-grad-reduce \
        --overlap-param-gather"
fi

if [ $AC = full ]; then
    _check=$(( ($NUM_LAYERS / $PP) % ${MP_AC_LAYERS} ))
    if [ $_check != 0 ]; then
        echo "the num layers per pp rank must be a multiple of the recompute layers."
        exit -1
    fi
    activation_checkpoint_options=" \
		    --recompute-method uniform \
            --recompute-num-layers ${MP_AC_LAYERS} \
		    --recompute-granularity full"
elif [ $AC = sel ]; then
    activation_checkpoint_options=" \
        --recompute-activations"
elif [ $AC = none ]; then
    activation_checkpoint_options=" \
    "
elif [ $AC = offload ]; then
    activation_checkpoint_options=" \
		    --cpu-offloading \
		    --cpu-offloading-num-layers ${MP_AC_LAYERS}"
    if [ $TP_COMM_OVERLAP -eq 1 ]; then
        echo "Disable --overlap-grad-reduce and --overlap-param-gather when cpu offloading is on..."
        comm_overlap_option="\
            --tp-comm-overlap"
    else
        echo "Disable --overlap-grad-reduce and --overlap-param-gather when cpu offloading is on..."
        comm_overlap_option=""
    fi
fi

if [ $PR = fp16 ]; then
    pr_options=" \
		    --fp16 \
            --apply-query-key-layer-scaling"
    export NVTE_APPLY_QK_LAYER_SCALING=1
elif [ $PR = bf16 ]; then
    pr_options=" \
        --bf16"
elif [ $PR = fp8 ]; then
    pr_options=" \
        --bf16 \
        --fp8-format hybrid \
        --fp8-amax-compute-algo max \
        --fp8-amax-history-len 1024"
fi

if [ $OPTIMIZER_OFFLOAD != false ] && [ $DO = false ]; then
    echo "Offload optimizer is valid only if \$DO=true"
    DO=true
fi

if [ $DO = true ]; then
    do_options=" \
		    --use-distributed-optimizer"

elif [ $DO = false ]; then
    do_options=" \
                    "
fi

te_options=" \
        --transformer-impl transformer_engine"

if [ $SP = true ] && [ $TP -gt 1 ]; then
    sp_options=" \
		    --sequence-parallel"

elif [ $SP = false ]; then
    sp_options=" \
                    "
fi

if [ -z ${MP_PP0_LAYERS} ];then
    uneven_split_option=""
elif [ ${PP} -gt 1 ]; then
    _check=$(( ( $NUM_LAYERS - ${MP_PP0_LAYERS} ) % ( ${PP} - 1 ) ))
    if [ $_check != 0 ]; then
        echo "With uneven pipelineing the left over layers must be divisible by left over stages."
        exit -1
    fi

    uneven_split_option=" \
        --decoder-first-pipeline-num-layers ${MP_PP0_LAYERS}
    "
else
    echo "uneven pipeline split must be used when PP > 1"
    exit -1
fi

if [ $PRETRAIN_CHECKPOINT_PATH != none ]; then
    load_options=" \
            --load $PRETRAIN_CHECKPOINT_PATH"
fi

if [ $OPTIMIZER_OFFLOAD != false ]; then
    offload_option=" \
        --optimizer-cpu-offload \
        --use-precision-aware-optimizer \
        --optimizer-offload-fraction ${OPTIMIZER_OFFLOAD}"
fi
 

if [ $SFT = true ]; then
    TRAIN_ITERS=${24}
    LR_WARMUP_ITERS=${25}
    LR_DECAY_ITERS=$(( ${TRAIN_ITERS} - ${LR_WARMUP_ITERS}))
    PREFIX="finetune-mcore-deepseek-v2-${MODEL_SIZE}-lr-${LR}-minlr-${MIN_LR}-bs-${BATCH_SIZE}-gbs-${GLOBAL_BATCH_SIZE}-seqlen-${SEQ_LEN}"
    sft_option=" \
         --eod-mask-loss \
         --calculate-per-token-loss \
         --train-mode finetune"
else
    #TRAIN_ITERS=$(( ${TRAIN_TOKENS} / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
    LR_WARMUP_ITERS=$(( ${WARMUP_TOKENS}  / ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
    LR_DECAY_ITERS=$(( ${TRAIN_TOKENS} /  ${GLOBAL_BATCH_SIZE} / ${SEQ_LEN} ))
    PREFIX="pretrain-mcore-deepseek-v2-${MODEL_SIZE}-lr-${LR}-minlr-${MIN_LR}-bs-${BATCH_SIZE}-gbs-${GLOBAL_BATCH_SIZE}-seqlen-${SEQ_LEN}"
    sft_option=" \
        --train-mode pretrain"
fi

if [ ${MP_DATASET_TYPE} = "raw" ]; then
    dataset_option=" \
        --train-data-path ${DATASET_PATH} \
        --valid-data-path ${VALID_DATASET_PATH} \
        --dataloader-type cyclic \
        --dataset JSON-SFT"
else 
    dataset_option=" \
        --data-path ${DATASET_PATH} \
        --split 99,1,0 \
        --dataset MMAP"
fi

if [ ${MP_SFT_PACKING} = true ]; then
    echo "Currently MLA do not support THD format attention, thus sequence packing can not be used..."
    packing_options=""
else
    packing_options=""
fi

##### Prepare logdirs #######
NAME="${PREFIX}-pr-${PR}-tp-${TP}-pp-${PP}-cp-${CP}-ac-${AC}-do-${DO}-sp-${SP}-ti-${TRAIN_ITERS}-wi-${LR_WARMUP_ITERS}"
mkdir -p "${OUTPUT_BASEPATH}/tensorboard/"
mkdir -p "${OUTPUT_BASEPATH}/checkpoint/"
mkdir -p "${OUTPUT_BASEPATH}/log/"
current_time=$(date "+%Y.%m.%d-%H.%M.%S")
TENSORBOARD_DIR="${OUTPUT_BASEPATH}/tensorboard/${NAME}_${current_time}"
mkdir -p ${TENSORBOARD_DIR}
SAVED_PRETRAIN_CHECKPOINT_PATH="${OUTPUT_BASEPATH}/checkpoint/${NAME}"

mkdir -p ${SAVED_PRETRAIN_CHECKPOINT_PATH}
find -L ${PRETRAIN_CHECKPOINT_PATH} -maxdepth 1 -type f -name "*.json" -print0 | xargs -0 cp -t ${SAVED_PRETRAIN_CHECKPOINT_PATH}
#find -L ${PRETRAIN_CHECKPOINT_PATH} -maxdepth 1 -type f -name "merges.txt" -print0 | xargs -0 cp -t ${SAVED_PRETRAIN_CHECKPOINT_PATH}

megatron_options="  \
        --lr ${LR} \
        --min-lr ${MIN_LR} \
        --lr-decay-style cosine \
        --weight-decay 0.1 \
        --adam-beta1 0.9 \
        --adam-beta2 0.95 \
        --clip-grad 1.0 \
        --init-method-std 0.008 \
        --attention-dropout 0.0 \
        --hidden-dropout 0.0 \
        --lr-decay-iters ${LR_DECAY_ITERS} \
        --lr-warmup-iters ${LR_WARMUP_ITERS} \
        --train-iters ${TRAIN_ITERS} \
        --micro-batch-size ${BATCH_SIZE} \
        --global-batch-size ${GLOBAL_BATCH_SIZE} \
        --num-layers ${NUM_LAYERS} \
        --hidden-size ${HIDDEN_SIZE} \
        --num-attention-heads ${NUM_ATTN_HEADS} \
        --ffn-hidden-size ${INTERMEDIATE_SIZE} \
        --seq-length ${SEQ_LEN} \
        --max-position-embeddings ${MAX_POSITION_EMBEDDINGS} \
        --max-padding-length ${PAD_LEN} \
        --log-interval 1 \
        --log-throughput \
        --eval-interval 10000 \
        --eval-iters 3 \
        --save-interval ${SAVE_INTERVAL} \
        --tensorboard-queue-size 1 \
        --tensorboard-dir ${TENSORBOARD_DIR} \
        --log-timers-to-tensorboard \
        --log-validation-ppl-to-tensorboard \
        --tensor-model-parallel-size ${TP} \
        --pipeline-model-parallel-size ${PP} \
        --context-parallel-size ${CP} \
        --no-load-optim \
        --no-load-rng \
        --num-workers 8 \
        --extra-vocab-size ${EXTRA_VOCAB_SIZE} \
        --patch-tokenizer-type DeepSeekV2Tokenizer \
        --swiglu \
        --normalization RMSNorm \
        --norm-epsilon ${RMS_NORM_EPS} \
        --use-rotary-position-embeddings \
        --no-bias-swiglu-fusion \
        --no-rope-fusion \
        --position-embedding-type rope \
        --untie-embeddings-and-output-weights \
        --disable-bias-linear \
        --rotary-base ${ROPE_THETA} \
        --rotary-scaling-factor ${SCALE_FACTOR} \
        --no-save-optim \
        --kv-channels ${V_HEAD_DIM} \
        --qk-layernorm \
        --multi-latent-attention \
        --ckpt-format torch \
        "

TORCH_PROFIE_ARGS="  \
    --profile \
    --profile-ranks 0 1 2 3 4 5 6 7 \
    --profile-step-start 3 \
    --profile-step-end 4 \
    --profile-dir torch_prof_deepseekv2_64nodes_tp1-pp2-ep8-cp1 \
    --use-pytorch-profiler \
"

APP="python3 -u pretrain_deepseek.py
        ${megatron_options} \
        ${dataset_option} \
        ${pr_options} \
        ${load_options} \
        ${te_options} \
        ${activation_checkpoint_options} \
        ${do_options} \
        ${sp_options} \
        ${moe_options} \
        ${offload_option} \
        ${sft_option} \
        ${vp_options} \
        ${packing_options} \
        ${uneven_split_option} \
        --rank ${RANK} \
        --world-size ${WORLD_SIZE} \
        --local-rank ${LOCAL_RANK} \
        --dist-url tcp://${1}:25900 \
        "

if [[ $profiling == "torch" ]]; then
    APP+=" ${TORCH_PROFIE_ARGS}"
fi

case ${LOCAL_RANK} in
[0])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[1])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[2])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[3])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[4])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[5])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[6])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
[7])
  export HIP_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
  ${APP}
  ;;
esac
