"...include/git@developer.sourcefind.cn:yangql/googletest.git" did not exist on "3fa7f983c69f780378b4d1ad44d36030ca951ba6"
Unverified Commit c63766d4 authored by Alp Dener's avatar Alp Dener Committed by GitHub
Browse files

[JAX] Fixing CI failure due to incorrect use of `static_argnums` in jax.jit (#785)



* fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* propagating the fix to the JAX mnist example
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* fixed missing space ibetween flags i QAA scripts
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

* added TE warnings into the ignore list
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>

---------
Signed-off-by: default avatarAlp Dener <adener@nvidia.com>
parent fad0e273
...@@ -55,7 +55,7 @@ class Net(nn.Module): ...@@ -55,7 +55,7 @@ class Net(nn.Module):
return x return x
@partial(jax.jit, static_argnums=6) @partial(jax.jit, static_argnums=(0, 1, 2, 3, 4, 5))
def train_step(state, inputs, masks, labels, var_collect, rngs): def train_step(state, inputs, masks, labels, var_collect, rngs):
"""Computes gradients, loss and accuracy for a single batch.""" """Computes gradients, loss and accuracy for a single batch."""
......
...@@ -74,7 +74,7 @@ def apply_model(state, images, labels, var_collect, rngs=None): ...@@ -74,7 +74,7 @@ def apply_model(state, images, labels, var_collect, rngs=None):
return grads, loss, accuracy return grads, loss, accuracy
@partial(jax.jit, static_argnums=2) @partial(jax.jit, static_argnums=(0, 1))
def update_model(state, grads): def update_model(state, grads):
"""Update model params and FP8 meta.""" """Update model params and FP8 meta."""
state = state.apply_gradients(grads=grads[PARAMS_KEY]) state = state.apply_gradients(grads=grads[PARAMS_KEY])
......
...@@ -5,14 +5,15 @@ ...@@ -5,14 +5,15 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pytest -Wignore -v $TE_PATH/examples/jax/mnist pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -Wignore -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder --ignore=$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
pytest -Wignore -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py
...@@ -5,5 +5,5 @@ ...@@ -5,5 +5,5 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_* pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
[pytest]
filterwarnings=
ignore:sharding_type of.*:DeprecationWarning
ignore:major_sharding_type of.*:DeprecationWarning
ignore:Fused attention is not enabled.*:UserWarning
ignore:The hookimpl.*:DeprecationWarning
ignore:xmap is an experimental feature and probably has bugs!
ignore:the imp module is deprecated in favour of importlib.*:DeprecationWarning
ignore:can't resolve package from __spec__ or __package__:ImportWarning
ignore:Using or importing the ABCs.*:DeprecationWarning
ignore:numpy.ufunc size changed
ignore:.*experimental feature
ignore:The distutils.* is deprecated.*:DeprecationWarning
ignore:backend and device argument on jit is deprecated.*:DeprecationWarning
ignore:ml_dtypes.float8_e4m3b11 is deprecated.
ignore:np.find_common_type is deprecated.*:DeprecationWarning
ignore:jax.numpy.in1d is deprecated.*:DeprecationWarning
ignore:The numpy.array_api submodule is still experimental.*:UserWarning
ignore:case not machine-readable.*:UserWarning
ignore:not machine-readable.*:UserWarning
ignore:Special cases found for .* but none were parsed.*:UserWarning
ignore:jax.extend.mlir.dialects.mhlo is deprecated.*:DeprecationWarning
ignore:jax.experimental.maps and .* are deprecated.*:DeprecationWarning
ignore:The host_callback APIs are deprecated .*:DeprecationWarning
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment