#!/bin/bash
set -e

# =================================================
# Helper functions
# =================================================
help() {
    cat << EOF
RCCL Tests MPI run helper script
Usage: $(basename "$0") [OPTIONS]

OPTIONS:
    -h, --help    Show this help message and exit
    -np           Total number of processes (default: sum of per-node counts in --hosts)
    -H, --hosts   Comma-separated list of nodes with optional process count per node
                  Format: node01:8,node02:8
                  If count is omitted, falls back to auto-detected GPU count per node.
    --tcp-iface   TCP interface to use for communication (default: ${tcp_iface})
    --ssh-port    SSH port to use for remote connections (default: ${ssh_port})
EOF
}

# =================================================
# Global variables
# =================================================
np=
hosts_raw=
tcp_iface=p14p2
ssh_port=3333

rccltest_args=()
mpi_bin=/opt/mpi/bin/mpirun
ompi_prefix=/opt/mpi

# Detect the number of GPUs per node (used as fallback when count is not specified in --hosts)
ngpu_per_node=$(hy-smi --showid 2>/dev/null | grep -ic "Device ID")
if [[ -z "${ngpu_per_node}" || "${ngpu_per_node}" -eq 0 ]]; then
  echo "[WRAPPER] Failed to get the number of GPUs per node via hy-smi. Defaulting to 8."
  ngpu_per_node=8
else
  echo "[WRAPPER] Detected ${ngpu_per_node} GPUs per node."
fi

# =================================================
# Parameter parsing
# =================================================
while [[ $# -gt 0 ]]; do
  case "${1}" in
    -h|--help) help; exit 0 ;;
    -np) np=${2}; shift 2 ;;
    -H|--hosts) hosts_raw=${2}; shift 2 ;;
    --tcp-iface) tcp_iface=${2}; shift 2 ;;
    --ssh-port) ssh_port=${2}; shift 2 ;;
    --) shift; rccltest_args+=("$@"); break ;;
    *) rccltest_args+=("${1}"); shift ;;
  esac
done

# =================================================
# Parse hosts into parallel arrays: node_names[], node_slots[]
# Input format: node01:8,node02:8  (count optional, falls back to ngpu_per_node)
# =================================================
parse_hosts() {
  node_names=()
  node_slots=()

  IFS=',' read -ra entries <<< "${hosts_raw}"
  for entry in "${entries[@]}"; do
    local name="${entry%%:*}"
    local slots="${entry##*:}"
    # If no ':' was present, entry == name == slots
    if [[ "${entry}" != *:* ]]; then
      slots="${ngpu_per_node}"
    fi
    node_names+=("${name}")
    node_slots+=("${slots}")
  done
}

# =================================================
# Run rccl test script
# =================================================
if [ -z "${hosts_raw}" ]; then
  # Run single-node test if --hosts is not set
  echo "[WRAPPER] No compute nodes specified. Running in single-node mode."

  # Default np to ngpu_per_node when not set
  np="${np:-${ngpu_per_node}}"
  echo "Using np=${np}"

  ${mpi_bin} --allow-run-as-root \
    --bind-to none \
    --mca pml ucx \
    --mca osc ucx \
    --mca btl ^vader,tcp,openib,uct \
    --mca coll ^hcoll \
    $(env | grep -E '^(NCCL|RCCL|UCX|HSA)_' | cut -d= -f1 | awk '{print "-x", $1}') \
    -np ${np} \
    "${rccltest_args[@]}"
else
  # Multi-node mode
  echo "[WRAPPER] Running in multi-node mode."

  parse_hosts

  # Build MPI -H string and auto-sum np
  hosts_string=""
  np_sum=0
  for i in "${!node_names[@]}"; do
    hosts_string+="${node_names[$i]}:${node_slots[$i]},"
    (( np_sum += node_slots[$i] ))
  done
  hosts_string="${hosts_string%,}"

  # -np overrides auto-sum if explicitly provided
  np="${np:-${np_sum}}"

  echo "[WRAPPER] MPI hosts: ${hosts_string}"
  echo "[WRAPPER] Total processes (np): ${np}"
  echo "[WRAPPER] Using TCP interface: ${tcp_iface}"
  echo "[WRAPPER] Using SSH port: ${ssh_port}"

  # Copy files to remote nodes (skip current node)
  current_node=$(hostname)
  copyto_hosts=()
  for name in "${node_names[@]}"; do
    if [[ "${name}" != "${current_node}" ]]; then
      copyto_hosts+=("${name}")
    fi
  done

  if [ ${#copyto_hosts[@]} -gt 0 ]; then
    echo "[WRAPPER] Copying files to remote nodes in parallel: ${copyto_hosts[*]}"
    for node in "${copyto_hosts[@]}"; do
      rsync -azP -e "ssh -p ${ssh_port}" ${PWD}/build ${PWD}/scripts ${NCCL_TOPO_FILE} ${NCCL_GRAPH_FILE} ${NCCL_TOPO_MAPPING_FILE} "${node}:${PWD}/" &
      rsync -azP -e "ssh -p ${ssh_port}" /opt/dtk/rccl/lib ${node}:/opt/dtk/rccl/ &
      rsync -azP -e "ssh -p ${ssh_port}" /opt/mpi /opt/ucx ${node}:/opt/ &
    done
    wait
    echo "[WRAPPER] Files synchronized successfully."
  fi

  ${mpi_bin} --allow-run-as-root \
    --prefix ${ompi_prefix} \
    --bind-to none \
    --mca pml ucx \
    --mca btl_tcp_if_include ${tcp_iface} \
    --mca plm_rsh_args "-p ${ssh_port}" \
    $(env | grep -E '^(NCCL|RCCL|UCX|HSA|HIP)_' | cut -d= -f1 | awk '{print "-x", $1}') \
    -x ROCM_PATH -x PATH -x LD_LIBRARY_PATH \
    -np ${np} \
    -H ${hosts_string} \
    "${rccltest_args[@]}"
fi
