Commit 5b6ef054 authored by yuguo's avatar yuguo
Browse files
parents 76060570 a7eeb28b
## Security
NVIDIA is dedicated to the security and trust of our software products and services, including all source code repositories managed through our organization.
If you need to report a security issue, please use the appropriate contact points outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.**
## Reporting Potential Security Vulnerability in an NVIDIA Product
To report a potential security vulnerability in any NVIDIA product:
- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
- E-Mail: psirt@nvidia.com
- We encourage you to use the following PGP key for secure email communication: [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
- Please include the following information:
- Product/Driver name and version/branch that contains the vulnerability
- Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
- Instructions to reproduce the vulnerability
- Proof-of-concept or exploit code
- Potential impact of the vulnerability, including how an attacker could exploit the vulnerability
While NVIDIA currently does not have a bug bounty program, we do offer acknowledgement when an externally reported security issue is addressed under our coordinated vulnerability disclosure policy. Please visit our [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/) policies page for more information.
## NVIDIA Product Security
For all security-related concerns, please visit NVIDIA's Product Security portal at https://www.nvidia.com/en-us/security
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os, sys, time
import subprocess
import pandas as pd
import numpy as np
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
)
pd.set_option("display.precision", 4)
# data type
dtype = torch.bfloat16
# number of iterations after 3 warmup iterations
num_iters = 3
# checkpointing
ckpt_attn = False
# workspace optimization path for cuDNN attention
workspace_opt = True
# QKV memory layout
qkv_layout = "bshd_bshd_bshd"
# padding between sequences for qkv_format=thd
pad_between_seqs = False
# training mode
is_training = True
model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
}
def benchmark_dot_product_attention(model, fused_attn_supported, flash_attn_supported):
config = model_configs[model]
if dtype == torch.bfloat16:
tols = dict(atol=2.5e-2, rtol=2.5e-2)
else:
tols = dict(atol=5e-3, rtol=5e-3)
cudnn_times = []
flash_times = []
warmup_iters = 3
for i in range(warmup_iters):
if fused_attn_supported:
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if flash_attn_supported:
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
if fused_attn_supported and flash_attn_supported:
torch.testing.assert_close(fused_attn_fwd, flash_attn_fwd, **tols)
for i, _ in enumerate(flash_attn_bwd):
torch.testing.assert_close(fused_attn_bwd[i], flash_attn_bwd[i], **tols)
torch.cuda.cudart().cudaProfilerStart()
torch.cuda.synchronize()
fused_attn_start = time.time()
if fused_attn_supported:
for i in range(num_iters):
fused_attn_fwd, fused_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FusedAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
fused_attn_time = time.time() - fused_attn_start if fused_attn_supported else 0
torch.cuda.synchronize()
flash_attn_start = time.time()
if flash_attn_supported:
for i in range(num_iters):
flash_attn_fwd, flash_attn_bwd = _run_dot_product_attention(
dtype,
config,
"FlashAttention",
ckpt_attn,
qkv_layout,
workspace_opt,
pad_between_seqs,
is_training,
)
torch.cuda.synchronize()
flash_attn_time = time.time() - flash_attn_start if flash_attn_supported else 0
df = pd.read_csv("times.csv")
df = pd.concat(
[
df,
pd.DataFrame(
[
[
fused_attn_time * 1e3 / num_iters,
0,
0,
0,
flash_attn_time * 1e3 / num_iters,
0,
0,
0,
0,
]
],
columns=df.columns,
),
],
ignore_index=True,
)
df.to_csv("times.csv", index=False)
torch.cuda.cudart().cudaProfilerStop()
def parse_results(per_cudnn, per_flash, model):
filename = f"prof_{model}_cuda_gpu_trace.csv"
df = pd.read_csv(os.path.join("./", filename))
df_times = pd.read_csv("times.csv")
row = len(df_times.index) - 1
if per_cudnn > 0:
t_cudnn_all = df[df["Name"].str.contains("cudnn")]["Duration (ns)"].to_numpy()
t_cudnn_all = t_cudnn_all.reshape(-1, per_cudnn)
t_cudnn_avg = np.average(t_cudnn_all, axis=0)
df_times.loc[row, "FusedAttention Kernels (fwd)"] = t_cudnn_avg[0] / 1e6
df_times.loc[row, "FusedAttention Kernels (bwd)"] = t_cudnn_avg[1:4].sum() / 1e6
df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"] = t_cudnn_avg.sum() / 1e6
if per_flash > 0:
t_flash_all = df[df["Name"].str.contains("flash")]["Duration (ns)"].to_numpy()
t_flash_all = t_flash_all.reshape(-1, per_flash)
t_flash_avg = np.average(t_flash_all, axis=0)
df_times.loc[row, "FlashAttention Kernels (fwd)"] = t_flash_avg[0] / 1e6
df_times.loc[row, "FlashAttention Kernels (bwd)"] = t_flash_avg[1:4].sum() / 1e6
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"] = t_flash_avg.sum() / 1e6
if per_cudnn > 0 and per_flash > 0:
df_times.loc[row, "Fused vs Flash Kernels Speedup (fwd+bwd)"] = (
df_times.loc[row, "FlashAttention Kernels (fwd+bwd)"]
/ df_times.loc[row, "FusedAttention Kernels (fwd+bwd)"]
)
df_times.to_csv("times.csv", index=False)
def main():
times = pd.DataFrame(
columns=[
"FusedAttention Module",
"FusedAttention Kernels (fwd)",
"FusedAttention Kernels (bwd)",
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Module",
"FlashAttention Kernels (fwd)",
"FlashAttention Kernels (bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
)
times.to_csv("times.csv", index=False)
device_id = torch.cuda.current_device()
device_properties = torch.cuda.get_device_properties(device_id)
print(
f"Device {device_id}: "
f"{device_properties.name} GPU, "
f"sm{device_properties.major}{device_properties.minor} compute capability, "
f"{device_properties.total_memory/1024**3:.1f}GB memory"
)
for model in model_configs.keys():
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
window_size=config.window_size,
pad_between_seqs=pad_between_seqs,
)
flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends
print(
f'Running {model} with {"cuDNN attention" if fused_attn_supported else ""}'
f'{" and flash-attention" if flash_attn_supported else ""}...'
)
prof_cmd = [
"nsys",
"profile",
"--capture-range=cudaProfilerApi",
"--capture-range-end=stop-shutdown",
"--force-overwrite=true",
f"--output=prof_{model}",
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.benchmark_dot_product_attention("""
f"""'{model}', {fused_attn_supported}, {flash_attn_supported})" """,
]
prof_cmd = " ".join(prof_cmd)
subprocess.call(prof_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
stats_cmd = [
"nsys",
"stats",
"-q",
"-r",
"cuda_gpu_trace",
"--format",
"csv,column",
"--force-overwrite=true",
"--force-export=true",
f"--output=prof_{model}",
f"prof_{model}.nsys-rep",
]
if fused_attn_supported:
num_kernels_cudnn = 4
if config.attn_bias_type == "post_scale_bias":
num_kernels_cudnn = num_kernels_cudnn + 1
if config.num_heads != config.num_gqa_groups:
num_kernels_cudnn = num_kernels_cudnn + 2
else:
num_kernels_cudnn = 0
num_kernels_flash = 4 if flash_attn_supported else 0
stats_cmd = " ".join(stats_cmd)
subprocess.call(stats_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
parse_cmd = [
"python",
"-c",
f""" "import benchmark_attention;""",
f"""benchmark_attention.parse_results("""
f"""{num_kernels_cudnn}, {num_kernels_flash}, '{model}')" """,
]
parse_cmd = " ".join(parse_cmd)
subprocess.call(parse_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, shell=True)
df_times = pd.read_csv("times.csv")
df_times.index = list(model_configs.keys())
a = df_times[
[
"FusedAttention Kernels (fwd+bwd)",
"FlashAttention Kernels (fwd+bwd)",
"Fused vs Flash Kernels Speedup (fwd+bwd)",
]
]
a.columns = ["cuDNN fwd+bwd (ms)", "flash-attn fwd+bwd (ms)", "cuDNN vs flash speedup"]
print()
print(a)
if __name__ == "__main__":
main()
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Build related infrastructure."""
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import ctypes
import os
import subprocess
import sys
import sysconfig
import copy
import time
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type
import setuptools
from .utils import (
cmake_bin,
debug_build_enabled,
found_ninja,
get_frameworks,
cuda_path,
get_max_jobs_for_parallel_build,
)
class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
def __init__(
self,
name: str,
cmake_path: Path,
cmake_flags: Optional[List[str]] = None,
) -> None:
super().__init__(name, sources=[]) # No work for base class
self.cmake_path: Path = cmake_path
self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
# Make sure paths are str
_cmake_bin = str(cmake_bin())
cmake_path = str(self.cmake_path)
build_dir = str(build_dir)
install_dir = str(install_dir)
# CMake configure command
build_type = "Debug" if debug_build_enabled() else "Release"
configure_command = [
_cmake_bin,
"-S",
cmake_path,
"-B",
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
configure_command += self.cmake_flags
import pybind11
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
# CMake build and install commands
build_command = [_cmake_bin, "--build", build_dir, "--verbose"]
install_command = [_cmake_bin, "--install", build_dir, "--verbose"]
# Check whether parallel build is restricted
max_jobs = get_max_jobs_for_parallel_build()
if found_ninja():
configure_command.append("-GNinja")
build_command.append("--parallel")
if max_jobs > 0:
build_command.append(str(max_jobs))
# Run CMake commands
start_time = time.perf_counter()
for command in [configure_command, build_command, install_command]:
print(f"Running command {' '.join(command)}")
try:
subprocess.run(command, cwd=build_dir, check=True)
except (CalledProcessError, OSError) as e:
raise RuntimeError(f"Error when running CMake: {e}")
total_time = time.perf_counter() - start_time
print(f"Time for build_ext: {total_time:.2f} seconds")
def get_build_ext(extension_cls: Type[setuptools.Extension], install_so_in_wheel_lib: bool = False):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
def run(self) -> None:
# Build CMake extensions
for ext in self.extensions:
package_path = Path(self.get_ext_fullpath(ext.name))
install_dir = package_path.resolve().parent
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
build_dir = os.getenv("NVTE_CMAKE_BUILD_DIR")
if build_dir:
build_dir = Path(build_dir).resolve()
else:
root_dir = Path(__file__).resolve().parent.parent
build_dir = root_dir / "build" / "cmake"
# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
ext._build_cmake(
build_dir=build_dir,
install_dir=install_dir,
)
# Build non-CMake extensions as usual
all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
super().run()
self.extensions = all_extensions
# Ensure that binaries are not in global package space.
lib_dir = (
"wheel_lib"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or install_so_in_wheel_lib
else ""
)
target_dir = install_dir / "transformer_engine" / lib_dir
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
def build_extensions(self):
# BuildExtensions from PyTorch already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
if "pytorch" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict):
for target in ["cxx", "nvcc"]:
if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
self.compiler.src_extensions += [".cu", ".cuh"]
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so
try:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in [".cu", ".cuh"]:
self.compiler.set_executable("compiler_so", str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags["nvcc"]
# Add -fPIC if not already specified
if not any("-fPIC" in flag for flag in cflags):
cflags.extend(["--compiler-options", "'-fPIC'"])
# Forward unknown options
if not any("--forward-unknown-opts" in flag for flag in cflags):
cflags.append("--forward-unknown-opts")
elif isinstance(cflags, dict):
cflags = cflags["cxx"]
# Append -std=c++17 if not already in flags
if not any(flag.startswith("-std=") for flag in cflags):
cflags.append("-std=c++17")
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable("compiler_so", original_compiler)
self.compiler._compile = _compile_fn
super().build_extensions()
return _CMakeBuildExtension
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""JAX related extensions."""
import os
from pathlib import Path
import setuptools
from glob import glob
from .utils import cuda_path, all_files_in_dir
from typing import List
def xla_path() -> str:
"""XLA root path lookup.
Throws FileNotFoundError if XLA source is not found."""
try:
from jax.extend import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
else:
xla_home = "/opt/xla"
else:
xla_home = ffi.include_dir()
if not os.path.isdir(xla_home):
raise FileNotFoundError("Could not find xla source.")
return xla_home
def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "utils.cu",
] + all_files_in_dir(extensions_dir, ".cpp")
# Header files
cuda_home, _ = cuda_path()
xla_home = xla_path()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
xla_home,
]
# Compile flags
cxx_flags = ["-O3"]
nvcc_flags = ["-O3"]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop("cxx", [])
cxx_flags += flags
self.extra_compile_args["cxx"] = cxx_flags
else:
self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={"cxx": cxx_flags, "nvcc": nvcc_flags},
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch related extensions."""
import os
from pathlib import Path
import setuptools
from .utils import (
all_files_in_dir,
cuda_archs,
cuda_version,
)
def setup_pytorch_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = [
"-O3",
"-fvisibility=hidden",
]
nvcc_flags = [
"-O3",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
cuda_architectures = cuda_archs()
if "70" in cuda_architectures:
nvcc_flags.extend(["-gencode", "arch=compute_70,code=sm_70"])
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version < (12, 0):
raise RuntimeError("Transformer Engine requires CUDA 12.0 or newer")
nvcc_flags.extend(
(
"--threads",
os.getenv("NVTE_BUILD_THREADS_PER_JOB", "1"),
)
)
for arch in cuda_architectures.split(";"):
if arch == "70":
continue # Already handled
nvcc_flags.extend(["-gencode", f"arch=compute_{arch},code=sm_{arch}"])
if bool(int(os.getenv("NVTE_UB_WITH_MPI", "0"))):
assert (
os.getenv("MPI_HOME") is not None
), "MPI_HOME=/path/to/mpi must be set when compiling with NVTE_UB_WITH_MPI=1!"
mpi_path = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_path / "include")
cxx_flags.append("-DNVTE_UB_WITH_MPI")
nvcc_flags.append("-DNVTE_UB_WITH_MPI")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_torch",
sources=[str(src) for src in sources],
include_dirs=[str(inc) for inc in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine version string."""
import os
from pathlib import Path
import subprocess
def te_version() -> str:
"""Transformer Engine version string
Includes Git commit as local version, unless suppressed with
NVTE_NO_LOCAL_VERSION environment variable.
"""
root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip()
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")) and not bool(
int(os.getenv("NVTE_RELEASE_BUILD", "0"))
):
try:
output = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
capture_output=True,
cwd=root_path,
check=True,
universal_newlines=True,
)
except (subprocess.CalledProcessError, OSError):
pass
else:
commit = output.stdout.strip()
version += f"+{commit}"
return version
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import functools
import glob
import importlib
import os
import re
import shutil
import subprocess
import sys
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union
@functools.lru_cache(maxsize=None)
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
for arg in sys.argv:
if arg == "--debug":
sys.argv.remove(arg)
return True
if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
return True
return False
@functools.lru_cache(maxsize=None)
def get_max_jobs_for_parallel_build() -> int:
"""Number of parallel jobs for Nina build"""
# Default: maximum parallel jobs
num_jobs = 0
# Check environment variable
if os.getenv("NVTE_BUILD_MAX_JOBS"):
num_jobs = int(os.getenv("NVTE_BUILD_MAX_JOBS"))
elif os.getenv("MAX_JOBS"):
num_jobs = int(os.getenv("MAX_JOBS"))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--parallel="):
num_jobs = int(arg.replace("--parallel=", ""))
sys.argv.remove(arg)
return num_jobs
def all_files_in_dir(path, name_extension=None):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
if name_extension is not None and name_extension not in name:
continue
all_files.append(Path(dirname, name))
return all_files
def remove_dups(_list: List):
return list(set(_list))
def found_cmake() -> bool:
""" "Check if valid CMake is available
CMake 3.18 or newer is required.
"""
# Check if CMake is available
try:
_cmake_bin = cmake_bin()
except FileNotFoundError:
return False
# Query CMake for version info
output = subprocess.run(
[_cmake_bin, "--version"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"version\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
version = tuple(int(v) for v in version)
return version >= (3, 18)
def cmake_bin() -> Path:
"""Get CMake executable
Throws FileNotFoundError if not found.
"""
# Search in CMake Python package
_cmake_bin: Optional[Path] = None
try:
from cmake import CMAKE_BIN_DIR
except ImportError:
pass
else:
_cmake_bin = Path(CMAKE_BIN_DIR).resolve() / "cmake"
if not _cmake_bin.is_file():
_cmake_bin = None
# Search in path
if _cmake_bin is None:
_cmake_bin = shutil.which("cmake")
if _cmake_bin is not None:
_cmake_bin = Path(_cmake_bin).resolve()
# Return executable if found
if _cmake_bin is None:
raise FileNotFoundError("Could not find CMake executable")
return _cmake_bin
def found_ninja() -> bool:
""" "Check if Ninja is available"""
return shutil.which("ninja") is not None
def found_pybind11() -> bool:
""" "Check if pybind11 is available"""
# Check if Python package is installed
try:
import pybind11
except ImportError:
pass
else:
return True
# Check if CMake can find pybind11
if not found_cmake():
return False
try:
subprocess.run(
[
"cmake",
"--find-package",
"-DMODE=EXIST",
"-DNAME=pybind11",
"-DCOMPILER_ID=CXX",
"-DLANGUAGE=CXX",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
except (CalledProcessError, OSError):
pass
else:
return True
return False
@functools.lru_cache(maxsize=None)
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc"))
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin
@functools.lru_cache(maxsize=None)
def cuda_archs() -> str:
version = cuda_version()
if os.getenv("NVTE_CUDA_ARCHS") is None:
os.environ["NVTE_CUDA_ARCHS"] = (
"70;80;89;90;100;120" if version >= (12, 8) else "70;80;89;90"
)
return os.getenv("NVTE_CUDA_ARCHS")
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
_, nvcc_bin = cuda_path()
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
return tuple(int(v) for v in version)
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
_frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--framework="):
_frameworks.extend(arg.replace("--framework=", "").split(","))
sys.argv.remove(arg)
# Detect installed frameworks if not explicitly specified
if not _frameworks:
try:
import torch
except ImportError:
pass
else:
_frameworks.append("pytorch")
try:
import jax
except ImportError:
pass
else:
_frameworks.append("jax")
# Special framework names
if "all" in _frameworks:
_frameworks = supported_frameworks.copy()
if "none" in _frameworks:
_frameworks = []
# Check that frameworks are valid
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(f"Transformer Engine does not support framework={framework}")
return _frameworks
def copy_common_headers(
src_dir: Union[Path, str],
dst_dir: Union[Path, str],
) -> None:
"""Copy headers from core library
src_dir should be the transformer_engine directory within the root
Transformer Engine repository. All .h and .cuh files within
transformer_engine/common are copied into dst_dir. Relative paths
are preserved.
"""
# Find common header files in src dir
headers = glob.glob(
os.path.join(str(src_dir), "common", "**", "*.h"),
recursive=True,
)
headers.extend(
glob.glob(
os.path.join(str(src_dir), "common", "**", "*.cuh"),
recursive=True,
)
)
headers = [Path(path) for path in headers]
# Copy common header files to dst dir
src_dir = Path(src_dir)
dst_dir = Path(dst_dir)
for path in headers:
new_path = dst_dir / path.relative_to(src_dir)
new_path.parent.mkdir(exist_ok=True, parents=True)
shutil.copy(path, new_path)
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
main_package = package.split("[")[0]
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_wheel_packages():
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"uninstall",
"-y",
"transformer_engine_cu12",
"transformer_engine_torch",
"transformer_engine_jax",
]
)
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_aarch64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="aarch64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"]
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
FROM quay.io/pypa/manylinux_2_28_x86_64
WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/
ARG VER="12-3"
ARG ARCH="x86_64"
RUN dnf -y install vim
# Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \
cuda-libraries-${VER}.${ARCH} \
cuda-libraries-devel-${VER}.${ARCH}
RUN dnf -y install --allowerasing cudnn9-cuda-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit
RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64
ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
ENV CUDA_HOME=/usr/local/cuda
ENV CUDA_ROOT=/usr/local/cuda
ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"]
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
set -e
PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true}
export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-}
mkdir -p /wheelhouse/logs
# Generate wheels for common library.
git config --global --add safe.directory /TransformerEngine
cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive
if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
mv dist/* /wheelhouse/
fi
if $BUILD_COMMON ; then
VERSION=`cat build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Create the wheel.
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
# Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp38-cp38/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}
# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target"
fi
if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp38-cp38/bin/pip install torch
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/
fi
if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/
fi
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
rm -rf aarch_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
rm -rf x86_wheelhouse
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
_build
doxygen
sphinx_rtd_theme
\ No newline at end of file
This diff is collapsed.
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line, and also
# from the environment for the first two.
SPHINXOPTS ?=
SPHINXBUILD ?= sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile sphinx_rtd_theme
PYTHONPATH=sphinx_rtd_theme:$(PYTHONPATH) $(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
# Patch Sphinx RTD theme 3.0.1 to add version selector in sidebar
sphinx_rtd_theme:
git clone --depth=1 -b 3.0.1 --single-branch https://github.com/readthedocs/sphinx_rtd_theme.git
bash -c "cd sphinx_rtd_theme; git apply ../version_select.patch"
<svg id="NVIDIA_Logo_V" data-name="NVIDIA Logo V" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 1211.808 415.949"><defs><style>.cls-1{fill:none;}</style></defs><title>NVIDIA-LogoBlack</title><path id="Reg" d="M1080.665,262.245v-2.692h1.729c.944,0,2.229.07,2.229,1.224,0,1.246-.662,1.468-1.775,1.468h-2.183m0,1.892h1.155l2.68,4.7h2.939l-2.962-4.9a2.658,2.658,0,0,0,2.793-2.905c0-2.563-1.771-3.389-4.762-3.389h-4.328v11.192h2.485v-4.7m12.588-.876c0-6.573-5.108-10.386-10.8-10.386-5.73,0-10.833,3.813-10.833,10.386s5.1,10.395,10.833,10.395c5.69,0,10.8-3.826,10.8-10.395m-3.115,0a7.672,7.672,0,0,1-7.683,8v-.035a7.984,7.984,0,1,1,7.683-7.968Z"/><path id="NVIDIA" d="M696.8,152.076l.011,117.957h33.313V152.078Zm-262.063-.16V270.033h33.61V178.346l26.218.088c8.625,0,14.586,2.066,18.743,6.5,5.269,5.616,7.42,14.667,7.42,31.233v53.865h32.564v-65.26c0-46.576-29.689-52.857-58.734-52.857Zm315.7.164V270.033h54.034c28.789,0,38.183-4.787,48.345-15.521,7.184-7.537,11.825-24.08,11.825-42.158,0-16.581-3.928-31.372-10.784-40.583-12.339-16.47-30.121-19.691-56.666-19.691Zm33.045,25.684h14.325c20.779,0,34.218,9.332,34.218,33.545s-13.439,33.548-34.218,33.548H783.484ZM648.77,152.08l-27.8,93.484-26.641-93.478-35.961-.006,38.047,117.953h48.014L682.771,152.08ZM880.145,270.033h33.318V152.086l-33.326-.006Zm93.386-117.91L927.014,269.992h32.849l7.36-20.832h55.05l6.967,20.832H1064.9l-46.873-117.879Zm21.625,21.5,20.18,55.221h-41Z"/><path id="Eye_Mark" data-name="Eye Mark" d="M219.887,171.742V155.509c1.576-.113,3.168-.2,4.79-.247,44.4-1.4,73.527,38.149,73.527,38.149s-31.46,43.7-65.191,43.7a40.916,40.916,0,0,1-13.126-2.1V185.783c17.285,2.088,20.759,9.723,31.154,27.044l23.111-19.486s-16.87-22.127-45.309-22.127a83.962,83.962,0,0,0-8.956.528m0-53.625v24.248c1.593-.126,3.189-.227,4.79-.285,61.744-2.08,101.968,50.637,101.968,50.637s-46.2,56.183-94.337,56.183a71.1,71.1,0,0,1-12.421-1.093V262.8a81.731,81.731,0,0,0,10.343.67c44.795,0,77.188-22.874,108.557-49.949,5.2,4.164,26.49,14.294,30.869,18.734-29.827,24.967-99.333,45.091-138.737,45.091-3.8,0-7.449-.23-11.032-.573v21.064H390.141V118.117Zm0,116.892v12.8c-41.43-7.387-52.929-50.454-52.929-50.454s19.892-22.04,52.929-25.611v14.041c-.026,0-.042-.007-.065-.007-17.336-2.082-30.882,14.117-30.882,14.117s7.589,27.268,30.947,35.116M146.3,195.487s24.555-36.232,73.584-39.978V142.365c-54.305,4.359-101.332,50.352-101.332,50.352s26.634,77,101.332,84.051V262.8C165.071,255.9,146.3,195.487,146.3,195.487Z"/><rect class="cls-1" width="1211.808" height="415.949"/></svg>
\ No newline at end of file
@font-face {
font-family: NVIDIA;
font-style: normal;
font-weight: 300;
src: url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Lt.woff) format("woff"),url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Lt.woff2) format("woff2")
}
@font-face {
font-family: NVIDIA;
font-style: normal;
font-weight: 400;
src: url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Rg.woff) format("woff"),url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Rg.woff2) format("woff2")
}
@font-face {
font-family: NVIDIA;
font-style: normal;
font-weight: 500;
src: url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Md.woff) format("woff"),url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Md.woff2) format("woff2")
}
@font-face {
font-family: NVIDIA;
font-style: normal;
font-weight: 700;
src: url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Bd.woff) format("woff"),url(https://images.nvidia.com/etc/designs/nvidiaGDC/clientlibs_base/fonts/nvidia-sans/GLOBAL/NVIDIASans_W_Bd.woff2) format("woff2")
}
body {
font-family: NVIDIA,Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;
}
.rst-content .toctree-wrapper>p.caption,h1,h2,h3,h4,h5,h6,legend {
margin-top: 0;
font-weight: 700;
font-family: NVIDIA,Roboto Slab,ff-tisa-web-pro,Georgia,Arial,sans-serif
}
input[type=color],input[type=date],input[type=datetime-local],input[type=datetime],input[type=email],input[type=month],input[type=number],input[type=password],input[type=search],input[type=tel],input[type=text],input[type=time],input[type=url],input[type=week] {
font-family: NVIDIA,Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;
}
select {
font-family: NVIDIA,Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;
}
.btn {
font-family: NVIDIA,Lato,proxima-nova,Helvetica Neue,Arial,sans-serif;
}
html.writer-html4 .rst-content dl:not(.docutils) .descclassname,html.writer-html4 .rst-content dl:not(.docutils) .descname,html.writer-html4 .rst-content dl:not(.docutils) .sig-name,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descclassname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .descname,html.writer-html5 .rst-content dl[class]:not(.option-list):not(.field-list):not(.footnote):not(.glossary):not(.simple) .sig-name {
font-family: NVIDIA,SFMono-Regular,Menlo,Monaco,Consolas,Liberation Mono,Courier New,Courier,monospace;
}
footer img {
display: block;
width: 137.5px;
position: relative;
left: -9px;
margin: 0 0 15px 0;
}
footer p {
color: #666666;
font-weight: normal;
font-size: 12px;
line-height: 1.25em;
}
footer p:not(.notices) {
display: inline;
margin: 0;
}
footer p a,
footer p a:link,
footer p a:visited {
color: #666666;
}
footer p a:hover {
color: #666666;
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment