Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
......@@ -9,9 +9,9 @@ from jax.ad_checkpoint import checkpoint_name
import jax
import jax.numpy as jnp
from transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout
from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script for TE jax extensions."""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import os
import shutil
from pathlib import Path
import setuptools
try:
import jax # pylint: disable=unused-import
except ImportError as e:
raise RuntimeError("This package needs Jax to build.") from e
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
shutil.copytree(build_tools_dir, current_file_path / "build_tools", dirs_exist_ok=True)
from build_tools.build_ext import get_build_ext
from build_tools.utils import package_files, copy_common_headers, install_and_import
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension
install_and_import('pybind11')
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [
setup_jax_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
# Configure package
setuptools.setup(
name="transformer_engine_jax",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Transformer Engine bindings for Paddle"""
# pylint: disable=wrong-import-position,wrong-import-order
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import
_load_library()
from .fp8 import fp8_autocast
from .layer import (
Linear,
......
......@@ -7,7 +7,7 @@ from enum import Enum
import paddle
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
class FP8FwdTensors(Enum):
......
......@@ -7,7 +7,7 @@ import math
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta
......
......@@ -9,7 +9,7 @@ from typing import Tuple, Optional, Dict, Any, Union
import numpy as np
import paddle
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type
......
......@@ -11,7 +11,7 @@ from typing import Dict, Any, List, Union
import numpy as np
import paddle
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
from .constants import dist_group_type, RecomputeFunctionNames
......
......@@ -14,7 +14,7 @@ try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError:
fused_rotary_position_embedding = None
import transformer_engine_paddle as tex
from transformer_engine import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear
from .linear import Linear
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script for TE paddle-paddle extensions."""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import os
import shutil
from pathlib import Path
import setuptools
from paddle.utils.cpp_extension import BuildExtension
try:
import paddle # pylint: disable=unused-import
except ImportError as e:
raise RuntimeError("This package needs Paddle Paddle to build.") from e
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
shutil.copytree(build_tools_dir, current_file_path / "build_tools", dirs_exist_ok=True)
from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position
from build_tools.utils import package_files, copy_common_headers # pylint: disable=wrong-import-position
from build_tools.te_version import te_version # pylint: disable=wrong-import-position
from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position
CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [
setup_paddle_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
# Configure package
setuptools.setup(
name="transformer_engine_paddle",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Paddle Paddle Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["paddlepaddle-gpu"],
tests_require=["numpy"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
......@@ -3,27 +3,55 @@
# See LICENSE for license information.
"""Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position,wrong-import-order
import importlib
import sys
import torch
from .module import LayerNormLinear
from .module import Linear
from .module import LayerNormMLP
from .module import LayerNorm
from .module import RMSNorm
from .attention import DotProductAttention
from .attention import InferenceParams
from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .fp8 import fp8_model_init
from .graph import make_graphed_callables
from .export import onnx_export
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker
from .cpu_offload import get_cpu_offload_context
from . import optimizers
from transformer_engine.common import get_te_path
from transformer_engine.common import _get_sys_extension
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
extension = _get_sys_extension()
try:
so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}"))
except StopIteration:
so_dir = get_te_path()
so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}"))
module_name = "transformer_engine_torch"
spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_load_library()
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import optimizers
# Register custom op symbolic ONNX functions
from .te_onnx_extensions import (
from transformer_engine.pytorch.te_onnx_extensions import (
onnx_cast_to_fp8,
onnx_cast_to_fp8_noalloc,
onnx_cast_from_fp8,
......@@ -33,7 +61,7 @@ from .te_onnx_extensions import (
onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd,
onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8
onnx_rmsnorm_fwd_fp8,
)
try:
......
......@@ -17,7 +17,7 @@ from packaging.version import Version as PkgVersion
import torch
import torch.nn.functional as F
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8,
cast_from_fp8,
......
......@@ -5,7 +5,7 @@
"""Enums for e2e transformer"""
import torch
import torch.distributed
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
"""
......
......@@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Python interface for c++ extensions"""
from transformer_engine_extensions import *
from transformer_engine_torch import *
from .fused_attn import *
from .gemm import *
......
......@@ -5,7 +5,7 @@
"""Python interface for activation extensions"""
from typing import Union
import torch
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu']
......
......@@ -5,7 +5,7 @@
"""Python interface for cast extensions"""
from typing import Optional, Union
import torch
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
__all__ = ['cast_to_fp8',
......
......@@ -6,8 +6,8 @@
import math
from typing import Tuple, List, Union
import torch
import transformer_engine_extensions as tex
from transformer_engine_extensions import (
import transformer_engine_torch as tex
from transformer_engine_torch import (
NVTE_QKV_Layout,
NVTE_Bias_Type,
NVTE_Mask_Type,
......
......@@ -5,7 +5,7 @@
"""Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec
......
......@@ -5,7 +5,7 @@
"""Python interface for normalization extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
__all__ = ['layernorm_fwd_fp8',
......
......@@ -5,7 +5,7 @@
"""Python interface for transpose extensions"""
from typing import Optional, Tuple, Union
import torch
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from ..constants import TE_DType
......
......@@ -9,7 +9,7 @@ import warnings
import torch
from torch.utils._pytree import tree_map
import transformer_engine_extensions as tex
import transformer_engine_torch as tex
from .constants import TE_DType
from .cpp_extensions import fp8_cast_transpose_fused
......@@ -346,7 +346,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if
fp8_meta is provided and otherwise ignored.
fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3
fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3
FP8 format.
fp8_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment