Unverified Commit ea43b18e authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Fix JAX distributed unit tests (#521)



* Remove assertion for NO_MASK
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix JAX distributed unit tests name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 6159af4e
...@@ -5,5 +5,5 @@ ...@@ -5,5 +5,5 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_custom_ops.py pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
......
...@@ -54,9 +54,6 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed ...@@ -54,9 +54,6 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed
""" """
Self fused attention wrapper Self fused attention wrapper
""" """
assert attn_mask_type is not AttnMaskType.NO_MASK, \
"Currently not support AttnMaskType.NO_MASK."
output = _self_fused_attn(qkv, output = _self_fused_attn(qkv,
bias, bias,
mask, mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment