Unverified Commit ea43b18e authored by zlsh80826's avatar zlsh80826 Committed by GitHub
Browse files

[JAX] Fix JAX distributed unit tests (#521)



* Remove assertion for NO_MASK
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

* Fix JAX distributed unit tests name
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>

---------
Signed-off-by: default avatarReese Wang <rewang@nvidia.com>
parent 6159af4e
......@@ -5,5 +5,5 @@
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_custom_ops.py
pytest -Wignore -v $TE_PATH/tests/jax/test_distributed_*
......@@ -5,7 +5,7 @@
set -xe
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/jax
pytest -Wignore -v $TE_PATH/tests/jax -k 'not distributed'
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
......
......@@ -54,9 +54,6 @@ def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, seed
"""
Self fused attention wrapper
"""
assert attn_mask_type is not AttnMaskType.NO_MASK, \
"Currently not support AttnMaskType.NO_MASK."
output = _self_fused_attn(qkv,
bias,
mask,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment