Commit f8c2af4c authored by yuguo's avatar yuguo
Browse files

Merge commit '1d903f5e' of...

Merge commit '1d903f5e' of https://github.com/NVIDIA/TransformerEngine
parents e92773a3 1d903f5e
......@@ -25,7 +25,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install . -v
run: pip install --no-build-isolation . -v
env:
NVTE_FRAMEWORK: none
MAX_JOBS: 1
......@@ -49,7 +49,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install . -v --no-deps
run: pip install --no-build-isolation . -v --no-deps
env:
NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1
......@@ -68,7 +68,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: pip install . -v
run: pip install --no-build-isolation . -v
env:
NVTE_FRAMEWORK: jax
MAX_JOBS: 1
......
......@@ -51,6 +51,8 @@ jobs:
|| github.actor == 'xiaopoc'
|| github.actor == 'jreiffers'
|| github.actor == 'lhb8125'
|| github.actor == 'kunlunl'
|| github.actor == 'pstjohn'
)
steps:
- name: Check if comment is issued by authorized person
......
......@@ -8,25 +8,22 @@
Transformer Engine
==================
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/documentation-archive.html>`_
Latest News
===========
* [03/2025] `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72778/>`_
* [03/2025] `Measure and Improve AI Workload Performance with NVIDIA DGX Cloud Benchmarking <https://developer.nvidia.com/blog/measure-and-improve-ai-workload-performance-with-nvidia-dgx-cloud-benchmarking/>`_
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_
.. image:: docs/examples/H200-NeMo-performance.png
.. image:: docs/examples/comparison-fp8-bf16-training-nvidia-dgx-cloud-benchmarking-performance-explorer.jpg
:width: 600
:alt: H200
:alt: Comparison of FP8 versus BF16 training, as seen in NVIDIA DGX Cloud Benchmarking Performance Explorer
* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
* [02/2025] `Understanding the Language of Life's Biomolecules Across Evolution at a New Scale with Evo 2 <https://developer.nvidia.com/blog/understanding-the-language-of-lifes-biomolecules-across-evolution-at-a-new-scale-with-evo-2/>`_
* [02/2025] `NVIDIA DGX Cloud Introduces Ready-To-Use Templates to Benchmark AI Platform Performance <https://developer.nvidia.com/blog/nvidia-dgx-cloud-introduces-ready-to-use-templates-to-benchmark-ai-platform-performance/>`_
* [01/2025] `Continued Pretraining of State-of-the-Art LLMs for Sovereign AI and Regulated Industries with iGenius and NVIDIA DGX Cloud <https://developer.nvidia.com/blog/continued-pretraining-of-state-of-the-art-llms-for-sovereign-ai-and-regulated-industries-with-igenius-and-nvidia-dgx-cloud/>`_
`Previous News <#previous-news>`_
What is Transformer Engine?
===========================
......@@ -141,6 +138,8 @@ Flax
for _ in range(10):
loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
For a more comprehensive tutorial, check out our `Quickstart Notebook <https://github.com/NVIDIA/TransformerEngine/blob/main/docs/examples/quickstart.ipynb>`_.
.. overview-end-marker-do-not-remove
Installation
......@@ -171,13 +170,21 @@ Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.
For example to use the NGC PyTorch container interactively,
.. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.01-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.04-py3
For example to use the NGC JAX container interactively,
.. code-block:: bash
docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.04-py3
Where 25.01 (corresponding to January 2025 release) is the container version.
Where 25.04 (corresponding to April 2025 release) is the container version.
**Benefits of using NGC containers:**
......@@ -349,8 +356,8 @@ Integrations
Transformer Engine has been integrated with popular LLM frameworks such as:
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
* `DeepSpeed <https://github.com/deepspeedai/DeepSpeed/blob/master/tests/unit/runtime/half_precision/test_fp8.py>`_
* `Hugging Face Accelerate <https://huggingface.co/docs/accelerate/main/en/usage_guides/low_precision_training#configuring-transformersengine>`_
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
......@@ -381,10 +388,37 @@ Papers
Videos
======
* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `Blackwell Numerics for AI | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72458/>`_
* `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 <https://www.nvidia.com/gtc/session-catalog/?regcode=no-ncid&ncid=no-ncid&tab.catalogallsessionstab=16566177511100015Kus&search=zoho#/session/1726152813607001vnYK>`_
* `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72799/>`_
* `What's New in Transformer Engine and FP8 Training | GTC 2024 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `FP8 Training with Transformer Engine | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
* `Inside the Hopper Architecture | GTC 2022 <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
:target: https://opensource.org/licenses/Apache-2.0
Previous News
=============
* [11/2024] `Developing a 172B LLM with Strong Japanese Capabilities Using NVIDIA Megatron-LM <https://developer.nvidia.com/blog/developing-a-172b-llm-with-strong-japanese-capabilities-using-nvidia-megatron-lm/>`_
* [11/2024] `How FP8 boosts LLM training by 18% on Amazon SageMaker P5 instances <https://aws.amazon.com/blogs/machine-learning/how-fp8-boosts-llm-training-by-18-on-amazon-sagemaker-p5-instances/>`_
* [11/2024] `Efficiently train models with large sequence lengths using Amazon SageMaker model parallel <https://aws.amazon.com/blogs/machine-learning/efficiently-train-models-with-large-sequence-lengths-using-amazon-sagemaker-model-parallel/>`_
* [09/2024] `Reducing AI large model training costs by 30% requires just a single line of code from FP8 mixed precision training upgrades <https://company.hpc-ai.com/blog/reducing-ai-large-model-training-costs-by-30-requires-just-a-single-line-of-code-from-fp8-mixed-precision-training-upgrades>`_
* [05/2024] `Accelerating Transformers with NVIDIA cuDNN 9 <https://developer.nvidia.com/blog/accelerating-transformers-with-nvidia-cudnn-9/>`_
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_
.. image:: docs/examples/H200-NeMo-performance.png
:width: 600
:alt: H200
* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
......@@ -4,7 +4,6 @@
"""Installation script."""
import ctypes
import os
import subprocess
import sys
......@@ -25,7 +24,7 @@ from .utils import (
debug_build_enabled,
found_ninja,
get_frameworks,
cuda_path,
nvcc_path,
get_max_jobs_for_parallel_build,
)
......@@ -96,7 +95,9 @@ class CMakeExtension(setuptools.Extension):
print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
def get_build_ext(
extension_cls: Type[setuptools.Extension], framework_extension_only: bool = False
):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
......@@ -131,22 +132,28 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
super().run()
self.extensions = all_extensions
# Ensure that binaries are not in global package space.
# Ensure that shared objects files for source and PyPI installations live
# in separate directories to avoid conflicts during install and runtime.
lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or framework_extension_only
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
# Ensure that binaries are not in global package space.
# For editable/inplace builds this is not a concern as
# the SOs will be in a local directory anyway.
if not self.inplace:
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
def build_extensions(self):
# BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
# For core lib + JAX install, fix build_ext from pybind11.setup_helpers
# to handle CUDA files correctly.
if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
......@@ -159,20 +166,24 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
# Also redirect .hip files to HIPCC
original_compile_fn = self.compiler._compile
self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
if not framework_extension_only:
self.compiler.src_extensions += [".cu", ".cuh", ".hip"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so
try:
if rocm_build():
_, nvcc_bin = rocm_path()
else:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]:
if (
os.path.splitext(src)[1] in [".cu", ".cuh", ".hip"]
and not framework_extension_only
):
if rocm_build():
_, nvcc_bin = rocm_path()
else:
nvcc_bin = nvcc_path()
self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags["nvcc"]
......@@ -188,7 +199,6 @@ def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel
# Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict):
cflags = cflags["cxx"]
......
......@@ -4,12 +4,12 @@
"""JAX related extensions."""
import os
import shutil
from pathlib import Path
import setuptools
from glob import glob
from .utils import cuda_path, all_files_in_dir
from .utils import get_cuda_include_dirs, all_files_in_dir, debug_build_enabled
from typing import List
......@@ -41,31 +41,33 @@ def setup_jax_extension(
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "utils.cu",
] + all_files_in_dir(extensions_dir, ".cpp")
sources = all_files_in_dir(extensions_dir, name_extension="cpp")
# Header files
cuda_home, _ = cuda_path()
xla_home = xla_path()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
xla_home,
]
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
xla_path(),
]
)
# Compile flags
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]
if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
else:
cxx_flags.append("-g0")
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""
class Pybind11CPPExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow custom CXX flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
......@@ -75,9 +77,9 @@ def setup_jax_extension(
else:
self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension(
return Pybind11CPPExtension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
extra_compile_args={"cxx": cxx_flags},
)
......@@ -8,13 +8,7 @@ from pathlib import Path
import setuptools
from .utils import (
rocm_build,
hipify,
all_files_in_dir,
cuda_archs,
cuda_version,
)
from .utils import all_files_in_dir, cuda_version, get_cuda_include_dirs, debug_build_enabled, rocm_build, hipify
def setup_pytorch_extension(
......@@ -25,19 +19,26 @@ def setup_pytorch_extension(
"""Setup CUDA extension for PyTorch support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
] + all_files_in_dir(extensions_dir)
sources = all_files_in_dir(Path(csrc_source_files), name_extension="cpp")
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
if rocm_build():
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
else:
include_dirs = get_cuda_include_dirs()
include_dirs.extend(
[
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
)
if rocm_build():
current_file_path = Path(__file__).parent.resolve()
......@@ -87,44 +88,27 @@ def setup_pytorch_extension(
cxx_flags.append("-Wno-sign-compare")
else:
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
pass
if debug_build_enabled():
cxx_flags.append("-g")
cxx_flags.append("-UNDEBUG")
if rocm_build():
nvcc_flags.append("-g")
nvcc_flags.append("-UNDEBUG")
# Version-dependent CUDA options
if rocm_build():
##TODO: Figure out which hipcc version starts to support this parallel compilation
nvcc_flags.extend(["-parallel-jobs=4"])
else:
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
print("Could not determine CUDA version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
# Libraries
library_dirs = []
......@@ -136,12 +120,11 @@ def setup_pytorch_extension(
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
if rocm_build():
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
library_dirs.append(mpi_path / "lib")
libraries.append("mpi")
library_dirs = []
libraries = []
if bool(int(os.getenv("NVTE_ENABLE_NVSHMEM", 0))):
assert (
os.getenv("NVSHMEM_HOME") is not None
......@@ -151,21 +134,32 @@ def setup_pytorch_extension(
library_dirs.append(nvshmem_home / "lib")
libraries.append("nvshmem_host")
cxx_flags.append("-DNVTE_ENABLE_NVSHMEM")
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
if rocm_build():
nvcc_flags.append("-DNVTE_ENABLE_NVSHMEM")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
from torch.utils.cpp_extension import CppExtension, CUDAExtension
if rocm_build():
return CUDAExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
else:
return CppExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={"cxx": cxx_flags},
libraries=[str(lib) for lib in libraries],
library_dirs=[str(lib_dir) for lib_dir in library_dirs],
)
......@@ -13,6 +13,7 @@ import shutil
import subprocess
import sys
from pathlib import Path
from importlib.metadata import version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
......@@ -55,7 +56,7 @@ def all_files_in_dir(path, name_extension=None):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
if name_extension is not None and name_extension not in name:
if name_extension is not None and not name.endswith(f".{name_extension}"):
continue
all_files.append(Path(dirname, name))
return all_files
......@@ -172,7 +173,7 @@ def rocm_build() -> bool:
return True
try:
cuda_path()
nvcc_path()
return False
except FileNotFoundError:
pass
......@@ -200,8 +201,30 @@ def rocm_path() -> Tuple[str, str]:
@functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
def cuda_toolkit_include_path() -> Tuple[str, str]:
"""Returns root path for cuda toolkit includes.
return `None` if CUDA is not found."""
# Try finding CUDA
cuda_home: Optional[Path] = None
if cuda_home is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME")) / "include"
if cuda_home is None:
# Check in NVCC
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc")) / "include"
if cuda_home is None:
# Last-ditch guess in /usr/local/cuda
if Path("/usr/local/cuda").is_dir():
cuda_home = Path("/usr/local/cuda") / "include"
return cuda_home
@functools.lru_cache(maxsize=None)
def nvcc_path() -> Tuple[str, str]:
"""Returns the NVCC binary path.
Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC
......@@ -223,7 +246,34 @@ def cuda_path() -> Tuple[str, str]:
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin
return nvcc_bin
@functools.lru_cache(maxsize=None)
def get_cuda_include_dirs() -> Tuple[str, str]:
"""Returns the CUDA header directory."""
# If cuda is installed via toolkit, all necessary headers
# are bundled inside the top level cuda directory.
if cuda_toolkit_include_path() is not None:
return [cuda_toolkit_include_path()]
# Use pip wheels to include all headers.
try:
import nvidia
except ModuleNotFoundError as e:
raise RuntimeError("CUDA not found.")
cuda_root = Path(nvidia.__file__).parent
return [
cuda_root / "cuda_nvcc" / "include",
cuda_root / "cublas" / "include",
cuda_root / "cuda_runtime" / "include",
cuda_root / "cudnn" / "include",
cuda_root / "cuda_cccl" / "include",
cuda_root / "nvtx" / "include",
cuda_root / "cuda_nvrtc" / "include",
]
@functools.lru_cache(maxsize=None)
......@@ -237,18 +287,34 @@ def cuda_archs() -> str:
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
_, nvcc_bin = cuda_path()
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)
"""CUDA Toolkit version as a (major, minor) tuple.
Try to get cuda version by locating the nvcc executable and running nvcc --version. If
nvcc is not found, look for the cuda runtime package pip `nvidia-cuda-runtime-cu12`
and check pip version.
"""
try:
nvcc_bin = nvcc_path()
except FileNotFoundError as e:
pass
else:
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)
try:
version_str = version("nvidia-cuda-runtime-cu12")
version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit())
return version_tuple
except importlib.metadata.PackageNotFoundError:
raise RuntimeError("Could neither find NVCC executable nor CUDA runtime Python package.")
def get_frameworks() -> List[str]:
......@@ -396,18 +462,3 @@ def install_and_import(package):
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
]
)
......@@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
......
......@@ -8,11 +8,9 @@ import gc
from contextlib import contextmanager
import torch
from torch import nn
import transformer_engine as te
from transformer_engine.pytorch.dot_product_attention.rope import RotaryPositionEmbedding
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.attention import RotaryPositionEmbedding
import transformers
from transformers.models.llama.modeling_llama import (
......
......@@ -37,5 +37,7 @@ def get_fp8_recipe_from_name_string(name: str):
return recipe.DelayedScaling()
case "MXFP8BlockScaling":
return recipe.MXFP8BlockScaling()
case "Float8CurrentScaling":
return recipe.Float8CurrentScaling()
case _:
raise ValueError(f"Invalid fp8_recipe, got {name}")
......@@ -8,9 +8,11 @@ NUM_GPUS=${NUM_GPUS:-$(nvidia-smi -L | wc -l)}
TEST_CASES=(
"test_te_bf16"
"test_te_delayed_scaling_fp8"
"test_te_current_scaling_fp8"
"test_te_mxfp8"
"test_te_bf16_shardy"
"test_te_delayed_scaling_fp8_shardy"
"test_te_current_scaling_fp8_shardy"
)
echo
......
......@@ -441,6 +441,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
......@@ -467,6 +475,15 @@ class TestEncoder(unittest.TestCase):
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.enable_shardy = True
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.535 and actual[1] > 0.73
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -609,7 +609,15 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling")
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.505 and result[1] > 0.753
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling")
assert result[0] < 0.507 and result[1] > 0.753
@unittest.skipIf(
not is_mxfp8_supported(), "Device compute capability 10.0+ is required for MXFP8"
......@@ -631,10 +639,18 @@ class TestEncoder(unittest.TestCase):
def test_te_delayed_scaling_fp8_shardy(self):
"""Test Transformer Engine with DelayedScaling FP8"""
result = self.exec(True, "DelayedScaling", enable_shardy=True)
assert result[0] < 0.505 and result[1] > 0.754
assert result[0] < 0.505 and result[1] > 0.753
# TODO(jreiffers): Add mxfp8 Shardy tests once supported in JAX.
@unittest.skipIf(
not is_fp8_supported(), "Device compute capability 9.0+ is required for CurrentScaling FP8"
)
def test_te_current_scaling_fp8_shardy(self):
"""Test Transformer Engine with CurrentScaling FP8"""
result = self.exec(True, "Float8CurrentScaling", enable_shardy=True)
assert result[0] < 0.507 and result[1] > 0.753
if __name__ == "__main__":
train_and_evaluate(encoder_parser(None))
......@@ -348,6 +348,14 @@ class TestEncoder(unittest.TestCase):
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.455 and actual[1] > 0.79
@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
......
......@@ -350,6 +350,14 @@ class TestMNIST(unittest.TestCase):
actual = train_and_evaluate(self.args)
self.verify(actual)
@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_current_scaling_fp8(self):
"""Test Transformer Engine with CurrentScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "Float8CurrentScaling"
actual = train_and_evaluate(self.args)
self.verify(actual)
if __name__ == "__main__":
train_and_evaluate(mnist_parser(None))
......@@ -8,7 +8,7 @@
# FSDP without deferred initialization:
# Duplicate modules initialized on each device. Load on device memory reduced only after
# torch.distributed.fsdp.FullyShardedDataParallel mode shards model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --no-defer-init
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# [GPU-0] TransformerEngine Model:
......@@ -40,7 +40,7 @@ $ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsd
# Modules initialized with empty parameters via `device='meta'` option. Zero load on device
# memory until torch.distributed.fsdp.FullyShardedDataParallel mode triggers a reset on
# on already sharded model parameters.
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py --defer-init
$ torchrun --standalone --nnodes=1 --nproc-per-node=$(nvidia-smi -L | wc -l) fsdp.py
# Sample output on 8xL40S:
# [GPU-0] WORLD_SIZE = 8
# ...
......
......@@ -43,8 +43,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
......
......@@ -25,9 +25,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --log-cli-level=INFO --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
# python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment