"...git@developer.sourcefind.cn:OpenDAS/TransformerEngine.git" did not exist on "8ec01e5e77908cab9b8278d87776b2e5253a6a63"
Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
...@@ -45,6 +45,12 @@ jobs: ...@@ -45,6 +45,12 @@ jobs:
|| github.actor == 'jberchtold-nvidia' || github.actor == 'jberchtold-nvidia'
|| github.actor == 'sanandaraj5597' || github.actor == 'sanandaraj5597'
|| github.actor == 'negvet' || github.actor == 'negvet'
|| github.actor == 'zhongbozhu'
|| github.actor == 'kwyss-nvidia'
|| github.actor == 'BestJuly'
|| github.actor == 'xiaopoc'
|| github.actor == 'jreiffers'
|| github.actor == 'lhb8125'
) )
steps: steps:
- name: Check if comment is issued by authorized person - name: Check if comment is issued by authorized person
......
...@@ -145,18 +145,30 @@ Flax ...@@ -145,18 +145,30 @@ Flax
Installation Installation
============ ============
.. installation
Pre-requisites System Requirements
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 12.1+ (CUDA 12.8+ for Blackwell)
* NVIDIA Driver supporting CUDA 12.1 or later
* cuDNN 9.3 or later
Docker * **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere
^^^^^^^^^^^^^^^^^^^^
* **OS:** Linux (official), WSL2 (limited support)
* **Software:**
* CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers
* cuDNN: 9.3+
* Compiler: GCC 9+ or Clang 10+ with C++17 support
* Python: 3.12 recommended
* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+
* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)
Installation Methods
^^^^^^^^^^^^^^^^^^^
Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. `NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively, For example to use the NGC PyTorch container interactively,
...@@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively, ...@@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively,
Where 25.01 (corresponding to January 2025 release) is the container version. Where 25.01 (corresponding to January 2025 release) is the container version.
pip **Benefits of using NGC containers:**
^^^^^^^^^^^^^^^^^^^^
To install the latest stable version of Transformer Engine, * All dependencies pre-installed with compatible versions and optimized configurations
* NGC PyTorch 23.08+ containers include FlashAttention-2
pip Installation
^^^^^^^^^^^^^^^^^^^
**Prerequisites for pip installation:**
* A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed
To install the latest stable version with pip:
.. code-block:: bash .. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable # For PyTorch integration
pip install --no-build-isolation transformer_engine[pytorch]
# For JAX integration
pip install --no-build-isolation transformer_engine[jax]
# For both frameworks
pip install --no-build-isolation transformer_engine[pytorch,jax]
Alternatively, install directly from the GitHub repository:
.. code-block:: bash
This will automatically detect if any supported deep learning frameworks are installed and build pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Transformer Engine support for them. To explicitly specify frameworks, set the environment variable
NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
Alternatively, the package can be directly installed from When installing from GitHub, you can explicitly specify frameworks using the environment variable:
`Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
.. code-block:: bash .. code-block:: bash
pip3 install transformer_engine[pytorch] NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Source Installation
^^^^^^^^^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be Environment Variables
explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). ^^^^^^^^^^^^^^^^^^^
Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX These environment variables can be set before installation to customize the build process:
and PyTorch extensions.
From source * **CUDA_PATH**: Path to CUDA installation
^^^^^^^^^^^ * **CUDNN_PATH**: Path to cuDNN installation
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_. * **CXX**: Path to C++ compiler
* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``)
* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system)
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job
Compiling with FlashAttention-2 Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance. Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.
You can verify which FlashAttention version is being used by setting these environment variables:
.. code-block:: bash
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
Note that NGC PyTorch 23.08+ containers include FlashAttention-2. .. troubleshooting-begin-marker-do-not-remove
Troubleshooting
^^^^^^^^^^^^^^^^^^^
**Common Issues and Solutions:**
1. **ABI Compatibility Issues:**
* **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine
* **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI.
* **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers.
2. **Missing Headers or Libraries:**
* **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.)
* **Solution:** Install missing development packages or set environment variables to point to correct locations:
.. code-block:: bash
export CUDA_PATH=/path/to/cuda
export CUDNN_PATH=/path/to/cudnn
* If CMake can't find a C++ compiler, set the ``CXX`` environment variable.
* Ensure all paths are correctly set before installation.
3. **Build Resource Issues:**
* **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors
* **Solution:** Limit parallel builds:
.. code-block:: bash
MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ...
4. **Verbose Build Logging:**
* For detailed build logs to help diagnose issues:
.. code-block:: bash
cd transformer_engine
pip install -v -v -v --no-build-isolation .
.. troubleshooting-end-marker-do-not-remove
Breaking Changes Breaking Changes
================ ================
......
...@@ -140,6 +140,19 @@ def setup_pytorch_extension( ...@@ -140,6 +140,19 @@ def setup_pytorch_extension(
library_dirs.append(mpi_path / "lib") library_dirs.append(mpi_path / "lib")
libraries.append("mpi") libraries.append("mpi")
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs] include_dirs = [str(path) for path in include_dirs]
......
...@@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr ...@@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr
.. code-block:: bash .. code-block:: bash
pip3 install transformer_engine[pytorch] pip3 install --no-build-isolation transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
...@@ -54,7 +54,7 @@ Execute the following command to install the latest stable version of Transforme ...@@ -54,7 +54,7 @@ Execute the following command to install the latest stable version of Transforme
.. code-block:: bash .. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`).
...@@ -71,7 +71,7 @@ Execute the following command to install the latest development build of Transfo ...@@ -71,7 +71,7 @@ Execute the following command to install the latest development build of Transfo
.. code-block:: bash .. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@main pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`. This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`.
...@@ -79,7 +79,7 @@ In order to install a specific PR, execute (after changing NNN to the PR number) ...@@ -79,7 +79,7 @@ In order to install a specific PR, execute (after changing NNN to the PR number)
.. code-block:: bash .. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge
Installation (from source) Installation (from source)
...@@ -93,8 +93,8 @@ Execute the following commands to install Transformer Engine from source: ...@@ -93,8 +93,8 @@ Execute the following commands to install Transformer Engine from source:
git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git
cd TransformerEngine cd TransformerEngine
export NVTE_FRAMEWORK=pytorch # Optionally set framework export NVTE_FRAMEWORK=pytorch # Optionally set framework
pip3 install . # Build and install pip3 install --no-build-isolation . # Build and install
If the Git repository has already been cloned, make sure to also clone the submodules: If the Git repository has already been cloned, make sure to also clone the submodules:
...@@ -106,10 +106,14 @@ Extra dependencies for testing can be installed by setting the "test" option: ...@@ -106,10 +106,14 @@ Extra dependencies for testing can be installed by setting the "test" option:
.. code-block:: bash .. code-block:: bash
pip3 install .[test] pip3 install --no-build-isolation .[test]
To build the C++ extensions with debug symbols, e.g. with the `-g` flag: To build the C++ extensions with debug symbols, e.g. with the `-g` flag:
.. code-block:: bash .. code-block:: bash
pip3 install . --global-option=--debug pip3 install --no-build-isolation . --global-option=--debug
.. include:: ../README.rst
:start-after: troubleshooting-begin-marker-do-not-remove
:end-before: troubleshooting-end-marker-do-not-remove
...@@ -4,20 +4,54 @@ ...@@ -4,20 +4,54 @@
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
for i in $(seq 0 $(($NUM_GPUS-1))) # Define the test cases to run
do TEST_CASES=(
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i & "test_te_bf16"
done "test_te_delayed_scaling_fp8"
wait "test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
)
for i in $(seq 0 $(($NUM_GPUS-1))) echo
do echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i &
done HAS_FAILURE=0 # Global failure flag
wait
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
echo "=== Starting test: $TEST_CASE ..."
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log"
for i in $(seq 0 $(($NUM_GPUS-1))) # Run pytest and redirect stdout and stderr to the log file
do pytest -c "$TE_PATH/tests/jax/pytest.ini" \
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i & -vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
done
# Wait for the process to finish
wait
# Check and print the log content accordingly
if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
tail -n +7 "${TEST_CASE}_gpu_0.log"
elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "Invalid ${TEST_CASE}_gpu_0.log"
fi
# Remove the log file after processing it
rm ${TEST_CASE}_gpu_*.log
done done
wait
exit $HAS_FAILURE
...@@ -57,13 +57,14 @@ class Net(nn.Module): ...@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding", self_attn_mask_type="padding",
enable_relative_embedding=False, enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral, enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
) )
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout) x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1) x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral: if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device. # Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint( x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None) x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
) )
...@@ -257,6 +258,8 @@ def get_state_sharding(state, params_sharding): ...@@ -257,6 +258,8 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count() num_gpu = jax.local_device_count()
...@@ -440,6 +443,9 @@ def encoder_parser(args): ...@@ -440,6 +443,9 @@ def encoder_parser(args):
parser.add_argument( parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism." "--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -447,19 +453,18 @@ def encoder_parser(args): ...@@ -447,19 +453,18 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
"""Test Transformer Engine with BF16""" """Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
...@@ -467,7 +472,7 @@ class TestEncoder(unittest.TestCase): ...@@ -467,7 +472,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
...@@ -475,14 +480,14 @@ class TestEncoder(unittest.TestCase): ...@@ -475,14 +480,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self): def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP""" """Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True self.args.enable_sp = True
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason) @unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self): def test_te_delayed_scaling_fp8_with_sp(self):
...@@ -491,7 +496,7 @@ class TestEncoder(unittest.TestCase): ...@@ -491,7 +496,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling" self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self): def test_te_mxfp8_with_sp(self):
...@@ -500,7 +505,35 @@ class TestEncoder(unittest.TestCase): ...@@ -500,7 +505,35 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling" self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76 assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding): ...@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len) train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count() num_gpu = jax.local_device_count()
...@@ -409,6 +410,9 @@ def encoder_parser(args): ...@@ -409,6 +410,9 @@ def encoder_parser(args):
default="DelayedScaling", default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)", help="Use FP8 recipe (default: DelayedScaling)",
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -416,13 +420,12 @@ def encoder_parser(args): ...@@ -416,13 +420,12 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"]) self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
...@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase): ...@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding): ...@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args): def train_and_evaluate(args):
"""Execute model training and evaluation loop.""" """Execute model training and evaluation loop."""
print(args) print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0: if args.process_id == 0:
nltk.download("punkt_tab") nltk.download("punkt_tab")
...@@ -565,6 +566,9 @@ def encoder_parser(args): ...@@ -565,6 +566,9 @@ def encoder_parser(args):
default=0, default=0,
help="the ID number of the current process (default: 0)", help="the ID number of the current process (default: 0)",
) )
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args) return parser.parse_args(args)
...@@ -573,7 +577,7 @@ def encoder_parser(args): ...@@ -573,7 +577,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
def exec(self, use_fp8, fp8_recipe): def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing""" """Run 3 epochs for testing"""
args = encoder_parser([]) args = encoder_parser([])
...@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase): ...@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase):
args.num_process = num_gpu args.num_process = num_gpu
args.process_id = self.process_id args.process_id = self.process_id
args.fp8_recipe = fp8_recipe args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy
return train_and_evaluate(args) return train_and_evaluate(args)
...@@ -604,7 +609,7 @@ class TestEncoder(unittest.TestCase): ...@@ -604,7 +609,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.755 assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase): ...@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling") result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -327,13 +327,12 @@ def encoder_parser(args): ...@@ -327,13 +327,12 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase): class TestEncoder(unittest.TestCase):
"""Encoder unittests""" """Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod def setUp(self):
def setUpClass(cls): """Run 3 epochs for testing"""
"""Run 4 epochs for testing""" self.args = encoder_parser(["--epochs", "3"])
cls.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self): def test_te_bf16(self):
......
...@@ -306,8 +306,8 @@ def mnist_parser(args): ...@@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase): class TestMNIST(unittest.TestCase):
"""MNIST unittests""" """MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING) is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING) is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
......
...@@ -17,13 +17,17 @@ RET=0 ...@@ -17,13 +17,17 @@ RET=0
FAILED_CASES="" FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py" wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh" . $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
......
...@@ -2,6 +2,8 @@ ...@@ -2,6 +2,8 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
set -x
function error_exit() { function error_exit() {
echo "Error: $1" echo "Error: $1"
exit 1 exit 1
...@@ -18,20 +20,23 @@ FAILED_CASES="" ...@@ -18,20 +20,23 @@ FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk" pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls # Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py without TE custom calls" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements" pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -19,27 +19,32 @@ FAILED_CASES="" ...@@ -19,27 +19,32 @@ FAILED_CASES=""
set -x set -x
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -5,5 +5,7 @@ ...@@ -5,5 +5,7 @@
set -xe set -xe
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_* python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
...@@ -17,16 +17,19 @@ RET=0 ...@@ -17,16 +17,19 @@ RET=0
FAILED_CASES="" FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" # python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -5,9 +5,11 @@ ...@@ -5,9 +5,11 @@
set -x set -x
: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder} : ${THUNDER_PATH:=/opt/pytorch/lightning-thunder}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.1.1 pytest-benchmark==5.1.0 pip3 install pytest==8.1.1 pytest-benchmark==5.1.0
python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py
# Check return code # Check return code
# Note: Return code 5 is fine. Lightning tests are skipped on systems # Note: Return code 5 is fine. Lightning tests are skipped on systems
......
...@@ -2,22 +2,45 @@ ...@@ -2,22 +2,45 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
set -xe set -x
pip install "nltk>=3.8.2" function error_exit() {
pip install pytest==8.2.1 echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls # Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
# Make encoder tests to have run-to-run deterministic to have the stable CI results # Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
...@@ -5,6 +5,8 @@ ...@@ -5,6 +5,8 @@
set -e set -e
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 pip3 install pytest==8.2.1
...@@ -37,6 +39,6 @@ do ...@@ -37,6 +39,6 @@ do
fi fi
# Run tests # Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done done
...@@ -80,6 +80,12 @@ def setup_common_extension() -> CMakeExtension: ...@@ -80,6 +80,12 @@ def setup_common_extension() -> CMakeExtension:
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1" ), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON") cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
...@@ -125,12 +131,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -125,12 +131,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks: if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"]) install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build # Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton") # install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"]) test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks: if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"]) install_reqs.extend(["jax", "flax>=0.7.1"])
# test_reqs.extend(["numpy", "praxis"])
test_reqs.extend(["numpy"]) test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]] return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment