#!/bin/bash

# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# 
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# 
#     http://www.apache.org/licenses/LICENSE-2.0
# 
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

set -ex
export FLAGS_rocm_dir=/public/software/compiler/rocm/dtk-21.04/
export FLAGS_max_inplace_grad_add=2
export HSA_FORCE_FINE_GRAIN_PCIE=1
export NCCL_P2P_LEVEL=5

export USE_NV_INPUT=1

USE_UNCOMPRESSED_DATASET=1

BASE_DATA_DIR=${BASE_DATA_DIR:-"/public/software/apps/DeepLearning/Data/mlperf/bert"}
export USE_NV_INPUT
UNCOMPRESSED_DATA_DIR=$BASE_DATA_DIR/hdf5/training-4320/hdf5_4320_shards_uncompressed
VARLENGTH_DATA_DIR=$BASE_DATA_DIR/hdf5/training-4320/hdf5_4320_shards_varlength

export DATA_DIR=$UNCOMPRESSED_DATA_DIR
export EVAL_DIR=$BASE_DATA_DIR/hdf5/eval
if [[ "$USE_NV_INPUT" == "1" && "$USE_UNCOMPRESSED_DATASET" == "0" ]]; then
  export DATA_DIR="$VARLENGTH_DATA_DIR"
  export EVAL_DIR=$BASE_DATA_DIR/hdf5/eval
else
  export USE_UNCOMPRESSED_DATASET=1
fi
export USE_UNCOMPRESSED_DATASET
export TF_CKPT_PATH=$BASE_DATA_DIR/phase1/model.ckpt-28252.tf_pickled
export BERT_CONFIG_PATH=$BASE_DATA_DIR/phase1/bert_config.json

export PYTHON=python3
export PADDLE_TRAINER_ID=${OMPI_COMM_WORLD_RANK}
export PADDLE_TRAINERS_NUM=${PADDLE_TRAINERS_NUM:-"1"}
export PADDLE_TRAINER_ENDPOINTS=${PADDLE_TRAINER_ENDPOINTS:-""}

OMPI_COMM_WORLD_RANK=${OMPI_COMM_WORLD_RANK:-"0"}
lrank=$OMPI_COMM_WORLD_LOCAL_RANK

function get_device_id() {
$PYTHON <<EOF
import paddle
import os
gpus = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if gpus is None:
    print($OMPI_COMM_WORLD_RANK)
else:
    gpus = gpus.split(",")
    print(gpus[$OMPI_COMM_WORLD_RANK])
EOF
}
if [[ $PADDLE_TRAINER_ID -lt $PADDLE_TRAINERS_NUM ]]; then
  export CUDA_VISIBLE_DEVICES=0,1,2,3 #$(expr $OMPI_COMM_WORLD_RANK % 4) #`get_device_id`
  export IS_TRAINER=1
  export IS_READER=0
else
  export CUDA_VISIBLE_DEVICES=""
  export IS_TRAINER=0
  export IS_READER=1
fi

echo "Trainer :" $CUDA_VISIBLE_DEVICES $PADDLE_TRAINER_ENDPOINTS $PADDLE_TRAINERS_NUM

export FLAGS_sync_nccl_allreduce=0
export FLAGS_fraction_of_gpu_memory_to_use=0.99
export FLAGS_call_stack_level=2
export FLAGS_use_fast_math=0
export FLAGS_enable_nvtx=1


batch_size=4
eval_batch_size=63
use_amp=True
use_pure_fp16=True

max_steps=7100
log_freq=50
eval_iter_start_samples=150000
eval_iter_samples=150000
max_seq_length=512

dense_seq_output=True
unpad=False
unpad_fmha=False
fused_bias_mha=True
fused_bias_fc=True
## can be False or True 
weight_transpose=True

fused_dropout_add_ln=False
exchange_padding=True
cpu_exchange_padding=True

distributed_lamb=True

unpad_embed=False
unpad_fmha_mke_opt=True

sort_eval_data=False

LOG_DIR="log_${PADDLE_TRAINERS_NUM}"
mkdir -p ${LOG_DIR}
LOG_FILE=${LOG_DIR}/worker.${PADDLE_TRAINER_ID}

export FLAGS_max_inplace_grad_add=2

if [[ "$exchange_padding" == "true" || "$exchange_padding" == "True" ]]; then
  if [[ "$cpu_exchange_padding" == "true" || "$cpu_exchange_padding" == "True" ]]; then
    export DATA_DIR="$UNCOMPRESSED_DATA_DIR"
  fi
fi


BERT_CMD="run_pretrain.py \
   --max_predictions_per_seq 76 \
   --train_batch_size $batch_size   \
   --eval_batch_size $eval_batch_size \
   --sort_eval_data $sort_eval_data \
   --learning_rate 0.000425 \
   --weight_decay 1e-2 \
   --lamb_epsilon 1e-6 \
   --start_warmup_step 0 \
   --warmup_proportion 0.0 \
   --warmup_steps 0 \
   --input_dir $DATA_DIR \
   --log_freq $log_freq \
   --max_steps $max_steps \
   --tf_ckpt_path $TF_CKPT_PATH \
   --bert_config_path $BERT_CONFIG_PATH \
   --unpad $unpad \
   --unpad_fmha $unpad_fmha \
   --unpad_fmha_mke_opt $unpad_fmha_mke_opt \
   --unpad_embed $unpad_embed \
   --fused_bias_mha $fused_bias_mha \
   --fused_bias_fc $fused_bias_fc \
   --fused_dropout_add_ln $fused_dropout_add_ln \
   --weight_transpose $weight_transpose \
   --max_seq_length $max_seq_length \
   --eval_dir $EVAL_DIR \
   --distributed_lamb $distributed_lamb \
   --exchange_padding $exchange_padding \
   --cpu_exchange_padding $cpu_exchange_padding \
   --seed $SEED \
   --use_uncompressed_dataset $USE_UNCOMPRESSED_DATASET \
   --dense_seq_output $dense_seq_output \
   --gradient_accumulation_steps 14 \
   --opt_lamb_beta_1 0.9 \
   --opt_lamb_beta_2 0.999 \
   --enable_addto True \
   --use_pure_fp16 $use_pure_fp16 \
   --use_amp $use_amp"


APP="python3 -u $BERT_CMD"
case $(expr $lrank % 4) in
[0])
  echo "work ${lrank} less than ${PADDLE_TRAINERS_NUM} on DCU $(expr $lrank % 4)"
  export HIP_VISIBLE_DEVICES=0,1,2,3
  export FLAGS_selected_gpus=0
  export UCX_NET_DEVICES=mlx5_0:1
  export UCX_IB_PCI_BW=mlx5_0:50Gbs
  numactl --cpunodebind=0 --membind=0 ${APP} >& $LOG_FILE
  ;;
[1])
  echo "work ${lrank} less than ${PADDLE_TRAINERS_NUM} on DCU $(expr $lrank % 4)"
  export HIP_VISIBLE_DEVICES=0,1,2,3
  export FLAGS_selected_gpus=1
  export UCX_NET_DEVICES=mlx5_1:1
  export UCX_IB_PCI_BW=mlx5_1:50Gbs
  numactl --cpunodebind=1 --membind=1 ${APP} >& $LOG_FILE
  ;;
[2])
  echo "work ${lrank} less than ${PADDLE_TRAINERS_NUM} on DCU $(expr $lrank % 4)"
  export HIP_VISIBLE_DEVICES=0,1,2,3
  export FLAGS_selected_gpus=2
  export UCX_NET_DEVICES=mlx5_2:1
  export UCX_IB_PCI_BW=mlx5_2:50Gbs
  numactl --cpunodebind=2 --membind=2 ${APP} >& $LOG_FILE
  ;;
[3])
  echo "work ${lrank} less than ${PADDLE_TRAINERS_NUM} on DCU $(expr $lrank % 4)"
  export HIP_VISIBLE_DEVICES=0,1,2,3
  export FLAGS_selected_gpus=3
  export UCX_NET_DEVICES=mlx5_3:1
  export UCX_IB_PCI_BW=mlx5_3:50Gbs
  numactl --cpunodebind=3 --membind=3 ${APP}  >& $LOG_FILE
  ;;
esac
