#!/bin/bash
set -e

# =================================================
# Helper functions
# =================================================
help() {
    cat << EOF
rocHPCG MPI run helper script
Usage: $(basename "$0") [OPTIONS]

OPTIONS:
    -h, --help    Show this help message and exit
    --npx         Number of processes in x dimension of process grid (default: ${npx})
    --npy         Number of processes in y dimension of process grid (default: ${npy})
    --npz         Number of processes in z dimension of process grid (default: ${npz})
    --nx          Problem size in x dimension (default: ${nx})
    --ny          Problem size in y dimension (default: ${ny})
    --nz          Problem size in z dimension (default: ${nz})
    --rt          Benchmarking time in seconds (> 1800s for official runs) (default: ${runtime})
    --tol         Residual tolerance, skip reference verification if set (default: ${tol})
    --pz          Partition boundary in z process dimension (default: 0, uniform grid)
    --zl          Local nz value for processes with z rank < pz (default: equal to ${nz})
    --zu          Local nz value for processes with z rank >= pz (default: equal to ${nz})

    -H, --hosts   Comma-separated list of nodes to run on
    --tcp-iface   TCP interface to use for communication (default: ${tcp_iface})
    --ssh-port    SSH port to use for remote connections (default: ${ssh_port})
EOF
}

# =================================================
# Global variables
# =================================================
npx=1
npy=1
npz=1
nx=560
ny=280
nz=280
runtime=60
tol=1
pz=0
zl=${nz}
zu=${nz}

nodes=
tcp_iface=p14p2
ssh_port=3333

rochpcg_runscript="${PWD}/run_rochpcg"
mpi_bin="${PWD}/deps/openmpi/bin/mpirun"
ompi_prefix="${PWD}/deps/openmpi"
ompi_lib_dir="${PWD}/deps/openmpi/lib"
ompi_lib64_dir="${PWD}/deps/openmpi/lib64"
ucx_lib_dir="${PWD}/deps/ucx/lib"
ucx_lib64_dir="${PWD}/deps/ucx/lib64"

export PATH="${ompi_prefix}/bin${PATH:+:${PATH}}"
export LD_LIBRARY_PATH="${ompi_lib_dir}:${ompi_lib64_dir}:${ucx_lib_dir}:${ucx_lib64_dir}${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}"
export OPAL_PREFIX="${ompi_prefix}"

# Detect the number of GPUs per node
ngpu_per_node=$(hy-smi --showid 2>/dev/null | grep -ic "Device ID")
if [[ -z "${ngpu_per_node}" || "${ngpu_per_node}" -eq 0 ]]; then
  echo "Failed to get the number of GPUs per node via hy-smi. Defaulting to 8."
  ngpu_per_node=8
else
  echo "Detected ${ngpu_per_node} GPUs per node."
fi

# =================================================
# Parameter parsing
# =================================================
GETOPT_PARSE=$(getopt --name "${0}" --options hH: --longoptions help,npx:,npy:,npz:,nx:,ny:,nz:,rt:,tol:,pz:,zl:,zu:,hosts:,tcp-iface:,ssh-port: -- "$@") \
  || { echo "getopt invocation failed; could not parse the command line"; exit 1; }

eval set -- "${GETOPT_PARSE}"

while true; do
  case "${1}" in
    -h|--help) help; exit 0 ;;
    --npx) npx=${2}; shift 2 ;;
    --npy) npy=${2}; shift 2 ;;
    --npz) npz=${2}; shift 2 ;;
    --nx) nx=${2}; shift 2 ;;
    --ny) ny=${2}; shift 2 ;;
    --nz)
        nz=${2}
        zl=${nz}
        zu=${nz}
        shift 2 ;;
    --rt) runtime=${2}; shift 2 ;;
    --tol) tol=${2}; shift 2 ;;
    --pz) pz=${2}; shift 2 ;;
    --zl) zl=${2}; shift 2 ;;
    --zu) zu=${2}; shift 2 ;;
    -H|--hosts) nodes=${2}; shift 2 ;;
    --tcp-iface) tcp_iface=${2}; shift 2 ;;
    --ssh-port) ssh_port=${2}; shift 2 ;;
    --) shift ; break ;;
    *)  echo "Unexpected command line parameter received; aborting";
        exit 1
        ;;
  esac
done

# Build rochpcg arguments
rochpcg_args="--npx=${npx} --npy=${npy} --npz=${npz}"
rochpcg_args+=" --nx=${nx} --ny=${ny} --nz=${nz}"
rochpcg_args+=" --rt=${runtime}"
rochpcg_args+=" --tol=${tol}"
rochpcg_args+=" --pz=${pz}"
rochpcg_args+=" --zl=${zl}"
rochpcg_args+=" --zu=${zu}"

# Calculate total number of processes
np=$((${npx}*${npy}*${npz}))

# =================================================
# Run rochpcg script
# =================================================
# Run single-node test if --hosts is not set
if [ -z "${nodes}" ]; then
  echo "No compute nodes specified. Running in single-node mode."

  ${mpi_bin} --allow-run-as-root \
    --bind-to none \
    --mca pml ucx \
    --mca osc ucx \
    --mca btl ^vader,tcp,openib,uct \
    --mca coll ^hcoll \
    -x UCX_TLS=self,sm,rocm \
    -x UCX_RNDV_SCHEME=put_zcopy \
    -x UCX_MEMTYPE_CACHE=y \
    -x HSA_FORCE_FINE_GRAIN_PCIE=1 \
    -np ${np} \
    ${rochpcg_runscript} ${rochpcg_args}
else
  echo "Running in multi-node mode. Using nodes: ${nodes}"
  echo "Using TCP interface: ${tcp_iface}"
  echo "Using SSH port: ${ssh_port}"

  # Set rank counts for hosts
  IFS=',' read -ra node_array <<< "${nodes}"
  hosts_string=""
  for node in "${node_array[@]}"; do
    hosts_string+="${node}:${ngpu_per_node},"
  done
  hosts_string="${hosts_string%,}"

  echo "MPI hosts: ${hosts_string}"

  # Copy files to other nodes
  current_node=$(hostname)
  copyto_hosts=()
  for node in "${node_array[@]}"; do
    if [[ "${node}" != "${current_node}" ]]; then
      copyto_hosts+=("${node}")
    fi
  done

  # Copy files using rsync only if there are other nodes to copy to
  if [ ${#copyto_hosts[@]} -gt 0 ]; then
    echo "Copying files to other nodes in parallel: ${copyto_hosts[@]}"
    for node in "${copyto_hosts[@]}"; do
      rsync -az -e "ssh -p ${ssh_port}" build deps ${rochpcg_runscript} "${node}:/workspace/" &
    done
    wait
    echo "Files synchronized successfully."
  fi

  # Multi-node run
  ${mpi_bin} --allow-run-as-root \
    --prefix ${ompi_prefix} \
    --map-by ppr:${ngpu_per_node}:node --bind-to none \
    --mca pml ucx \
    --mca osc ucx \
    --mca btl ^openib \
    --mca btl_tcp_if_include ${tcp_iface} \
    --mca plm_rsh_args "-p ${ssh_port}" \
    --mca coll_hcoll_enable 0 \
    -x UCX_TLS=self,sm,rocm,rc \
    -x UCX_RNDV_SCHEME=put_zcopy \
    -x UCX_RNDV_FRAG_MEM_TYPE=rocm \
    -x UCX_MEMTYPE_CACHE=n \
    -x UCX_LOG_LEVEL=fatal \
    -x HSA_FORCE_FINE_GRAIN_PCIE=1 \
    -x PATH -x LD_LIBRARY_PATH -x OPAL_PREFIX \
    -np ${np} \
    -H ${hosts_string} \
    ${rochpcg_runscript} ${rochpcg_args}
fi