mpirun_rccltest 4.06 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
#!/bin/bash
set -e

# =================================================
# Helper functions
# =================================================
help() {
    cat << EOF
RCCL Tests MPI run helper script
Usage: $(basename "$0") [OPTIONS]

OPTIONS:
    -h, --help    Show this help message and exit
    -np           Total number of processes (default: sum of per-node counts in --hosts)
    -H, --hosts   Comma-separated list of nodes with optional process count per node
                  Format: node01:8,node02:8
                  If count is omitted, falls back to auto-detected GPU count per node.
    --tcp-iface   TCP interface to use for communication (default: ${tcp_iface})
    --ssh-port    SSH port to use for remote connections (default: ${ssh_port})
EOF
}

# =================================================
# Global variables
# =================================================
np=
hosts_raw=
tcp_iface=p14p2
ssh_port=3333

rccltest_args=()
mpi_bin=/opt/mpi/bin/mpirun
ompi_prefix=/opt/mpi

# Detect the number of GPUs per node (used as fallback when count is not specified in --hosts)
ngpu_per_node=$(hy-smi --showid 2>/dev/null | grep -ic "Device ID")
if [[ -z "${ngpu_per_node}" || "${ngpu_per_node}" -eq 0 ]]; then
  echo "[WRAPPER] Failed to get the number of GPUs per node via hy-smi. Defaulting to 8."
  ngpu_per_node=8
else
  echo "[WRAPPER] Detected ${ngpu_per_node} GPUs per node."
fi

# =================================================
# Parameter parsing
# =================================================
while [[ $# -gt 0 ]]; do
  case "${1}" in
    -h|--help) help; exit 0 ;;
    -np) np=${2}; shift 2 ;;
    -H|--hosts) hosts_raw=${2}; shift 2 ;;
    --tcp-iface) tcp_iface=${2}; shift 2 ;;
    --ssh-port) ssh_port=${2}; shift 2 ;;
    --) shift; rccltest_args+=("$@"); break ;;
    *) rccltest_args+=("${1}"); shift ;;
  esac
done

# =================================================
# Parse hosts into parallel arrays: node_names[], node_slots[]
# Input format: node01:8,node02:8  (count optional, falls back to ngpu_per_node)
# =================================================
parse_hosts() {
  node_names=()
  node_slots=()

  IFS=',' read -ra entries <<< "${hosts_raw}"
  for entry in "${entries[@]}"; do
    local name="${entry%%:*}"
    local slots="${entry##*:}"
    # If no ':' was present, entry == name == slots
    if [[ "${entry}" != *:* ]]; then
      slots="${ngpu_per_node}"
    fi
    node_names+=("${name}")
    node_slots+=("${slots}")
  done
}

# =================================================
# Run rccl test script
# =================================================
if [ -z "${hosts_raw}" ]; then
  # Run single-node test if --hosts is not set
  echo "[WRAPPER] No compute nodes specified. Running in single-node mode."

  # Default np to ngpu_per_node when not set
  np="${np:-${ngpu_per_node}}"
  echo "Using np=${np}"

  ${mpi_bin} --allow-run-as-root \
    --bind-to none \
    --mca pml ucx \
    --mca osc ucx \
    --mca btl ^vader,tcp,openib,uct \
    --mca coll ^hcoll \
    $(env | grep -E '^(NCCL|RCCL|UCX|HSA)_' | cut -d= -f1 | awk '{print "-x", $1}') \
    -np ${np} \
    "${rccltest_args[@]}"
else
  # Multi-node mode
  echo "[WRAPPER] Running in multi-node mode."

  parse_hosts

  # Build MPI -H string and auto-sum np
  hosts_string=""
  np_sum=0
  for i in "${!node_names[@]}"; do
    hosts_string+="${node_names[$i]}:${node_slots[$i]},"
    (( np_sum += node_slots[$i] ))
  done
  hosts_string="${hosts_string%,}"

  # -np overrides auto-sum if explicitly provided
  np="${np:-${np_sum}}"

  echo "[WRAPPER] MPI hosts: ${hosts_string}"
  echo "[WRAPPER] Total processes (np): ${np}"
  echo "[WRAPPER] Using TCP interface: ${tcp_iface}"
  echo "[WRAPPER] Using SSH port: ${ssh_port}"

  ${mpi_bin} --allow-run-as-root \
    --prefix ${ompi_prefix} \
    --bind-to none \
    --mca pml ucx \
    --mca btl_tcp_if_include ${tcp_iface} \
    --mca plm_rsh_args "-p ${ssh_port}" \
    $(env | grep -E '^(NCCL|RCCL|UCX|HSA|HIP)_' | cut -d= -f1 | awk '{print "-x", $1}') \
    -x ROCM_PATH -x PATH -x LD_LIBRARY_PATH \
    -np ${np} \
    -H ${hosts_string} \
    "${rccltest_args[@]}"
fi