Unverified Commit a91e4585 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Add xml export for `test_multiprocessing_encoder` and `test_cgemm` (#2210)



* add xml export for test_multiprocessing_encoder and test_cgemm
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent d75bf43f
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
# Check if NVLINK is supported before running tests # Check if NVLINK is supported before running tests
echo "*** Checking NVLINK support***" echo "*** Checking NVLINK support***"
NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1) NVLINK_OUTPUT=$(nvidia-smi nvlink --status 2>&1)
...@@ -69,7 +73,8 @@ for TEST_FILE in "${TEST_FILES[@]}"; do ...@@ -69,7 +73,8 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
# For process 0: show live output AND save to log file using tee # For process 0: show live output AND save to log file using tee
echo "=== Live output from process 0 ===" echo "=== Live output from process 0 ==="
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \ -vs --junitxml=$XML_LOG_DIR/collective_gemm_${TEST_FILE}.xml \
"$TE_PATH/examples/jax/collective_gemm/$TEST_FILE" \
--num-processes=$NUM_GPUS \ --num-processes=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" & --process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$! PID=$!
...@@ -94,8 +99,11 @@ for TEST_FILE in "${TEST_FILES[@]}"; do ...@@ -94,8 +99,11 @@ for TEST_FILE in "${TEST_FILES[@]}"; do
elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then elif grep -q "FAILED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE FAILED" echo "... $TEST_FILE FAILED"
HAS_FAILURE=1 HAS_FAILURE=1
else elif grep -q "PASSED" "${TEST_FILE}_gpu_0.log"; then
echo "... $TEST_FILE PASSED" echo "... $TEST_FILE PASSED"
else
echo "... $TEST_FILE INVALID"
HAS_FAILURE=1
fi fi
# Remove the log files after processing them # Remove the log files after processing them
......
...@@ -15,11 +15,37 @@ TEST_CASES=( ...@@ -15,11 +15,37 @@ TEST_CASES=(
"test_te_current_scaling_fp8_shardy" "test_te_current_scaling_fp8_shardy"
) )
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
echo echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***" echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE=0 # Global failure flag HAS_FAILURE=0 # Global failure flag
PIDS=() # Array to store all process PIDs
# Cleanup function to kill all processes
cleanup() {
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Killing process $pid"
kill -TERM "$pid" 2>/dev/null || true
fi
done
# Wait a bit and force kill if needed
sleep 2
for pid in "${PIDS[@]}"; do
if kill -0 "$pid" 2>/dev/null; then
echo "Force killing process $pid"
kill -KILL "$pid" 2>/dev/null || true
fi
done
}
# Set up signal handlers to cleanup on exit
trap cleanup EXIT INT TERM
# Run each test case across all GPUs # Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do for TEST_CASE in "${TEST_CASES[@]}"; do
echo echo
...@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -29,25 +55,40 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
# Define output file for logs # Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log" LOG_FILE="${TEST_CASE}_gpu_${i}.log"
# Run pytest and redirect stdout and stderr to the log file # For process 0: show live output AND save to log file using tee
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \ if [ $i -eq 0 ]; then
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \ echo "=== Live output from process 0 ==="
--num-process=$NUM_GPUS \ pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
--process-id=$i > "$LOG_FILE" 2>&1 & -vs --junitxml=$XML_LOG_DIR/multiprocessing_encoder_${TEST_CASE}.xml \
done "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i 2>&1 | tee "$LOG_FILE" &
PID=$!
PIDS+=($PID)
else
pytest -s -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
PID=$!
PIDS+=($PID)
fi
done
# Wait for the process to finish # Wait for the process to finish
wait wait
tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly # Check and print the log content accordingly
if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED" echo "... $TEST_CASE SKIPPED"
elif grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE FAILED"
HAS_FAILURE=1
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED" echo "... $TEST_CASE PASSED"
else else
echo "... $TEST_CASE INVALID"
HAS_FAILURE=1 HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
fi fi
# Remove the log file after processing it # Remove the log file after processing it
...@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do ...@@ -56,4 +97,8 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
done done
wait wait
# Final cleanup (trap will also call cleanup on exit)
cleanup
exit $HAS_FAILURE exit $HAS_FAILURE
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment