Unverified Commit 1269b2e2 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Ensure JAX reference impl uses an accurate backend in our tests (#2322)



Ensure JAX reference impl uses an accurate backend
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 26370b11
......@@ -8,5 +8,6 @@ set -xe
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
......@@ -8,4 +8,5 @@ set -xe
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
# Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment