Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
......@@ -9,7 +9,7 @@
*.ncu-rep
*.sqlite
*.onnx
.eggs
*.eggs
build/
*.so
*.egg-info
......@@ -27,3 +27,15 @@ docs/_build
docs/doxygen
*.log
CMakeFiles/CMakeSystem.cmake
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
develop-eggs/
dist/
downloads/
.pytest_cache/
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Build related infrastructure."""
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import ctypes
import os
import subprocess
import sys
import sysconfig
import copy
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type
import setuptools
from .utils import (
cmake_bin,
debug_build_enabled,
found_ninja,
get_frameworks,
cuda_path,
)
class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
def __init__(
self,
name: str,
cmake_path: Path,
cmake_flags: Optional[List[str]] = None,
) -> None:
super().__init__(name, sources=[]) # No work for base class
self.cmake_path: Path = cmake_path
self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
# Make sure paths are str
_cmake_bin = str(cmake_bin())
cmake_path = str(self.cmake_path)
build_dir = str(build_dir)
install_dir = str(install_dir)
# CMake configure command
build_type = "Debug" if debug_build_enabled() else "Release"
configure_command = [
_cmake_bin,
"-S",
cmake_path,
"-B",
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
configure_command += self.cmake_flags
if found_ninja():
configure_command.append("-GNinja")
import pybind11
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
# CMake build and install commands
build_command = [_cmake_bin, "--build", build_dir]
install_command = [_cmake_bin, "--install", build_dir]
# Run CMake commands
for command in [configure_command, build_command, install_command]:
print(f"Running command {' '.join(command)}")
try:
subprocess.run(command, cwd=build_dir, check=True)
except (CalledProcessError, OSError) as e:
raise RuntimeError(f"Error when running CMake: {e}")
def get_build_ext(extension_cls: Type[setuptools.Extension]):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
def run(self) -> None:
# Build CMake extensions
for ext in self.extensions:
package_path = Path(self.get_ext_fullpath(ext.name))
install_dir = package_path.resolve().parent
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
setup_dir = Path(__file__).resolve().parent
build_dir = setup_dir / "build" / "cmake"
# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
ext._build_cmake(
build_dir=build_dir,
install_dir=install_dir,
)
# Build non-CMake extensions as usual
all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
super().run()
self.extensions = all_extensions
paddle_ext = None
if "paddle" in get_frameworks():
for ext in self.extensions:
if "paddle" in ext.name:
paddle_ext = ext
break
# Manually write stub file for Paddle extension
if paddle_ext is not None:
# Load libtransformer_engine.so to avoid linker errors
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Source compilation from top-level (--editable)
search_paths = list(Path(__file__).resolve().parent.parent.parent.iterdir())
# Source compilation from top-level
search_paths.extend(list(Path(self.build_lib).iterdir()))
else:
# Only during release sdist build.
import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine
common_so_path = ""
for path in search_paths:
if path.name.startswith("libtransformer_engine."):
common_so_path = str(path)
assert common_so_path, "Could not find libtransformer_engine"
ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL)
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \
"Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
# Figure out library name
# Note: This library doesn't actually exist. Paddle
# internally reinserts the '_pd_' suffix.
so_path = self.get_ext_fullpath(module_name)
_, so_ext = os.path.splitext(so_path)
lib_name = stub_name + so_ext
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
# For paddle, the stub file needs to be copied to the install location.
if paddle_ext is not None:
stub_path = Path(self.build_lib) / "transformer_engine"
for stub in stub_path.glob("transformer_engine_paddle.py"):
self.copy_file(stub, target_dir)
def build_extensions(self):
# BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict):
for target in ['cxx', 'nvcc']:
if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
self.compiler.src_extensions += ['.cu', '.cuh']
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so
try:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in ['.cu', '.cuh']:
self.compiler.set_executable('compiler_so', str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags['nvcc']
# Add -fPIC if not already specified
if not any('-fPIC' in flag for flag in cflags):
cflags.extend(['--compiler-options', "'-fPIC'"])
# Forward unknown options
if not any('--forward-unknown-opts' in flag for flag in cflags):
cflags.append('--forward-unknown-opts')
elif isinstance(cflags, dict):
cflags = cflags['cxx']
# Append -std=c++17 if not already in flags
if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++17')
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable('compiler_so', original_compiler)
self.compiler._compile = _compile_fn
super().build_extensions()
return _CMakeBuildExtension
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
from pathlib import Path
import setuptools
from .utils import cuda_path
from typing import List
def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cpp",
csrc_source_files / "modules.cpp",
csrc_source_files / "utils.cu",
]
# Header files
cuda_home, _ = cuda_path()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compile flags
cxx_flags = [ "-O3" ]
nvcc_flags = [ "-O3" ]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop('cxx', [])
cxx_flags += flags
self.extra_compile_args['cxx'] = cxx_flags
else:
self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags
},
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
from pathlib import Path
import setuptools
from .utils import cuda_version
def setup_paddle_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cu",
csrc_source_files / "common.cpp",
csrc_source_files / "custom_ops.cu",
]
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Construct Paddle CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from paddle.utils.cpp_extension import CUDAExtension
ext = CUDAExtension(
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
ext.name = "transformer_engine_paddle_pd_"
return ext
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch related extensions."""
import os
from pathlib import Path
import setuptools
from .utils import (
all_files_in_dir,
cuda_version,
userbuffers_enabled,
)
def setup_pytorch_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# userbuffers support
if userbuffers_enabled():
if os.getenv("MPI_HOME"):
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_torch",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
......@@ -7,6 +7,7 @@ import os
from pathlib import Path
import subprocess
def te_version() -> str:
"""Transformer Engine version string
......@@ -15,9 +16,10 @@ def te_version() -> str:
"""
root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION", "r") as f:
with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip()
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")):
if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))
and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))):
try:
output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"],
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import os
import re
import glob
import shutil
import subprocess
import sys
from functools import cache
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
@cache
def userbuffers_enabled() -> bool:
"""Check if userbuffers support is enabled"""
if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
assert os.getenv("MPI_HOME"), "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
return True
return False
@cache
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
for arg in sys.argv:
if arg == "--debug":
sys.argv.remove(arg)
return True
if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
return True
return False
def all_files_in_dir(path):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
all_files.append(Path(dirname, name))
return all_files
def remove_dups(_list: List):
return list(set(_list))
def found_cmake() -> bool:
""" "Check if valid CMake is available
CMake 3.18 or newer is required.
"""
# Check if CMake is available
try:
_cmake_bin = cmake_bin()
except FileNotFoundError:
return False
# Query CMake for version info
output = subprocess.run(
[_cmake_bin, "--version"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"version\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
version = tuple(int(v) for v in version)
return version >= (3, 18)
def cmake_bin() -> Path:
"""Get CMake executable
Throws FileNotFoundError if not found.
"""
# Search in CMake Python package
_cmake_bin: Optional[Path] = None
try:
from cmake import CMAKE_BIN_DIR
except ImportError:
pass
else:
_cmake_bin = Path(CMAKE_BIN_DIR).resolve() / "cmake"
if not _cmake_bin.is_file():
_cmake_bin = None
# Search in path
if _cmake_bin is None:
_cmake_bin = shutil.which("cmake")
if _cmake_bin is not None:
_cmake_bin = Path(_cmake_bin).resolve()
# Return executable if found
if _cmake_bin is None:
raise FileNotFoundError("Could not find CMake executable")
return _cmake_bin
def found_ninja() -> bool:
""" "Check if Ninja is available"""
return shutil.which("ninja") is not None
def found_pybind11() -> bool:
""" "Check if pybind11 is available"""
# Check if Python package is installed
try:
import pybind11
except ImportError:
pass
else:
return True
# Check if CMake can find pybind11
if not found_cmake():
return False
try:
subprocess.run(
[
"cmake",
"--find-package",
"-DMODE=EXIST",
"-DNAME=pybind11",
"-DCOMPILER_ID=CXX",
"-DLANGUAGE=CXX",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
except (CalledProcessError, OSError):
pass
else:
return True
return False
@cache
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc"))
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
_, nvcc_bin = cuda_path()
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split('.')
return tuple(int(v) for v in version)
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax", "paddle"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
_frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--framework="):
_frameworks.extend(arg.replace("--framework=", "").split(","))
sys.argv.remove(arg)
# Detect installed frameworks if not explicitly specified
if not _frameworks:
try:
import torch
except ImportError:
pass
else:
_frameworks.append("pytorch")
try:
import jax
except ImportError:
pass
else:
_frameworks.append("jax")
try:
import paddle
except ImportError:
pass
else:
_frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
_frameworks = supported_frameworks.copy()
if "none" in _frameworks:
_frameworks = []
# Check that frameworks are valid
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(
f"Transformer Engine does not support framework={framework}"
)
return _frameworks
def package_files(directory):
paths = []
for path, _, filenames in os.walk(directory):
path = Path(path)
for filename in filenames:
paths.append(str(path / filename).replace(f"{directory}/", ""))
return paths
def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
import importlib
try:
importlib.import_module(package)
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
finally:
globals()[package] = importlib.import_module(package)
......@@ -17,7 +17,7 @@ from datetime import date
te_path = os.path.dirname(os.path.realpath(__file__))
with open(te_path + "/../VERSION", "r") as f:
with open(te_path + "/../build_tools/VERSION.txt", "r") as f:
te_version = f.readline().strip()
release_year = 2022
......
......@@ -204,7 +204,7 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
def cast_to_representable(inp, scale = 1., fp8_format='e4m3'):
import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype]
......
[MASTER]
extension-pkg-whitelist=transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals,
invalid-name,
too-many-arguments,
......
......@@ -4,6 +4,7 @@
set -xe
pip install pytest==7.2
: ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
......
......@@ -4,6 +4,7 @@
set -xe
pip install pytest==7.2
: ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist
[MASTER]
extension-pkg-whitelist=torch,
transformer_engine_extensions
transformer_engine_torch
disable=too-many-locals,
too-many-public-methods,
......
......@@ -14,7 +14,7 @@ then
echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine
cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine
cpplint --recursive transformer_engine/pytorch
fi
if [ -z "${CPP_ONLY}" ]
......
......@@ -4,241 +4,65 @@
"""Installation script."""
import ctypes
from functools import lru_cache
import os
from pathlib import Path
import re
import shutil
import subprocess
from subprocess import CalledProcessError
import sys
import sysconfig
from typing import List, Optional, Tuple, Union
from typing import List, Tuple
import setuptools
from setuptools.command.build_ext import build_ext
from te_version import te_version
from build_tools.build_ext import CMakeExtension, get_build_ext
from build_tools.utils import (
found_cmake,
found_ninja,
found_pybind11,
remove_dups,
userbuffers_enabled,
get_frameworks,
install_and_import,
)
from build_tools.te_version import te_version
# Project directory root
root_path: Path = Path(__file__).resolve().parent
@lru_cache(maxsize=1)
def with_debug_build() -> bool:
"""Whether to build with a debug configuration"""
for arg in sys.argv:
if arg == "--debug":
sys.argv.remove(arg)
return True
if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
return True
return False
frameworks = get_frameworks()
current_file_path = Path(__file__).parent.resolve()
# Call once in global scope since this function manipulates the
# command-line arguments. Future calls will use a cached value.
with_debug_build()
def found_cmake() -> bool:
""""Check if valid CMake is available
CMake 3.18 or newer is required.
"""
# Check if CMake is available
try:
_cmake_bin = cmake_bin()
except FileNotFoundError:
return False
# Query CMake for version info
output = subprocess.run(
[_cmake_bin, "--version"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"version\s*([\d.]+)", output.stdout)
version = match.group(1).split('.')
version = tuple(int(v) for v in version)
return version >= (3, 18)
def cmake_bin() -> Path:
"""Get CMake executable
Throws FileNotFoundError if not found.
"""
# Search in CMake Python package
_cmake_bin: Optional[Path] = None
try:
import cmake
except ImportError:
pass
else:
cmake_dir = Path(cmake.__file__).resolve().parent
_cmake_bin = cmake_dir / "data" / "bin" / "cmake"
if not _cmake_bin.is_file():
_cmake_bin = None
# Search in path
if _cmake_bin is None:
_cmake_bin = shutil.which("cmake")
if _cmake_bin is not None:
_cmake_bin = Path(_cmake_bin).resolve()
# Return executable if found
if _cmake_bin is None:
raise FileNotFoundError("Could not find CMake executable")
return _cmake_bin
def found_ninja() -> bool:
""""Check if Ninja is available"""
return shutil.which("ninja") is not None
from setuptools.command.build_ext import build_ext as BuildExtension
if "pytorch" in frameworks:
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks:
from paddle.utils.cpp_extension import BuildExtension
elif "jax" in frameworks:
install_and_import('pybind11')
from pybind11.setup_helpers import build_ext as BuildExtension
def found_pybind11() -> bool:
""""Check if pybind11 is available"""
# Check if Python package is installed
try:
import pybind11
except ImportError:
pass
else:
return True
CMakeBuildExtension = get_build_ext(BuildExtension)
# Check if CMake can find pybind11
if not found_cmake():
return False
try:
subprocess.run(
[
"cmake",
"--find-package",
"-DMODE=EXIST",
"-DNAME=pybind11",
"-DCOMPILER_ID=CXX",
"-DLANGUAGE=CXX",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
except (CalledProcessError, OSError):
pass
else:
return True
return False
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library
Throws FileNotFoundError if NVCC is not found.
Also builds JAX or userbuffers support if needed.
"""
cmake_flags = []
if userbuffers_enabled():
cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
# Query NVCC for version info
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split('.')
return tuple(int(v) for v in version)
@lru_cache(maxsize=1)
def with_userbuffers() -> bool:
"""Check if userbuffers support is enabled"""
if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
assert os.getenv("MPI_HOME"), \
"MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
return True
return False
@lru_cache(maxsize=1)
def frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax", "paddle"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
_frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--framework="):
_frameworks.extend(arg.replace("--framework=", "").split(","))
sys.argv.remove(arg)
# Detect installed frameworks if not explicitly specified
if not _frameworks:
try:
import torch
except ImportError:
pass
else:
_frameworks.append("pytorch")
try:
import jax
except ImportError:
pass
else:
_frameworks.append("jax")
try:
import paddle
except ImportError:
pass
else:
_frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
_frameworks = supported_frameworks.copy()
if "none" in _frameworks:
_frameworks = []
# Check that frameworks are valid
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(
f"Transformer Engine does not support framework={framework}"
# Project directory root
root_path = Path(__file__).resolve().parent
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / Path("transformer_engine"),
cmake_flags=cmake_flags,
)
return _frameworks
# Call once in global scope since this function manipulates the
# command-line arguments. Future calls will use a cached value.
frameworks()
def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
"""Setup Python dependencies
Returns dependencies for build, runtime, and testing.
"""
# Common requirements
......@@ -250,373 +74,65 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
]
test_reqs: List[str] = ["pytest"]
def add_unique(l: List[str], vals: Union[str, List[str]]) -> None:
"""Add entry to list if not already included"""
if isinstance(vals, str):
vals = [vals]
for val in vals:
if val not in l:
l.append(val)
# Requirements that may be installed outside of Python
if not found_cmake():
add_unique(setup_reqs, "cmake>=3.18")
setup_reqs.append("cmake>=3.18")
if not found_ninja():
add_unique(setup_reqs, "ninja")
# Framework-specific requirements
if "pytorch" in frameworks():
add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.8,!=2.0.9,!=2.1.0"])
add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"])
if "jax" in frameworks():
setup_reqs.append("ninja")
if not found_pybind11():
add_unique(setup_reqs, "pybind11")
add_unique(install_reqs, ["jax", "flax>=0.7.1"])
add_unique(test_reqs, ["numpy", "praxis"])
if "paddle" in frameworks():
add_unique(install_reqs, "paddlepaddle-gpu")
add_unique(test_reqs, "numpy")
return setup_reqs, install_reqs, test_reqs
class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
def __init__(
self,
name: str,
cmake_path: Path,
cmake_flags: Optional[List[str]] = None,
) -> None:
super().__init__(name, sources=[]) # No work for base class
self.cmake_path: Path = cmake_path
self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
setup_reqs.append("pybind11")
def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
return [remove_dups(reqs) for reqs in [setup_reqs, install_reqs, test_reqs]]
# Make sure paths are str
_cmake_bin = str(cmake_bin())
cmake_path = str(self.cmake_path)
build_dir = str(build_dir)
install_dir = str(install_dir)
# CMake configure command
build_type = "Debug" if with_debug_build() else "Release"
configure_command = [
_cmake_bin,
"-S",
cmake_path,
"-B",
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
configure_command += self.cmake_flags
if found_ninja():
configure_command.append("-GNinja")
try:
import pybind11
except ImportError:
pass
else:
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
# CMake build and install commands
build_command = [_cmake_bin, "--build", build_dir]
install_command = [_cmake_bin, "--install", build_dir]
# Run CMake commands
for command in [configure_command, build_command, install_command]:
print(f"Running command {' '.join(command)}")
try:
subprocess.run(command, cwd=build_dir, check=True)
except (CalledProcessError, OSError) as e:
raise RuntimeError(f"Error when running CMake: {e}")
# PyTorch extension modules require special handling
if "pytorch" in frameworks():
from torch.utils.cpp_extension import BuildExtension
elif "paddle" in frameworks():
from paddle.utils.cpp_extension import BuildExtension
else:
from setuptools.command.build_ext import build_ext as BuildExtension
class CMakeBuildExtension(BuildExtension):
"""Setuptools command with support for CMake extension modules"""
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def run(self) -> None:
# Build CMake extensions
for ext in self.extensions:
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
setup_dir = Path(__file__).resolve().parent
build_dir = setup_dir / "build" / "cmake"
build_dir.mkdir(parents=True, exist_ok=True) # Ensure the directory exists
package_path = Path(self.get_ext_fullpath(ext.name))
install_dir = package_path.resolve().parent
ext._build_cmake(
build_dir=build_dir,
install_dir=install_dir,
)
# Paddle requires linker search path for libtransformer_engine.so
paddle_ext = None
if "paddle" in frameworks():
for ext in self.extensions:
if "paddle" in ext.name:
ext.library_dirs.append(self.build_lib)
paddle_ext = ext
break
# Build non-CMake extensions as usual
all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions
if not isinstance(ext, CMakeExtension)
]
super().run()
self.extensions = all_extensions
# Manually write stub file for Paddle extension
if paddle_ext is not None:
# Load libtransformer_engine.so to avoid linker errors
for path in Path(self.build_lib).iterdir():
if path.name.startswith("libtransformer_engine."):
ctypes.CDLL(str(path), mode=ctypes.RTLD_GLOBAL)
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \
"Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, stub_name + ".py")
# Figure out library name
# Note: This library doesn't actually exist. Paddle
# internally reinserts the '_pd_' suffix.
so_path = self.get_ext_fullpath(module_name)
_, so_ext = os.path.splitext(so_path)
lib_name = stub_name + so_ext
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library
Also builds JAX or userbuffers support if needed.
"""
cmake_flags = []
if "jax" in frameworks():
cmake_flags.append("-DENABLE_JAX=ON")
if with_userbuffers():
cmake_flags.append("-DNVTE_WITH_USERBUFFERS=ON")
return CMakeExtension(
name="transformer_engine",
cmake_path=root_path / "transformer_engine",
cmake_flags=cmake_flags,
)
def _all_files_in_dir(path):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
all_files.append(Path(dirname, name))
return all_files
def setup_pytorch_extension() -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
# Source files
src_dir = root_path / "transformer_engine" / "pytorch" / "csrc"
extensions_dir = src_dir / "extensions"
sources = [
src_dir / "common.cu",
src_dir / "ts_fp8_op.cpp",
# We need to compile system.cpp because the pytorch extension uses
# transformer_engine::getenv. This is a workaround to avoid direct
# linking with libtransformer_engine.so, as the pre-built PyTorch
# wheel from conda or PyPI was not built with CXX11_ABI, and will
# cause undefined symbol issues.
root_path / "transformer_engine" / "common" / "util" / "system.cpp",
] + \
_all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine" / "pytorch" / "csrc",
root_path / "transformer_engine",
root_path / "3rdparty" / "cudnn-frontend" / "include",
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# userbuffers support
if with_userbuffers():
if os.getenv("MPI_HOME"):
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_extensions",
sources=sources,
include_dirs=include_dirs,
# libraries=["transformer_engine"], ### TODO (tmoon) Debug linker errors
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
def setup_paddle_extension() -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
# Source files
src_dir = root_path / "transformer_engine" / "paddle" / "csrc"
sources = [
src_dir / "extensions.cu",
src_dir / "common.cpp",
src_dir / "custom_ops.cu",
]
# Header files
include_dirs = [
root_path / "transformer_engine" / "common" / "include",
root_path / "transformer_engine" / "paddle" / "csrc",
root_path / "transformer_engine",
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Construct Paddle CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from paddle.utils.cpp_extension import CUDAExtension
ext = CUDAExtension(
sources=sources,
include_dirs=include_dirs,
libraries=["transformer_engine"],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
ext.name = "transformer_engine_paddle_pd_"
return ext
def main():
# Submodules to install
packages = setuptools.find_packages(
include=["transformer_engine", "transformer_engine.*"],
)
if __name__ == "__main__":
# Dependencies
setup_requires, install_requires, test_requires = setup_requirements()
# Extensions
ext_modules = [setup_common_extension()]
if "pytorch" in frameworks():
ext_modules.append(setup_pytorch_extension())
__version__ = te_version()
if "paddle" in frameworks():
ext_modules.append(setup_paddle_extension())
ext_modules = [setup_common_extension()]
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine"))
if "jax" in frameworks:
from build_tools.jax import setup_jax_extension
ext_modules.append(
setup_jax_extension(
"transformer_engine/jax/csrc",
current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine"))
if "paddle" in frameworks:
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
setup_paddle_extension(
"transformer_engine/paddle/csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine"))
# Configure package
setuptools.setup(
name="transformer_engine",
version=te_version(),
packages=packages,
version=__version__,
packages=setuptools.find_packages(
include=["transformer_engine",
"transformer_engine.*",
"transformer_engine/build_tools"],
),
extras_require={
"test": test_requires,
},
description="Transformer acceleration library",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
setup_requires=setup_requires,
install_requires=install_requires,
extras_require={"test": test_requires},
license_files=("LICENSE",),
include_package_data=True,
package_data={"": ["VERSION.txt"]}
)
if __name__ == "__main__":
main()
......@@ -26,7 +26,7 @@ if(NOT DEFINED TE_LIB_PATH)
OUTPUT_VARIABLE TE_LIB_PATH)
endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED)
find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include)
......
......@@ -22,7 +22,7 @@ from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLay
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend
from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose
......
......@@ -15,7 +15,7 @@ import pytest
from utils import assert_allclose
from transformer_engine_jax import get_device_compute_capability
from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment