Commit ab3e5a92 authored by yuguo's avatar yuguo
Browse files

Merge commit '04c730c0' of...

Merge commit '04c730c0' of https://github.com/NVIDIA/TransformerEngine
parents a8d19fd9 04c730c0
......@@ -45,6 +45,12 @@ jobs:
|| github.actor == 'jberchtold-nvidia'
|| github.actor == 'sanandaraj5597'
|| github.actor == 'negvet'
|| github.actor == 'zhongbozhu'
|| github.actor == 'kwyss-nvidia'
|| github.actor == 'BestJuly'
|| github.actor == 'xiaopoc'
|| github.actor == 'jreiffers'
|| github.actor == 'lhb8125'
)
steps:
- name: Check if comment is issued by authorized person
......
......@@ -145,18 +145,30 @@ Flax
Installation
============
.. installation
Pre-requisites
System Requirements
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 12.1+ (CUDA 12.8+ for Blackwell)
* NVIDIA Driver supporting CUDA 12.1 or later
* cuDNN 9.3 or later
Docker
^^^^^^^^^^^^^^^^^^^^
* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere
* **OS:** Linux (official), WSL2 (limited support)
* **Software:**
* CUDA: 12.1+ (Hopper/Ada/Ampere), 12.8+ (Blackwell) with compatible NVIDIA drivers
* cuDNN: 9.3+
* Compiler: GCC 9+ or Clang 10+ with C++17 support
* Python: 3.12 recommended
* **Source Build Requirements:** CMake 3.18+, Ninja, Git 2.17+, pybind11 2.6.0+
* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)
Installation Methods
^^^^^^^^^^^^^^^^^^^
Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,
......@@ -167,41 +179,116 @@ For example to use the NGC PyTorch container interactively,
Where 25.01 (corresponding to January 2025 release) is the container version.
pip
^^^^^^^^^^^^^^^^^^^^
To install the latest stable version of Transformer Engine,
**Benefits of using NGC containers:**
* All dependencies pre-installed with compatible versions and optimized configurations
* NGC PyTorch 23.08+ containers include FlashAttention-2
pip Installation
^^^^^^^^^^^^^^^^^^^
**Prerequisites for pip installation:**
* A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed
To install the latest stable version with pip:
.. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
# For PyTorch integration
pip install --no-build-isolation transformer_engine[pytorch]
# For JAX integration
pip install --no-build-isolation transformer_engine[jax]
# For both frameworks
pip install --no-build-isolation transformer_engine[pytorch,jax]
Alternatively, install directly from the GitHub repository:
.. code-block:: bash
This will automatically detect if any supported deep learning frameworks are installed and build
Transformer Engine support for them. To explicitly specify frameworks, set the environment variable
NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Alternatively, the package can be directly installed from
`Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.
When installing from GitHub, you can explicitly specify frameworks using the environment variable:
.. code-block:: bash
pip3 install transformer_engine[pytorch]
NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
Source Installation
^^^^^^^^^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be
explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]).
Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX
and PyTorch extensions.
Environment Variables
^^^^^^^^^^^^^^^^^^^
These environment variables can be set before installation to customize the build process:
From source
^^^^^^^^^^^
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.
* **CUDA_PATH**: Path to CUDA installation
* **CUDNN_PATH**: Path to cuDNN installation
* **CXX**: Path to C++ compiler
* **NVTE_FRAMEWORK**: Comma-separated list of frameworks to build for (e.g., ``pytorch,jax``)
* **MAX_JOBS**: Limit number of parallel build jobs (default varies by system)
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job
Compiling with FlashAttention-2
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine release v0.11.0 added support for FlashAttention-2 in PyTorch for improved performance.
Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.
You can verify which FlashAttention version is being used by setting these environment variables:
.. code-block:: bash
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python your_script.py
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
.. troubleshooting-begin-marker-do-not-remove
Troubleshooting
^^^^^^^^^^^^^^^^^^^
**Common Issues and Solutions:**
1. **ABI Compatibility Issues:**
* **Symptoms:** ``ImportError`` with undefined symbols when importing transformer_engine
* **Solution:** Ensure PyTorch and Transformer Engine are built with the same C++ ABI setting. Rebuild PyTorch from source with matching ABI.
* **Context:** If you're using PyTorch built with a different C++ ABI than your system's default, you may encounter these undefined symbol errors. This is particularly common with pip-installed PyTorch outside of containers.
2. **Missing Headers or Libraries:**
* **Symptoms:** CMake errors about missing headers (``cudnn.h``, ``cublas_v2.h``, ``filesystem``, etc.)
* **Solution:** Install missing development packages or set environment variables to point to correct locations:
.. code-block:: bash
export CUDA_PATH=/path/to/cuda
export CUDNN_PATH=/path/to/cudnn
* If CMake can't find a C++ compiler, set the ``CXX`` environment variable.
* Ensure all paths are correctly set before installation.
3. **Build Resource Issues:**
* **Symptoms:** Compilation hangs, system freezes, or out-of-memory errors
* **Solution:** Limit parallel builds:
.. code-block:: bash
MAX_JOBS=1 NVTE_BUILD_THREADS_PER_JOB=1 pip install ...
4. **Verbose Build Logging:**
* For detailed build logs to help diagnose issues:
.. code-block:: bash
cd transformer_engine
pip install -v -v -v --no-build-isolation .
.. troubleshooting-end-marker-do-not-remove
Breaking Changes
================
......
......@@ -140,6 +140,19 @@ def setup_pytorch_extension(
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
nvshmem_home = Path(os.getenv("NVSHMEM_HOME"))
include_dirs.append(nvshmem_home / "include")
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
......
......@@ -34,7 +34,7 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr
.. code-block:: bash
pip3 install transformer_engine[pytorch]
pip3 install --no-build-isolation transformer_engine[pytorch]
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
......@@ -54,7 +54,7 @@ Execute the following command to install the latest stable version of Transforme
.. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`).
......@@ -71,7 +71,7 @@ Execute the following command to install the latest development build of Transfo
.. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@main
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@main
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable `NVTE_FRAMEWORK` to a comma-separated list (e.g. `NVTE_FRAMEWORK=jax,pytorch`). To only build the framework-agnostic C++ API, set `NVTE_FRAMEWORK=none`.
......@@ -79,7 +79,7 @@ In order to install a specific PR, execute (after changing NNN to the PR number)
.. code-block:: bash
pip3 install git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge
pip3 install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@refs/pull/NNN/merge
Installation (from source)
......@@ -93,8 +93,8 @@ Execute the following commands to install Transformer Engine from source:
git clone --branch stable --recursive https://github.com/NVIDIA/TransformerEngine.git
cd TransformerEngine
export NVTE_FRAMEWORK=pytorch # Optionally set framework
pip3 install . # Build and install
export NVTE_FRAMEWORK=pytorch # Optionally set framework
pip3 install --no-build-isolation . # Build and install
If the Git repository has already been cloned, make sure to also clone the submodules:
......@@ -106,10 +106,14 @@ Extra dependencies for testing can be installed by setting the "test" option:
.. code-block:: bash
pip3 install .[test]
pip3 install --no-build-isolation .[test]
To build the C++ extensions with debug symbols, e.g. with the `-g` flag:
.. code-block:: bash
pip3 install . --global-option=--debug
pip3 install --no-build-isolation . --global-option=--debug
.. include:: ../README.rst
:start-after: troubleshooting-begin-marker-do-not-remove
:end-before: troubleshooting-end-marker-do-not-remove
......@@ -4,20 +4,54 @@
NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_bf16 --num-process=$NUM_GPUS --process-id=$i &
done
wait
# Define the test cases to run
TEST_CASES=(
"test_te_bf16"
"test_te_delayed_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
)
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_delayed_scaling_fp8 --num-process=$NUM_GPUS --process-id=$i &
done
wait
echo
echo "*** Executing tests in examples/jax/encoder/test_multiprocessing_encoder.py ***"
HAS_FAILURE=0 # Global failure flag
# Run each test case across all GPUs
for TEST_CASE in "${TEST_CASES[@]}"; do
echo
echo "=== Starting test: $TEST_CASE ..."
for i in $(seq 0 $(($NUM_GPUS - 1))); do
# Define output file for logs
LOG_FILE="${TEST_CASE}_gpu_${i}.log"
for i in $(seq 0 $(($NUM_GPUS-1)))
do
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::test_te_mxfp8 --num-process=$NUM_GPUS --process-id=$i &
# Run pytest and redirect stdout and stderr to the log file
pytest -c "$TE_PATH/tests/jax/pytest.ini" \
-vs "$TE_PATH/examples/jax/encoder/test_multiprocessing_encoder.py::TestEncoder::$TEST_CASE" \
--num-process=$NUM_GPUS \
--process-id=$i > "$LOG_FILE" 2>&1 &
done
# Wait for the process to finish
wait
# Check and print the log content accordingly
if grep -q "FAILED" "${TEST_CASE}_gpu_0.log"; then
HAS_FAILURE=1
echo "... $TEST_CASE FAILED"
tail -n +7 "${TEST_CASE}_gpu_0.log"
elif grep -q "SKIPPED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE SKIPPED"
elif grep -q "PASSED" "${TEST_CASE}_gpu_0.log"; then
echo "... $TEST_CASE PASSED"
else
echo "Invalid ${TEST_CASE}_gpu_0.log"
fi
# Remove the log file after processing it
rm ${TEST_CASE}_gpu_*.log
done
wait
exit $HAS_FAILURE
......@@ -57,13 +57,14 @@ class Net(nn.Module):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)
x = x.reshape(x.shape[0], -1)
if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
......@@ -257,6 +258,8 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
......@@ -440,6 +443,9 @@ def encoder_parser(args):
parser.add_argument(
"--enable-sp", action="store_true", default=False, help="Enable sequence parallelism."
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -447,19 +453,18 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
def setUp(self):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
......@@ -467,7 +472,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
......@@ -475,14 +480,14 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
......@@ -491,7 +496,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
......@@ -500,7 +505,35 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8 + SP"""
self.args.enable_shardy = True
self.args.enable_sp = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.785
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
......
......@@ -238,6 +238,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
train_ds, test_ds, num_embed = get_datasets(args.max_seq_len)
num_gpu = jax.local_device_count()
......@@ -409,6 +410,9 @@ def encoder_parser(args):
default="DelayedScaling",
help="Use FP8 recipe (default: DelayedScaling)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -416,13 +420,12 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
def setUp(self):
"""Run 3 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
......@@ -446,6 +449,24 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
self.args.enable_shardy = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -343,6 +343,7 @@ def get_state_sharding(state, params_sharding):
def train_and_evaluate(args):
"""Execute model training and evaluation loop."""
print(args)
jax.config.update("jax_use_shardy_partitioner", args.enable_shardy)
if args.process_id == 0:
nltk.download("punkt_tab")
......@@ -565,6 +566,9 @@ def encoder_parser(args):
default=0,
help="the ID number of the current process (default: 0)",
)
parser.add_argument(
"--enable-shardy", action="store_true", default=False, help="Enable Shardy (experimental)."
)
return parser.parse_args(args)
......@@ -573,7 +577,7 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
def exec(self, use_fp8, fp8_recipe):
def exec(self, use_fp8, fp8_recipe, *, enable_shardy=False):
"""Run 3 epochs for testing"""
args = encoder_parser([])
......@@ -589,6 +593,7 @@ class TestEncoder(unittest.TestCase):
args.num_process = num_gpu
args.process_id = self.process_id
args.fp8_recipe = fp8_recipe
args.enable_shardy = enable_shardy
return train_and_evaluate(args)
......@@ -604,7 +609,7 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.755
assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
......@@ -614,6 +619,22 @@ class TestEncoder(unittest.TestCase):
result = self.exec(True, "MXFP8BlockScaling")
assert result[0] < 0.505 and result[1] > 0.754
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self):
"""Test Transformer Engine with BF16"""
result = self.exec(False, None, enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.755
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for DelayedScaling FP8"
)
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -327,13 +327,12 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
"""Run 4 epochs for testing"""
cls.args = encoder_parser(["--epochs", "3"])
def setUp(self):
"""Run 3 epochs for testing"""
self.args = encoder_parser(["--epochs", "3"])
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16(self):
......
......@@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)
@classmethod
def setUpClass(cls):
......
......@@ -17,13 +17,17 @@ RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_multigpu_encoder.xml $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
wait
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_model_parallel_encoder.xml $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
wait
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
if [ $RET -ne 0 ]; then
......
......@@ -2,6 +2,8 @@
#
# See LICENSE for license information.
set -x
function error_exit() {
echo "Error: $1"
exit 1
......@@ -18,20 +20,23 @@ FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py without TE custom calls"
NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist || test_fail "test_mnist.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
......
......@@ -19,27 +19,32 @@ FAILED_CASES=""
set -x
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/test_batched_linear.xml $TE_PATH/tests/pytorch/test_batched_linear.py || test_fail "test_batched_linear.py"
PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -5,5 +5,7 @@
set -xe
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_distributed_*
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
......@@ -17,16 +17,19 @@ RET=0
FAILED_CASES=""
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --log-cli-level=INFO $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -5,9 +5,11 @@
set -x
: ${THUNDER_PATH:=/opt/pytorch/lightning-thunder}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.1.1 pytest-benchmark==5.1.0
python3 -m pytest -v -s ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml ${THUNDER_PATH}/thunder/tests/test_transformer_engine_executor.py
# Check return code
# Note: Return code 5 is fine. Lightning tests are skipped on systems
......
......@@ -2,22 +2,45 @@
#
# See LICENSE for license information.
set -xe
set -x
pip install "nltk>=3.8.2"
pip install pytest==8.2.1
function error_exit() {
echo "Error: $1"
exit 1
}
function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}
RET=0
FAILED_CASES=""
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_praxis_layers.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
# Test without custom calls
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax/test_custom_call_compute.py
NVTE_JAX_UNITTEST_LEVEL="L2" NVTE_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_custom_call_compute.xml $TE_PATH/tests/jax/test_custom_call_compute.py || test_fail "test_custom_call_compute.py"
pip install -r $TE_PATH/examples/jax/mnist/requirements.txt
pip install -r $TE_PATH/examples/jax/encoder/requirements.txt
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Failed to install encoder requirements"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/mnist
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
# Make encoder tests to have run-to-run deterministic to have the stable CI results
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
......@@ -5,6 +5,8 @@
set -e
: ${TE_PATH:=/opt/transformerengine}
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"
pip3 install pytest==8.2.1
......@@ -37,6 +39,6 @@ do
fi
# Run tests
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
done
......@@ -80,6 +80,12 @@ def setup_common_extension() -> CMakeExtension:
), "MPI_HOME must be set when compiling with NVTE_UB_WITH_MPI=1"
cmake_flags.append("-DNVTE_UB_WITH_MPI=ON")
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", "0"))):
assert (
os.getenv("NVSHMEM_HOME") is not None
), "NVSHMEM_HOME must be set when compiling with NVTE_ENABLE_NVSHMEM=1"
cmake_flags.append("-DNVTE_ENABLE_NVSHMEM=ON")
if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))):
cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON")
......@@ -125,12 +131,15 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
install_reqs.extend(["torch>=2.1"])
install_reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
# Blackwell is not supported as of Triton 3.2.0, need custom internal build
# install_reqs.append("triton")
test_reqs.extend(["numpy", "torchvision", "prettytable", "PyYAML"])
if "jax" in frameworks:
install_reqs.extend(["jax", "flax>=0.7.1"])
# test_reqs.extend(["numpy", "praxis"])
test_reqs.extend(["numpy"])
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment