Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
...@@ -9,9 +9,9 @@ from jax.ad_checkpoint import checkpoint_name ...@@ -9,9 +9,9 @@ from jax.ad_checkpoint import checkpoint_name
import jax import jax
import jax.numpy as jnp import jax.numpy as jnp
from transformer_engine_jax import NVTE_Bias_Type from transformer_engine.transformer_engine_jax import NVTE_Bias_Type
from transformer_engine_jax import NVTE_Mask_Type from transformer_engine.transformer_engine_jax import NVTE_Mask_Type
from transformer_engine_jax import NVTE_QKV_Layout from transformer_engine.transformer_engine_jax import NVTE_QKV_Layout
from .cpp_extensions import FusedAttnHelper from .cpp_extensions import FusedAttnHelper
from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script for TE jax extensions."""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import os
import shutil
from pathlib import Path
import setuptools
try:
import jax # pylint: disable=unused-import
except ImportError as e:
raise RuntimeError("This package needs Jax to build.") from e
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
shutil.copytree(build_tools_dir, current_file_path / "build_tools", dirs_exist_ok=True)
from build_tools.build_ext import get_build_ext
from build_tools.utils import package_files, copy_common_headers, install_and_import
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension
install_and_import('pybind11')
from pybind11.setup_helpers import build_ext as BuildExtension
CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [
setup_jax_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
# Configure package
setuptools.setup(
name="transformer_engine_jax",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# #
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for Paddle""" """Transformer Engine bindings for Paddle"""
# pylint: disable=wrong-import-position,wrong-import-order
def _load_library():
"""Load shared library with Transformer Engine C extensions"""
from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import
_load_library()
from .fp8 import fp8_autocast from .fp8 import fp8_autocast
from .layer import ( from .layer import (
Linear, Linear,
......
...@@ -7,7 +7,7 @@ from enum import Enum ...@@ -7,7 +7,7 @@ from enum import Enum
import paddle import paddle
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
class FP8FwdTensors(Enum): class FP8FwdTensors(Enum):
......
...@@ -7,7 +7,7 @@ import math ...@@ -7,7 +7,7 @@ import math
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import paddle import paddle
import paddle.nn.functional as F import paddle.nn.functional as F
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta from .fp8 import FP8TensorMeta
......
...@@ -9,7 +9,7 @@ from typing import Tuple, Optional, Dict, Any, Union ...@@ -9,7 +9,7 @@ from typing import Tuple, Optional, Dict, Any, Union
import numpy as np import numpy as np
import paddle import paddle
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
......
...@@ -11,7 +11,7 @@ from typing import Dict, Any, List, Union ...@@ -11,7 +11,7 @@ from typing import Dict, Any, List, Union
import numpy as np import numpy as np
import paddle import paddle
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from .constants import dist_group_type, RecomputeFunctionNames from .constants import dist_group_type, RecomputeFunctionNames
......
...@@ -14,7 +14,7 @@ try: ...@@ -14,7 +14,7 @@ try:
from paddle.incubate.nn.functional import fused_rotary_position_embedding from paddle.incubate.nn.functional import fused_rotary_position_embedding
except ImportError: except ImportError:
fused_rotary_position_embedding = None fused_rotary_position_embedding = None
import transformer_engine_paddle as tex from transformer_engine import transformer_engine_paddle as tex
from .layernorm_linear import LayerNormLinear from .layernorm_linear import LayerNormLinear
from .linear import Linear from .linear import Linear
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script for TE paddle-paddle extensions."""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import os
import shutil
from pathlib import Path
import setuptools
from paddle.utils.cpp_extension import BuildExtension
try:
import paddle # pylint: disable=unused-import
except ImportError as e:
raise RuntimeError("This package needs Paddle Paddle to build.") from e
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
shutil.copytree(build_tools_dir, current_file_path / "build_tools", dirs_exist_ok=True)
from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position
from build_tools.utils import package_files, copy_common_headers # pylint: disable=wrong-import-position
from build_tools.te_version import te_version # pylint: disable=wrong-import-position
from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position
CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [
setup_paddle_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
# Configure package
setuptools.setup(
name="transformer_engine_paddle",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Paddle Paddle Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["paddlepaddle-gpu"],
tests_require=["numpy"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
...@@ -3,27 +3,55 @@ ...@@ -3,27 +3,55 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Transformer Engine bindings for pyTorch""" """Transformer Engine bindings for pyTorch"""
# pylint: disable=wrong-import-position,wrong-import-order
import importlib
import sys
import torch import torch
from .module import LayerNormLinear from transformer_engine.common import get_te_path
from .module import Linear from transformer_engine.common import _get_sys_extension
from .module import LayerNormMLP
from .module import LayerNorm
from .module import RMSNorm def _load_library():
from .attention import DotProductAttention """Load shared library with Transformer Engine C extensions"""
from .attention import InferenceParams extension = _get_sys_extension()
from .attention import MultiheadAttention try:
from .transformer import TransformerLayer so_dir = get_te_path() / "transformer_engine"
from .fp8 import fp8_autocast so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}"))
from .fp8 import fp8_model_init except StopIteration:
from .graph import make_graphed_callables so_dir = get_te_path()
from .export import onnx_export so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}"))
from .distributed import checkpoint
from .distributed import CudaRNGStatesTracker module_name = "transformer_engine_torch"
from .cpu_offload import get_cpu_offload_context spec = importlib.util.spec_from_file_location(module_name, so_path)
from . import optimizers solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib
spec.loader.exec_module(solib)
_load_library()
from transformer_engine.pytorch.module import LayerNormLinear
from transformer_engine.pytorch.module import Linear
from transformer_engine.pytorch.module import LayerNormMLP
from transformer_engine.pytorch.module import LayerNorm
from transformer_engine.pytorch.module import RMSNorm
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention import InferenceParams
from transformer_engine.pytorch.attention import MultiheadAttention
from transformer_engine.pytorch.transformer import TransformerLayer
from transformer_engine.pytorch.fp8 import fp8_autocast
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformer_engine.pytorch.graph import make_graphed_callables
from transformer_engine.pytorch.export import onnx_export
from transformer_engine.pytorch.distributed import checkpoint
from transformer_engine.pytorch.distributed import CudaRNGStatesTracker
from transformer_engine.pytorch.cpu_offload import get_cpu_offload_context
from transformer_engine.pytorch import optimizers
# Register custom op symbolic ONNX functions # Register custom op symbolic ONNX functions
from .te_onnx_extensions import ( from transformer_engine.pytorch.te_onnx_extensions import (
onnx_cast_to_fp8, onnx_cast_to_fp8,
onnx_cast_to_fp8_noalloc, onnx_cast_to_fp8_noalloc,
onnx_cast_from_fp8, onnx_cast_from_fp8,
...@@ -33,7 +61,7 @@ from .te_onnx_extensions import ( ...@@ -33,7 +61,7 @@ from .te_onnx_extensions import (
onnx_layernorm_fwd_fp8, onnx_layernorm_fwd_fp8,
onnx_layernorm_fwd, onnx_layernorm_fwd,
onnx_rmsnorm_fwd, onnx_rmsnorm_fwd,
onnx_rmsnorm_fwd_fp8 onnx_rmsnorm_fwd_fp8,
) )
try: try:
......
...@@ -17,7 +17,7 @@ from packaging.version import Version as PkgVersion ...@@ -17,7 +17,7 @@ from packaging.version import Version as PkgVersion
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.cpp_extensions import ( from transformer_engine.pytorch.cpp_extensions import (
cast_to_fp8, cast_to_fp8,
cast_from_fp8, cast_from_fp8,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Enums for e2e transformer""" """Enums for e2e transformer"""
import torch import torch
import torch.distributed import torch.distributed
import transformer_engine_extensions as tex import transformer_engine_torch as tex
""" """
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
# See LICENSE for license information. # See LICENSE for license information.
"""Python interface for c++ extensions""" """Python interface for c++ extensions"""
from transformer_engine_extensions import * from transformer_engine_torch import *
from .fused_attn import * from .fused_attn import *
from .gemm import * from .gemm import *
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Python interface for activation extensions""" """Python interface for activation extensions"""
from typing import Union from typing import Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu'] __all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu', 'srelu']
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Python interface for cast extensions""" """Python interface for cast extensions"""
from typing import Optional, Union from typing import Optional, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
__all__ = ['cast_to_fp8', __all__ = ['cast_to_fp8',
......
...@@ -6,8 +6,8 @@ ...@@ -6,8 +6,8 @@
import math import math
from typing import Tuple, List, Union from typing import Tuple, List, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine_extensions import ( from transformer_engine_torch import (
NVTE_QKV_Layout, NVTE_QKV_Layout,
NVTE_Bias_Type, NVTE_Bias_Type,
NVTE_Mask_Type, NVTE_Mask_Type,
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Python interface for GEMM extensions""" """Python interface for GEMM extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
from ..utils import assert_dim_for_fp8_exec from ..utils import assert_dim_for_fp8_exec
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Python interface for normalization extensions""" """Python interface for normalization extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
__all__ = ['layernorm_fwd_fp8', __all__ = ['layernorm_fwd_fp8',
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"""Python interface for transpose extensions""" """Python interface for transpose extensions"""
from typing import Optional, Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from ..constants import TE_DType from ..constants import TE_DType
......
...@@ -9,7 +9,7 @@ import warnings ...@@ -9,7 +9,7 @@ import warnings
import torch import torch
from torch.utils._pytree import tree_map from torch.utils._pytree import tree_map
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from .constants import TE_DType from .constants import TE_DType
from .cpp_extensions import fp8_cast_transpose_fused from .cpp_extensions import fp8_cast_transpose_fused
...@@ -346,7 +346,7 @@ class Float8Tensor(torch.Tensor): ...@@ -346,7 +346,7 @@ class Float8Tensor(torch.Tensor):
fp8_meta_index: int, optional fp8_meta_index: int, optional
Index to access in FP8 meta tensors. Required if Index to access in FP8 meta tensors. Required if
fp8_meta is provided and otherwise ignored. fp8_meta is provided and otherwise ignored.
fp8_dtype: transformer_engine_extensions.DType, tex.DType.kFloat8E4M3 fp8_dtype: transformer_engine_torch.DType, tex.DType.kFloat8E4M3
FP8 format. FP8 format.
fp8_scale_inv: torch.Tensor fp8_scale_inv: torch.Tensor
Reciprocal of the scaling factor applied when Reciprocal of the scaling factor applied when
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment