run_test_cgemm.sh 3.22 KB
Newer Older
Phuong Nguyen's avatar
Phuong Nguyen committed
1
2
3
4
5
6
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}

7
8
9
10
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

Phuong Nguyen's avatar
Phuong Nguyen committed
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
NVLINK_EXIT_CODE=$?

# Check if command failed OR output indicates no NVLINK
if [ $NVLINK_EXIT_CODE -ne 0 ] || [[ "$NVLINK_OUTPUT" == *"not supported"* ]] || [[ "$NVLINK_OUTPUT" == *"No devices"* ]] || [ -z "$NVLINK_OUTPUT" ]; then
  echo "NVLINK is not supported on this platform"
  echo "Collective GEMM tests require NVLINK connectivity"
  echo "SKIPPING all tests"
  exit 0
else
  echo "NVLINK support detected"
fi

# Define the test files to run
TEST_FILES=(
"test_gemm.py"
"test_dense_grad.py"
"test_layernorm_mlp_grad.py"
)

echo
echo "*** Executing tests in examples/jax/collective_gemm/ ***"

HAS_FAILURE=0  # Global failure flag
PIDS=()  # Array to store all process PIDs

# Cleanup function to kill all processes
cleanup() {
  for pid in "${PIDS[@]}"; do
    if kill -0 "$pid" 2>/dev/null; then
      echo "Killing process $pid"
      kill -TERM "$pid" 2>/dev/null || true
    fi
  done
  # Wait a bit and force kill if needed
  sleep 2
  for pid in "${PIDS[@]}"; do
    if kill -0 "$pid" 2>/dev/null; then
      echo "Force killing process $pid"
      kill -KILL "$pid" 2>/dev/null || true
    fi
  done
}

# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM

# Run each test file across all GPUs
for TEST_FILE in "${TEST_FILES[@]}"; do
  echo
  echo "=== Starting test file: $TEST_FILE ..."

  # Clear PIDs array for this test file
  PIDS=()

  for i in $(seq 0 $(($NUM_GPUS - 1))); do
    # Define output file for logs
    LOG_FILE="${TEST_FILE}_gpu_${i}.log"

    if [ $i -eq 0 ]; then
      # For process 0: show live output AND save to log file using tee
      echo "=== Live output from process 0 ==="
      pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
76
77
        -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
        "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
Phuong Nguyen's avatar
Phuong Nguyen committed
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
        --num-processes=$NUM_GPUS \
        --process-id=$i 2>&1 | tee "$LOG_FILE" &
      PID=$!
      PIDS+=($PID)
    else
      # For other processes: redirect to log files only
      pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
        -vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
        --num-processes=$NUM_GPUS \
        --process-id=$i > "$LOG_FILE" 2>&1 &
      PID=$!
      PIDS+=($PID)
    fi
  done

  # Wait for all processes to finish
  wait

  # Check and print the log content from process 0 (now has log file thanks to tee)
  if grep -q "SKIPPED" "${TEST_FILE}_gpu_0.log"; then
    echo "... $TEST_FILE SKIPPED"
  elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
    echo "... $TEST_FILE FAILED"
    HAS_FAILURE=1
102
  elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
Phuong Nguyen's avatar
Phuong Nguyen committed
103
    echo "... $TEST_FILE PASSED"
104
105
106
  else
    echo "... $TEST_FILE INVALID"
    HAS_FAILURE=1
Phuong Nguyen's avatar
Phuong Nguyen committed
107
108
109
110
111
112
113
114
115
116
117
118
119
  fi

  # Remove the log files after processing them
  wait
  rm ${TEST_FILE}_gpu_*.log
done

wait

# Final cleanup (trap will also call cleanup on exit)
cleanup

exit $HAS_FAILURE