Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
*.ncu-rep *.ncu-rep
*.sqlite *.sqlite
*.onnx *.onnx
.eggs *.eggs
build/ build/
*.so *.so
*.egg-info *.egg-info
...@@ -27,3 +27,15 @@ docs/_build ...@@ -27,3 +27,15 @@ docs/_build
docs/doxygen docs/doxygen
*.log *.log
CMakeFiles/CMakeSystem.cmake CMakeFiles/CMakeSystem.cmake
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
develop-eggs/
dist/
downloads/
.pytest_cache/
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Build related infrastructure."""
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import ctypes
import os
import subprocess
import sys
import sysconfig
import copy
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Type
import setuptools
from .utils import (
cmake_bin,
debug_build_enabled,
found_ninja,
get_frameworks,
cuda_path,
)
class CMakeExtension(setuptools.Extension):
"""CMake extension module"""
def __init__(
self,
name: str,
cmake_path: Path,
cmake_flags: Optional[List[str]] = None,
) -> None:
super().__init__(name, sources=[]) # No work for base class
self.cmake_path: Path = cmake_path
self.cmake_flags: List[str] = [] if cmake_flags is None else cmake_flags
def _build_cmake(self, build_dir: Path, install_dir: Path) -> None:
# Make sure paths are str
_cmake_bin = str(cmake_bin())
cmake_path = str(self.cmake_path)
build_dir = str(build_dir)
install_dir = str(install_dir)
# CMake configure command
build_type = "Debug" if debug_build_enabled() else "Release"
configure_command = [
_cmake_bin,
"-S",
cmake_path,
"-B",
build_dir,
f"-DPython_EXECUTABLE={sys.executable}",
f"-DPython_INCLUDE_DIR={sysconfig.get_path('include')}",
f"-DCMAKE_BUILD_TYPE={build_type}",
f"-DCMAKE_INSTALL_PREFIX={install_dir}",
]
configure_command += self.cmake_flags
if found_ninja():
configure_command.append("-GNinja")
import pybind11
pybind11_dir = Path(pybind11.__file__).resolve().parent
pybind11_dir = pybind11_dir / "share" / "cmake" / "pybind11"
configure_command.append(f"-Dpybind11_DIR={pybind11_dir}")
# CMake build and install commands
build_command = [_cmake_bin, "--build", build_dir]
install_command = [_cmake_bin, "--install", build_dir]
# Run CMake commands
for command in [configure_command, build_command, install_command]:
print(f"Running command {' '.join(command)}")
try:
subprocess.run(command, cwd=build_dir, check=True)
except (CalledProcessError, OSError) as e:
raise RuntimeError(f"Error when running CMake: {e}")
def get_build_ext(extension_cls: Type[setuptools.Extension]):
class _CMakeBuildExtension(extension_cls):
"""Setuptools command with support for CMake extension modules"""
def run(self) -> None:
# Build CMake extensions
for ext in self.extensions:
package_path = Path(self.get_ext_fullpath(ext.name))
install_dir = package_path.resolve().parent
if isinstance(ext, CMakeExtension):
print(f"Building CMake extension {ext.name}")
# Set up incremental builds for CMake extensions
setup_dir = Path(__file__).resolve().parent
build_dir = setup_dir / "build" / "cmake"
# Ensure the directory exists
build_dir.mkdir(parents=True, exist_ok=True)
ext._build_cmake(
build_dir=build_dir,
install_dir=install_dir,
)
# Build non-CMake extensions as usual
all_extensions = self.extensions
self.extensions = [
ext for ext in self.extensions if not isinstance(ext, CMakeExtension)
]
super().run()
self.extensions = all_extensions
paddle_ext = None
if "paddle" in get_frameworks():
for ext in self.extensions:
if "paddle" in ext.name:
paddle_ext = ext
break
# Manually write stub file for Paddle extension
if paddle_ext is not None:
# Load libtransformer_engine.so to avoid linker errors
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Source compilation from top-level (--editable)
search_paths = list(Path(__file__).resolve().parent.parent.parent.iterdir())
# Source compilation from top-level
search_paths.extend(list(Path(self.build_lib).iterdir()))
else:
# Only during release sdist build.
import transformer_engine
search_paths = list(Path(transformer_engine.__path__[0]).iterdir())
del transformer_engine
common_so_path = ""
for path in search_paths:
if path.name.startswith("libtransformer_engine."):
common_so_path = str(path)
assert common_so_path, "Could not find libtransformer_engine"
ctypes.CDLL(common_so_path, mode=ctypes.RTLD_GLOBAL)
# Figure out stub file path
module_name = paddle_ext.name
assert module_name.endswith("_pd_"), \
"Expected Paddle extension module to end with '_pd_'"
stub_name = module_name[:-4] # remove '_pd_'
stub_path = os.path.join(self.build_lib, "transformer_engine", stub_name + ".py")
Path(stub_path).parent.mkdir(exist_ok=True, parents=True)
# Figure out library name
# Note: This library doesn't actually exist. Paddle
# internally reinserts the '_pd_' suffix.
so_path = self.get_ext_fullpath(module_name)
_, so_ext = os.path.splitext(so_path)
lib_name = stub_name + so_ext
# Write stub file
print(f"Writing Paddle stub for {lib_name} into file {stub_path}")
from paddle.utils.cpp_extension.extension_utils import custom_write_stub
custom_write_stub(lib_name, stub_path)
# Ensure that binaries are not in global package space.
target_dir = install_dir / "transformer_engine"
target_dir.mkdir(exist_ok=True, parents=True)
for ext in Path(self.build_lib).glob("*.so"):
self.copy_file(ext, target_dir)
os.remove(ext)
# For paddle, the stub file needs to be copied to the install location.
if paddle_ext is not None:
stub_path = Path(self.build_lib) / "transformer_engine"
for stub in stub_path.glob("transformer_engine_paddle.py"):
self.copy_file(stub, target_dir)
def build_extensions(self):
# BuildExtensions from PyTorch and PaddlePaddle already handle CUDA files correctly
# so we don't need to modify their compiler. Only the pybind11 build_ext needs to be fixed.
if "pytorch" not in get_frameworks() and "paddle" not in get_frameworks():
# Ensure at least an empty list of flags for 'cxx' and 'nvcc' when
# extra_compile_args is a dict.
for ext in self.extensions:
if isinstance(ext.extra_compile_args, dict):
for target in ['cxx', 'nvcc']:
if target not in ext.extra_compile_args.keys():
ext.extra_compile_args[target] = []
# Define new _compile method that redirects to NVCC for .cu and .cuh files.
original_compile_fn = self.compiler._compile
self.compiler.src_extensions += ['.cu', '.cuh']
def _compile_fn(obj, src, ext, cc_args, extra_postargs, pp_opts) -> None:
# Copy before we make any modifications.
cflags = copy.deepcopy(extra_postargs)
original_compiler = self.compiler.compiler_so
try:
_, nvcc_bin = cuda_path()
original_compiler = self.compiler.compiler_so
if os.path.splitext(src)[1] in ['.cu', '.cuh']:
self.compiler.set_executable('compiler_so', str(nvcc_bin))
if isinstance(cflags, dict):
cflags = cflags['nvcc']
# Add -fPIC if not already specified
if not any('-fPIC' in flag for flag in cflags):
cflags.extend(['--compiler-options', "'-fPIC'"])
# Forward unknown options
if not any('--forward-unknown-opts' in flag for flag in cflags):
cflags.append('--forward-unknown-opts')
elif isinstance(cflags, dict):
cflags = cflags['cxx']
# Append -std=c++17 if not already in flags
if not any(flag.startswith('-std=') for flag in cflags):
cflags.append('-std=c++17')
return original_compile_fn(obj, src, ext, cc_args, cflags, pp_opts)
finally:
# Put the original compiler back in place.
self.compiler.set_executable('compiler_so', original_compiler)
self.compiler._compile = _compile_fn
super().build_extensions()
return _CMakeBuildExtension
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
from pathlib import Path
import setuptools
from .utils import cuda_path
from typing import List
def setup_jax_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup PyBind11 extension for JAX support"""
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cpp",
csrc_source_files / "modules.cpp",
csrc_source_files / "utils.cu",
]
# Header files
cuda_home, _ = cuda_path()
include_dirs = [
cuda_home / "include",
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compile flags
cxx_flags = [ "-O3" ]
nvcc_flags = [ "-O3" ]
# Define TE/JAX as a Pybind11Extension
from pybind11.setup_helpers import Pybind11Extension
class Pybind11CUDAExtension(Pybind11Extension):
"""Modified Pybind11Extension to allow combined CXX + NVCC compile flags."""
def _add_cflags(self, flags: List[str]) -> None:
if isinstance(self.extra_compile_args, dict):
cxx_flags = self.extra_compile_args.pop('cxx', [])
cxx_flags += flags
self.extra_compile_args['cxx'] = cxx_flags
else:
self.extra_compile_args[:0] = flags
return Pybind11CUDAExtension(
"transformer_engine_jax",
sources=[str(path) for path in sources],
include_dirs=[str(path) for path in include_dirs],
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags
},
)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Paddle-paddle related extensions."""
from pathlib import Path
import setuptools
from .utils import cuda_version
def setup_paddle_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for Paddle support"""
# Source files
csrc_source_files = Path(csrc_source_files)
sources = [
csrc_source_files / "extensions.cu",
csrc_source_files / "common.cpp",
csrc_source_files / "custom_ops.cu",
]
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# Construct Paddle CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from paddle.utils.cpp_extension import CUDAExtension
ext = CUDAExtension(
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
ext.name = "transformer_engine_paddle_pd_"
return ext
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""PyTorch related extensions."""
import os
from pathlib import Path
import setuptools
from .utils import (
all_files_in_dir,
cuda_version,
userbuffers_enabled,
)
def setup_pytorch_extension(
csrc_source_files,
csrc_header_files,
common_header_files,
) -> setuptools.Extension:
"""Setup CUDA extension for PyTorch support"""
# Source files
csrc_source_files = Path(csrc_source_files)
extensions_dir = csrc_source_files / "extensions"
sources = [
csrc_source_files / "common.cu",
csrc_source_files / "ts_fp8_op.cpp",
] + all_files_in_dir(extensions_dir)
# Header files
include_dirs = [
common_header_files,
common_header_files / "common",
common_header_files / "common" / "include",
csrc_header_files,
]
# Compiler flags
cxx_flags = ["-O3"]
nvcc_flags = [
"-O3",
"-gencode",
"arch=compute_70,code=sm_70",
"-U__CUDA_NO_HALF_OPERATORS__",
"-U__CUDA_NO_HALF_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT16_OPERATORS__",
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__",
"-U__CUDA_NO_BFLOAT162_OPERATORS__",
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__",
"--expt-relaxed-constexpr",
"--expt-extended-lambda",
"--use_fast_math",
]
# Version-dependent CUDA options
try:
version = cuda_version()
except FileNotFoundError:
print("Could not determine CUDA Toolkit version")
else:
if version >= (11, 2):
nvcc_flags.extend(["--threads", "4"])
if version >= (11, 0):
nvcc_flags.extend(["-gencode", "arch=compute_80,code=sm_80"])
if version >= (11, 8):
nvcc_flags.extend(["-gencode", "arch=compute_90,code=sm_90"])
# userbuffers support
if userbuffers_enabled():
if os.getenv("MPI_HOME"):
mpi_home = Path(os.getenv("MPI_HOME"))
include_dirs.append(mpi_home / "include")
cxx_flags.append("-DNVTE_WITH_USERBUFFERS")
nvcc_flags.append("-DNVTE_WITH_USERBUFFERS")
# Construct PyTorch CUDA extension
sources = [str(path) for path in sources]
include_dirs = [str(path) for path in include_dirs]
from torch.utils.cpp_extension import CUDAExtension
return CUDAExtension(
name="transformer_engine_torch",
sources=sources,
include_dirs=include_dirs,
extra_compile_args={
"cxx": cxx_flags,
"nvcc": nvcc_flags,
},
)
...@@ -7,6 +7,7 @@ import os ...@@ -7,6 +7,7 @@ import os
from pathlib import Path from pathlib import Path
import subprocess import subprocess
def te_version() -> str: def te_version() -> str:
"""Transformer Engine version string """Transformer Engine version string
...@@ -15,9 +16,10 @@ def te_version() -> str: ...@@ -15,9 +16,10 @@ def te_version() -> str:
""" """
root_path = Path(__file__).resolve().parent root_path = Path(__file__).resolve().parent
with open(root_path / "VERSION", "r") as f: with open(root_path / "VERSION.txt", "r") as f:
version = f.readline().strip() version = f.readline().strip()
if not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0")): if (not int(os.getenv("NVTE_NO_LOCAL_VERSION", "0"))
and not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0")))):
try: try:
output = subprocess.run( output = subprocess.run(
["git", "rev-parse" , "--short", "HEAD"], ["git", "rev-parse" , "--short", "HEAD"],
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script."""
import os
import re
import glob
import shutil
import subprocess
import sys
from functools import cache
from pathlib import Path
from subprocess import CalledProcessError
from typing import List, Optional, Tuple
@cache
def userbuffers_enabled() -> bool:
"""Check if userbuffers support is enabled"""
if int(os.getenv("NVTE_WITH_USERBUFFERS", "0")):
assert os.getenv("MPI_HOME"), "MPI_HOME must be set if NVTE_WITH_USERBUFFERS=1"
return True
return False
@cache
def debug_build_enabled() -> bool:
"""Whether to build with a debug configuration"""
for arg in sys.argv:
if arg == "--debug":
sys.argv.remove(arg)
return True
if int(os.getenv("NVTE_BUILD_DEBUG", "0")):
return True
return False
def all_files_in_dir(path):
all_files = []
for dirname, _, names in os.walk(path):
for name in names:
all_files.append(Path(dirname, name))
return all_files
def remove_dups(_list: List):
return list(set(_list))
def found_cmake() -> bool:
""" "Check if valid CMake is available
CMake 3.18 or newer is required.
"""
# Check if CMake is available
try:
_cmake_bin = cmake_bin()
except FileNotFoundError:
return False
# Query CMake for version info
output = subprocess.run(
[_cmake_bin, "--version"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"version\s*([\d.]+)", output.stdout)
version = match.group(1).split(".")
version = tuple(int(v) for v in version)
return version >= (3, 18)
def cmake_bin() -> Path:
"""Get CMake executable
Throws FileNotFoundError if not found.
"""
# Search in CMake Python package
_cmake_bin: Optional[Path] = None
try:
from cmake import CMAKE_BIN_DIR
except ImportError:
pass
else:
_cmake_bin = Path(CMAKE_BIN_DIR).resolve() / "cmake"
if not _cmake_bin.is_file():
_cmake_bin = None
# Search in path
if _cmake_bin is None:
_cmake_bin = shutil.which("cmake")
if _cmake_bin is not None:
_cmake_bin = Path(_cmake_bin).resolve()
# Return executable if found
if _cmake_bin is None:
raise FileNotFoundError("Could not find CMake executable")
return _cmake_bin
def found_ninja() -> bool:
""" "Check if Ninja is available"""
return shutil.which("ninja") is not None
def found_pybind11() -> bool:
""" "Check if pybind11 is available"""
# Check if Python package is installed
try:
import pybind11
except ImportError:
pass
else:
return True
# Check if CMake can find pybind11
if not found_cmake():
return False
try:
subprocess.run(
[
"cmake",
"--find-package",
"-DMODE=EXIST",
"-DNAME=pybind11",
"-DCOMPILER_ID=CXX",
"-DLANGUAGE=CXX",
],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True,
)
except (CalledProcessError, OSError):
pass
else:
return True
return False
@cache
def cuda_path() -> Tuple[str, str]:
"""CUDA root path and NVCC binary path as a tuple.
Throws FileNotFoundError if NVCC is not found."""
# Try finding NVCC
nvcc_bin: Optional[Path] = None
if nvcc_bin is None and os.getenv("CUDA_HOME"):
# Check in CUDA_HOME
cuda_home = Path(os.getenv("CUDA_HOME"))
nvcc_bin = cuda_home / "bin" / "nvcc"
if nvcc_bin is None:
# Check if nvcc is in path
nvcc_bin = shutil.which("nvcc")
if nvcc_bin is not None:
cuda_home = Path(nvcc_bin.rstrip("/bin/nvcc"))
nvcc_bin = Path(nvcc_bin)
if nvcc_bin is None:
# Last-ditch guess in /usr/local/cuda
cuda_home = Path("/usr/local/cuda")
nvcc_bin = cuda_home / "bin" / "nvcc"
if not nvcc_bin.is_file():
raise FileNotFoundError(f"Could not find NVCC at {nvcc_bin}")
return cuda_home, nvcc_bin
def cuda_version() -> Tuple[int, ...]:
"""CUDA Toolkit version as a (major, minor) tuple."""
# Query NVCC for version info
_, nvcc_bin = cuda_path()
output = subprocess.run(
[nvcc_bin, "-V"],
capture_output=True,
check=True,
universal_newlines=True,
)
match = re.search(r"release\s*([\d.]+)", output.stdout)
version = match.group(1).split('.')
return tuple(int(v) for v in version)
def get_frameworks() -> List[str]:
"""DL frameworks to build support for"""
_frameworks: List[str] = []
supported_frameworks = ["pytorch", "jax", "paddle"]
# Check environment variable
if os.getenv("NVTE_FRAMEWORK"):
_frameworks.extend(os.getenv("NVTE_FRAMEWORK").split(","))
# Check command-line arguments
for arg in sys.argv.copy():
if arg.startswith("--framework="):
_frameworks.extend(arg.replace("--framework=", "").split(","))
sys.argv.remove(arg)
# Detect installed frameworks if not explicitly specified
if not _frameworks:
try:
import torch
except ImportError:
pass
else:
_frameworks.append("pytorch")
try:
import jax
except ImportError:
pass
else:
_frameworks.append("jax")
try:
import paddle
except ImportError:
pass
else:
_frameworks.append("paddle")
# Special framework names
if "all" in _frameworks:
_frameworks = supported_frameworks.copy()
if "none" in _frameworks:
_frameworks = []
# Check that frameworks are valid
_frameworks = [framework.lower() for framework in _frameworks]
for framework in _frameworks:
if framework not in supported_frameworks:
raise ValueError(
f"Transformer Engine does not support framework={framework}"
)
return _frameworks
def package_files(directory):
paths = []
for path, _, filenames in os.walk(directory):
path = Path(path)
for filename in filenames:
paths.append(str(path / filename).replace(f"{directory}/", ""))
return paths
def copy_common_headers(te_src, dst):
headers = te_src / "common"
for file_path in glob.glob(os.path.join(str(headers), "**", '*.h'), recursive=True):
new_path = os.path.join(dst, file_path[len(str(te_src)) + 1:])
Path(new_path).parent.mkdir(exist_ok=True, parents=True)
shutil.copy(file_path, new_path)
def install_and_import(package):
"""Install a package via pip (if not already installed) and import into globals."""
import importlib
try:
importlib.import_module(package)
except ImportError:
subprocess.check_call([sys.executable, '-m', 'pip', 'install', package])
finally:
globals()[package] = importlib.import_module(package)
...@@ -17,7 +17,7 @@ from datetime import date ...@@ -17,7 +17,7 @@ from datetime import date
te_path = os.path.dirname(os.path.realpath(__file__)) te_path = os.path.dirname(os.path.realpath(__file__))
with open(te_path + "/../VERSION", "r") as f: with open(te_path + "/../build_tools/VERSION.txt", "r") as f:
te_version = f.readline().strip() te_version = f.readline().strip()
release_year = 2022 release_year = 2022
......
...@@ -204,7 +204,7 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model): ...@@ -204,7 +204,7 @@ def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
def cast_to_representable(inp, scale = 1., fp8_format='e4m3'): def cast_to_representable(inp, scale = 1., fp8_format='e4m3'):
import transformer_engine.pytorch.cpp_extensions as texcpp import transformer_engine.pytorch.cpp_extensions as texcpp
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch.constants import TE_DType
fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2 fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
input_type = TE_DType[inp.dtype] input_type = TE_DType[inp.dtype]
......
[MASTER] [MASTER]
extension-pkg-whitelist=transformer_engine_jax extension-pkg-whitelist=transformer_engine_jax
extension-pkg-allow-list=transformer_engine.transformer_engine_jax
disable=too-many-locals, disable=too-many-locals,
invalid-name, invalid-name,
too-many-arguments, too-many-arguments,
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
set -xe set -xe
pip install pytest==7.2
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed' pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/tests/jax -k 'not distributed'
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
set -xe set -xe
pip install pytest==7.2
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -Wignore -v $TE_PATH/tests/paddle pytest -Wignore -v $TE_PATH/tests/paddle
pytest -Wignore -v $TE_PATH/examples/paddle/mnist pytest -Wignore -v $TE_PATH/examples/paddle/mnist
[MASTER] [MASTER]
extension-pkg-whitelist=torch, extension-pkg-whitelist=torch,
transformer_engine_extensions transformer_engine_torch
disable=too-many-locals, disable=too-many-locals,
too-many-public-methods, too-many-public-methods,
......
...@@ -14,7 +14,7 @@ then ...@@ -14,7 +14,7 @@ then
echo "Checking common API headers" echo "Checking common API headers"
cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include cpplint --root transformer_engine/common/include --recursive transformer_engine/common/include
echo "Checking C++ files" echo "Checking C++ files"
cpplint --recursive --exclude=transformer_engine/common/include transformer_engine cpplint --recursive --exclude=transformer_engine/common/include --exclude=transformer_engine/build_tools/build transformer_engine
cpplint --recursive transformer_engine/pytorch cpplint --recursive transformer_engine/pytorch
fi fi
if [ -z "${CPP_ONLY}" ] if [ -z "${CPP_ONLY}" ]
......
This diff is collapsed.
...@@ -26,7 +26,7 @@ if(NOT DEFINED TE_LIB_PATH) ...@@ -26,7 +26,7 @@ if(NOT DEFINED TE_LIB_PATH)
OUTPUT_VARIABLE TE_LIB_PATH) OUTPUT_VARIABLE TE_LIB_PATH)
endif() endif()
find_library(TE_LIB NAMES transformer_engine PATHS ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/transformer_engine" ENV TE_LIB_PATH REQUIRED)
message(STATUS "Found transformer_engine library: ${TE_LIB}") message(STATUS "Found transformer_engine library: ${TE_LIB}")
include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common/include)
......
...@@ -22,7 +22,7 @@ from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLay ...@@ -22,7 +22,7 @@ from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLay
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine.jax.cpp_extensions import FusedAttnHelper
from transformer_engine_jax import NVTE_Fused_Attn_Backend from transformer_engine.transformer_engine_jax import NVTE_Fused_Attn_Backend
from utils import assert_allclose from utils import assert_allclose
......
...@@ -15,7 +15,7 @@ import pytest ...@@ -15,7 +15,7 @@ import pytest
from utils import assert_allclose from utils import assert_allclose
from transformer_engine_jax import get_device_compute_capability from transformer_engine.transformer_engine_jax import get_device_compute_capability
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections from transformer_engine.jax import fp8_autocast, update_fp8_metas, update_collections
from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral from transformer_engine.jax.flax import DenseGeneral, LayerNormDenseGeneral
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment