Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
...@@ -25,7 +25,7 @@ jobs: ...@@ -25,7 +25,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: pip install . -v run: pip install --no-build-isolation . -v
env: env:
NVTE_FRAMEWORK: none NVTE_FRAMEWORK: none
MAX_JOBS: 1 MAX_JOBS: 1
...@@ -49,7 +49,7 @@ jobs: ...@@ -49,7 +49,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: pip install . -v --no-deps run: pip install --no-build-isolation . -v --no-deps
env: env:
NVTE_FRAMEWORK: pytorch NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1 MAX_JOBS: 1
...@@ -68,7 +68,7 @@ jobs: ...@@ -68,7 +68,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: pip install . -v run: pip install --no-build-isolation . -v
env: env:
NVTE_FRAMEWORK: jax NVTE_FRAMEWORK: jax
MAX_JOBS: 1 MAX_JOBS: 1
......
...@@ -51,6 +51,8 @@ jobs: ...@@ -51,6 +51,8 @@ jobs:
|| github.actor == 'xiaopoc' || github.actor == 'xiaopoc'
|| github.actor == 'jreiffers' || github.actor == 'jreiffers'
|| github.actor == 'lhb8125' || github.actor == 'lhb8125'
|| github.actor == 'kunlunl'
|| github.actor == 'pstjohn'
) )
steps: steps:
- name: Check if comment is issued by authorized person - name: Check if comment is issued by authorized person
......
...@@ -8,25 +8,22 @@ ...@@ -8,25 +8,22 @@
Transformer Engine Transformer Engine
================== ==================
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_ `Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html>`_
Latest News Latest News
=========== ===========
* [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72778/>`_
* [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking <https://developer.nvidia.com/blog/measure-and-improve-ai-workload-performance-with-nvidia-dgx-cloud-benchmarking/>`_
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_ .. image:: docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_
.. image:: docs/examples/H200-NeMo-performance.png
:width: 600 :width: 600
:alt: H200 :alt: Comparison of FP8 versus BF16 training, as seen in NVIDIA DGX Cloud Benchmarking Performance Explorer
* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_ * [02/2025] `Understanding the Language of Life's Biomolecules Across Evolution at a New Scale with Evo 2 <https://developer.nvidia.com/blog/understanding-the-language-of-lifes-biomolecules-across-evolution-at-a-new-scale-with-evo-2/>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_ * [02/2025] `NVIDIA DGX Cloud Introduces Ready-To-Use Templates to Benchmark AI Platform Performance <https://developer.nvidia.com/blog/nvidia-dgx-cloud-introduces-ready-to-use-templates-to-benchmark-ai-platform-performance/>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_ * [01/2025] `Continued Pretraining of State-of-the-Art LLMs for Sovereign AI and Regulated Industries with iGenius and NVIDIA DGX Cloud <https://developer.nvidia.com/blog/continued-pretraining-of-state-of-the-art-llms-for-sovereign-ai-and-regulated-industries-with-igenius-and-nvidia-dgx-cloud/>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_ `Previous News <#previous-news>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
What is Transformer Engine? What is Transformer Engine?
=========================== ===========================
...@@ -141,6 +138,8 @@ Flax ...@@ -141,6 +138,8 @@ Flax
for _ in range(10): for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp) loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
For a more comprehensive tutorial, check out our `Quickstart Notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb>`_.
.. overview-end-marker-do-not-remove .. overview-end-marker-do-not-remove
Installation Installation
...@@ -171,13 +170,21 @@ Docker (Recommended) ...@@ -171,13 +170,21 @@ Docker (Recommended)
^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. `NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively, For example to use the NGC PyTorch container interactively,
.. code-block:: bash .. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3 docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.04-py3
For example to use the NGC JAX container interactively,
.. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.04-py3
Where 25.01 (corresponding to January 2025 release) is the container version. Where 25.04 (corresponding to April 2025 release) is the container version.
**Benefits of using NGC containers:** **Benefits of using NGC containers:**
...@@ -349,8 +356,8 @@ Integrations ...@@ -349,8 +356,8 @@ Integrations
Transformer Engine has been integrated with popular LLM frameworks such as: Transformer Engine has been integrated with popular LLM frameworks such as:
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_ * `DeepSpeed <https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/runtime/half_precision/test_fp8.py>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_ * `Hugging Face Accelerate <https://huggingface.co/docs/accelerate/main/en/usage_guides/low_precision_training#configuring-transformersengine>`_
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_ * `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_ * `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_ * `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
...@@ -381,10 +388,37 @@ Papers ...@@ -381,10 +388,37 @@ Papers
Videos Videos
====== ======
* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `Blackwell Numerics for AI | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72458/>`_
* `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 <https://www.nvidia.com/gtc/session-catalog/?regcode=no-ncid&ncid=no-ncid&tab.catalogallsessionstab=16566177511100015Kus&search=zoho#/session/1726152813607001vnYK>`_
* `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72799/>`_
* `What's New in Transformer Engine and FP8 Training | GTC 2024 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_ * `What's New in Transformer Engine and FP8 Training | GTC 2024 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `FP8 Training with Transformer Engine | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_ * `FP8 Training with Transformer Engine | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_ * `FP8 for Deep Learning | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_ * `Inside the Hopper Architecture | GTC 2022 <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg .. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
:target: https://opensource.org/licenses/Apache-2.0 :target: https://opensource.org/licenses/Apache-2.0
Previous News
=============
* [11/2024] `Developing a 172B LLM with Strong Japanese Capabilities Using NVIDIA Megatron-LM <https://developer.nvidia.com/blog/developing-a-172b-llm-with-strong-japanese-capabilities-using-nvidia-megatron-lm/>`_
* [11/2024] `How FP8 boosts LLM training by 18% on Amazon SageMaker P5 instances <https://aws.amazon.com/blogs/machine-learning/how-fp8-boosts-llm-training-by-18-on-amazon-sagemaker-p5-instances/>`_
* [11/2024] `Efficiently train models with large sequence lengths using Amazon SageMaker model parallel <https://aws.amazon.com/blogs/machine-learning/efficiently-train-models-with-large-sequence-lengths-using-amazon-sagemaker-model-parallel/>`_
* [09/2024] `Reducing AI large model training costs by 30% requires just a single line of code from FP8 mixed precision training upgrades <https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades>`_
* [05/2024] `Accelerating Transformers with NVIDIA cuDNN 9 <https://developer.nvidia.com/blog/accelerating-transformers-with-nvidia-cudnn-9/>`_
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_
.. image:: docs/examples/H200-NeMo-performance.png
:width: 600
:alt: H200
* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
...@@ -4,7 +4,6 @@ ...@@ -4,7 +4,6 @@
"""Installation script.""" """Installation script."""
import ctypes
import os import os
import subprocess import subprocess
import sys import sys
...@@ -25,7 +24,7 @@ from .utils import ( ...@@ -25,7 +24,7 @@ from .utils import (
debug_build_enabled, debug_build_enabled,
found_ninja, found_ninja,
get_frameworks, get_frameworks,
cuda_path, nvcc_path,
get_max_jobs_for_parallel_build, get_max_jobs_for_parallel_build,
) )
...@@ -96,7 +95,9 @@ class CMakeExtension(setuptools.Extension): ...@@ -96,7 +95,9 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds") print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False): def get_build_ext(
extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
class _CMakeBuildExtension(extension_cls): class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules""" """Setuptools command with support for CMake extension modules"""
...@@ -131,22 +132,28 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -131,22 +132,28 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
super().run() super().run()
self.extensions = all_extensions self.extensions = all_extensions
# Ensure that binaries are not in global package space. # Ensure that shared objects files for source and PyPI installations live
# in separate directories to avoid conflicts during install and runtime.
lib_dir = ( lib_dir = (
"wheel_lib" "wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else "" else ""
) )
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"): # Ensure that binaries are not in global package space.
self.copy_file(ext, target_dir) # For editable/inplace builds this is not a concern as
os.remove(ext) # the SOs will be in a local directory anyway.
if not self.inplace:
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
def build_extensions(self): def build_extensions(self):
# BuildExtensions from PyTorch already handle CUDA files correctly # For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed. # to handle CUDA files correctly.
if "pytorch" not in get_frameworks(): if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when # Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict. # extra_compile_args is a dict.
...@@ -159,20 +166,24 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -159,20 +166,24 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Define new _compile method that redirects to NVCC for .cu and .cuh files. # Define new _compile method that redirects to NVCC for .cu and .cuh files.
# Also redirect .hip files to HIPCC # Also redirect .hip files to HIPCC
original_compile_fn = self.compiler._compile original_compile_fn = self.compiler._compile
self.compiler.src_extensions += [".cu", ".cuh", ".hip"] if not framework_extension_only:
self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None: def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications. # Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs) cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
try: try:
if rocm_build():
_, nvcc_bin = rocm_path()
else:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]: if (
os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]
and not framework_extension_only
):
if rocm_build():
_, nvcc_bin = rocm_path()
else:
nvcc_bin = nvcc_path()
self.compiler.set_executable("compiler_so", str(nvcc_bin)) self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict): if isinstance(cflags, dict):
cflags = cflags["nvcc"] cflags = cflags["nvcc"]
...@@ -188,7 +199,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel ...@@ -188,7 +199,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Forward unknown options # Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags): if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts") cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict): elif isinstance(cflags, dict):
cflags = cflags["cxx"] cflags = cflags["cxx"]
......
...@@ -4,12 +4,12 @@ ...@@ -4,12 +4,12 @@
"""JAX related extensions.""" """JAX related extensions."""
import os import os
import shutil
from pathlib import Path from pathlib import Path
import setuptools import setuptools
from glob import glob
from .utils import cuda_path, all_files_in_dir from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
from typing import List from typing import List
...@@ -41,31 +41,33 @@ def setup_jax_extension( ...@@ -41,31 +41,33 @@ def setup_jax_extension(
# Source files # Source files
csrc_source_files = Path(csrc_source_files) csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions" extensions_dir = csrc_source_files / "extensions"
sources = [ sources = all_files_in_dir(extensions_dir, name_extension="cpp")
csrc_source_files / "utils.cu",
] + all_files_in_dir(extensions_dir, ".cpp")
# Header files # Header files
cuda_home, _ = cuda_path() include_dirs = get_cuda_include_dirs()
xla_home = xla_path() include_dirs.extend(
include_dirs = [ [
cuda_home / "include", common_header_files,
common_header_files, common_header_files / "common",
common_header_files / "common", common_header_files / "common" / "include",
common_header_files / "common" / "include", csrc_header_files,
csrc_header_files, xla_path(),
xla_home, ]
] )
# Compile flags # Compile flags
cxx_flags = ["-O3"] cxx_flags = ["-O3"]
nvcc_flags = ["-O3"] if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
else:
cxx_flags.append("-g0")
# Define TE/JAX as a Pybind11Extension # Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension): class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags.""" """Modified Pybind11Extension to allow custom CXX flags."""
def _add_cflags(self, flags: List[str]) -> None: def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict): if isinstance(self.extra_compile_args, dict):
...@@ -75,9 +77,9 @@ def setup_jax_extension( ...@@ -75,9 +77,9 @@ def setup_jax_extension(
else: else:
self.extra_compile_args[:0] = flags self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension( return Pybind11CPPExtension(
"transformer_engine_jax", "transformer_engine_jax",
sources=[str(path) for path in sources], sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs], include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags}, extra_compile_args={"cxx": cxx_flags},
) )
...@@ -8,13 +8,7 @@ from pathlib import Path ...@@ -8,13 +8,7 @@ from pathlib import Path
import setuptools import setuptools
from .utils import ( from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled, rocm_build, hipify
rocm_build,
hipify,
all_files_in_dir,
cuda_archs,
cuda_version,
)
def setup_pytorch_extension( def setup_pytorch_extension(
...@@ -25,19 +19,26 @@ def setup_pytorch_extension( ...@@ -25,19 +19,26 @@ def setup_pytorch_extension(
"""Setup CUDA extension for PyTorch support""" """Setup CUDA extension for PyTorch support"""
# Source files # Source files
csrc_source_files = Path(csrc_source_files) sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
] + all_files_in_dir(extensions_dir)
# Header files # Header files
include_dirs = [ if rocm_build():
common_header_files, include_dirs = [
common_header_files / "common", common_header_files,
common_header_files / "common" / "include", common_header_files / "common",
csrc_header_files, common_header_files / "common" / "include",
] csrc_header_files,
]
else:
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
)
if rocm_build(): if rocm_build():
current_file_path = Path(__file__).parent.resolve() current_file_path = Path(__file__).parent.resolve()
...@@ -87,44 +88,27 @@ def setup_pytorch_extension( ...@@ -87,44 +88,27 @@ def setup_pytorch_extension(
cxx_flags.append("-Wno-sign-compare") cxx_flags.append("-Wno-sign-compare")
else: else:
nvcc_flags = [ pass
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__", if debug_build_enabled():
"-U__CUDA_NO_HALF_CONVERSIONS__", cxx_flags.append("-g")
"-U__CUDA_NO_BFLOAT16_OPERATORS__", cxx_flags.append("-UNDEBUG")
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__", if rocm_build():
"-U__CUDA_NO_BFLOAT162_OPERATORS__", nvcc_flags.append("-g")
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__", nvcc_flags.append("-UNDEBUG")
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options # Version-dependent CUDA options
if rocm_build(): if rocm_build():
##TODO: Figure out which hipcc version starts to support this parallel compilation ##TODO: Figure out which hipcc version starts to support this parallel compilation
nvcc_flags.extend(["-parallel-jobs=4"]) nvcc_flags.extend(["-parallel-jobs=4"])
else: else:
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
try: try:
version = cuda_version() version = cuda_version()
except FileNotFoundError: except FileNotFoundError:
print("Could not determine CUDA Toolkit version") print("Could not determine CUDA version")
else: else:
if version < (12, 0): if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer") raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries # Libraries
library_dirs = [] library_dirs = []
...@@ -136,12 +120,11 @@ def setup_pytorch_extension( ...@@ -136,12 +120,11 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME")) mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include") include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI") cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI") if rocm_build():
library_dirs.append(mpi_path / "lib") nvcc_flags.append("-DNVTE_UB_WITH_MPI")
libraries.append("mpi") library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))): if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert ( assert (
os.getenv("NVSHMEM_HOME") is not None os.getenv("NVSHMEM_HOME") is not None
...@@ -151,21 +134,32 @@ def setup_pytorch_extension( ...@@ -151,21 +134,32 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib") library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host") libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM") cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM") if rocm_build():
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension # Construct PyTorch CUDA extension
sources = [str(path) for path in sources] sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs] include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension from torch.utils.cpp_extension import CppExtension, CUDAExtension
return CUDAExtension( if rocm_build():
name="transformer_engine_torch", return CUDAExtension(
sources=[str(src) for src in sources], name="transformer_engine_torch",
include_dirs=[str(inc) for inc in include_dirs], sources=[str(src) for src in sources],
extra_compile_args={ include_dirs=[str(inc) for inc in include_dirs],
"cxx": cxx_flags, extra_compile_args={
"nvcc": nvcc_flags, "cxx": cxx_flags,
}, "nvcc": nvcc_flags,
libraries=[str(lib) for lib in libraries], },
library_dirs=[str(lib_dir) for lib_dir in library_dirs], libraries=[str(lib) for lib in libraries],
) library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
else:
return CppExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
...@@ -13,6 +13,7 @@ import shutil ...@@ -13,6 +13,7 @@ import shutil
import subprocess import subprocess
import sys import sys
from pathlib import Path from pathlib import Path
from importlib.metadata import version
from subprocess import CalledProcessError from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
...@@ -55,7 +56,7 @@ def all_files_in_dir(path, name_extension=None): ...@@ -55,7 +56,7 @@ def all_files_in_dir(path, name_extension=None):
all_files = [] all_files = []
for dirname, _, names in os.walk(path): for dirname, _, names in os.walk(path):
for name in names: for name in names:
if name_extension is not None and name_extension not in name: if name_extension is not None and not name.endswith(f".{name_extension}"):
continue continue
all_files.append(Path(dirname, name)) all_files.append(Path(dirname, name))
return all_files return all_files
...@@ -172,7 +173,7 @@ def rocm_build() -> bool: ...@@ -172,7 +173,7 @@ def rocm_build() -> bool:
return True return True
try: try:
cuda_path() nvcc_path()
return False return False
except FileNotFoundError: except FileNotFoundError:
pass pass
...@@ -200,8 +201,30 @@ def rocm_path() -> Tuple[str, str]: ...@@ -200,8 +201,30 @@ def rocm_path() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]: def cuda_toolkit_include_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple. """Returns root path for cuda toolkit includes.
return `None` if CUDA is not found."""
# Try finding CUDA
cuda_home: Optional[Path] = None
if cuda_home is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME")) / "include"
if cuda_home is None:
# Check in NVCC
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) / "include"
if cuda_home is None:
# Last-ditch guess in /usr/local/cuda
if Path("/usr/local/cuda").is_dir():
cuda_home = Path("/usr/local/cuda") / "include"
return cuda_home
@functools.lru_cache(maxsize=None)
def nvcc_path() -> Tuple[str, str]:
"""Returns the NVCC binary path.
Throws FileNotFoundError if NVCC is not found.""" Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC # Try finding NVCC
...@@ -223,7 +246,34 @@ def cuda_path() -> Tuple[str, str]: ...@@ -223,7 +246,34 @@ def cuda_path() -> Tuple[str, str]:
if not nvcc_bin.is_file(): if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}") raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin return nvcc_bin
@functools.lru_cache(maxsize=None)
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]
# Use pip wheels to include all headers.
try:
import nvidia
except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.")
cuda_root = Path(nvidia.__file__).parent
return [
cuda_root / "cuda_nvcc" / "include",
cuda_root / "cublas" / "include",
cuda_root / "cuda_runtime" / "include",
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
]
@functools.lru_cache(maxsize=None) @functools.lru_cache(maxsize=None)
...@@ -237,18 +287,34 @@ def cuda_archs() -> str: ...@@ -237,18 +287,34 @@ def cuda_archs() -> str:
def cuda_version() -> Tuple[int, ...]: def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple.""" """CUDA Toolkit version as a (major, minor) tuple.
# Query NVCC for version info
_, nvcc_bin = cuda_path() Try to get cuda version by locating the nvcc executable and running nvcc --version. If
output = subprocess.run( nvcc is not found, look for the cuda runtime package pip `nvidia-cuda-runtime-cu12`
[nvcc_bin, "-V"], and check pip version.
capture_output=True, """
check=True,
universal_newlines=True, try:
) nvcc_bin = nvcc_path()
match = re.search(r"release\s*([\d.]+)", output.stdout) except FileNotFoundError as e:
version = match.group(1).split(".") pass
return tuple(int(v) for v in version) else:
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)
try:
version_str = version("nvidia-cuda-runtime-cu12")
version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit())
return version_tuple
except importlib.metadata.PackageNotFoundError:
raise RuntimeError("Could neither find NVCC executable nor CUDA runtime Python package.")
def get_frameworks() -> List[str]: def get_frameworks() -> List[str]:
...@@ -396,18 +462,3 @@ def install_and_import(package): ...@@ -396,18 +462,3 @@ def install_and_import(package):
main_package = package.split("[")[0] main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package) globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
]
)
...@@ -458,7 +458,7 @@ ...@@ -458,7 +458,7 @@
" </tr>\n", " </tr>\n",
"</table>\n", "</table>\n",
"\n", "\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note</b>\n", "<b>Note</b>\n",
......
...@@ -8,11 +8,9 @@ import gc ...@@ -8,11 +8,9 @@ import gc
from contextlib import contextmanager from contextlib import contextmanager
import torch import torch
from torch import nn
import transformer_engine as te import transformer_engine as te
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding from transformer_engine.pytorch.attention import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
import transformers import transformers
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
......
...@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str): ...@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.DelayedScaling() return recipe.DelayedScaling()
case "MXFP8BlockScaling": case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling() return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case _: case _:
raise ValueError(f"Invalid fp8_recipe, got {name}") raise ValueError(f"Invalid fp8_recipe, got {name}")
...@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)} ...@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
TEST_CASES=( TEST_CASES=(
"test_te_bf16" "test_te_bf16"
"test_te_delayed_scaling_fp8" "test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8" "test_te_mxfp8"
"test_te_bf16_shardy" "test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy" "test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
) )
echo echo
......
...@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase): ...@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73 assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
...@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase): ...@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase):
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -609,7 +609,15 @@ class TestEncoder(unittest.TestCase): ...@@ -609,7 +609,15 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self): def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling") result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.753
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
@unittest.skipIf( @unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8" not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
...@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase): ...@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self): def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8""" """Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True) result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754 assert result[0] < 0.505 and result[1] > 0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX. # TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(encoder_parser(None)) train_and_evaluate(encoder_parser(None))
...@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase): ...@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79 assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason) @unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self): def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8""" """Test Transformer Engine with MXFP8"""
......
...@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase): ...@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
self.verify(actual) self.verify(actual)
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__": if __name__ == "__main__":
train_and_evaluate(mnist_parser(None)) train_and_evaluate(mnist_parser(None))
...@@ -8,7 +8,7 @@ ...@@ -8,7 +8,7 @@
# FSDP without deferred initialization: # FSDP without deferred initialization:
# Duplicate modules initialized on each device. Load on device memory reduced only after # Duplicate modules initialized on each device. Load on device memory reduced only after
# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters. # torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --no-defer-init
# Sample output on 8xL40S: # Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8 # [GPU-0] WORLD_SIZE = 8
# [GPU-0] TransformerEngine Model: # [GPU-0] TransformerEngine Model:
...@@ -40,7 +40,7 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd ...@@ -40,7 +40,7 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# Modules initialized with empty parameters via `device='meta'` option. Zero load on device # Modules initialized with empty parameters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on # memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters. # on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
# Sample output on 8xL40S: # Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8 # [GPU-0] WORLD_SIZE = 8
# ... # ...
......
...@@ -43,8 +43,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ ...@@ -43,8 +43,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
if [ "$RET" -ne 0 ]; then if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES" echo "Error in the following test cases:$FAILED_CASES"
......
...@@ -25,9 +25,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -25,9 +25,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment