Unverified Commit ae572af0 authored by Phuong Nguyen's avatar Phuong Nguyen Committed by GitHub
Browse files

[JAX] Fixes for L0_jax_distributed_unittest (#1884)



* include previously accidentally excluded tests

* Execute run_test_multiprocessing_encoder with nested bash + exit code for inner bash shell

* Adapt run_test_multiprocessing to handle segfault
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
parent ba8c923e
......@@ -6,13 +6,13 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
# Define the test cases to run
TEST_CASES=(
# "test_te_bf16"
"test_te_bf16"
"test_te_delayed_scaling_fp8"
# "test_te_current_scaling_fp8"
# "test_te_mxfp8"
# "test_te_bf16_shardy"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
# "test_te_current_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
)
echo
......@@ -40,21 +40,20 @@ for TEST_CASE in "${TEST_CASES[@]}"; do
wait
tail -n +7 "${TEST_CASE}_gpu_0.log"
tail -n +7 "${TEST_CASE}_gpu_0.log"
# Check and print the log content accordingly
if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
if grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "Invalid ${TEST_CASE}_gpu_0.log"
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
fi
# Remove the log file after processing it
wait
rm ${TEST_CASE}_gpu_*.log
done
wait
exit $HAS_FAILURE
......@@ -24,11 +24,11 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
# wait
# python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
# wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
TE_PATH=$TE_PATH bash $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment