Unverified Commit c1b915ae authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build system refactor for wheels (#877)



Cleanup
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent fc989613
...@@ -9,7 +9,7 @@ from collections import deque ...@@ -9,7 +9,7 @@ from collections import deque
from typing import Callable, List, Optional, Dict, Any, Tuple, Union from typing import Callable, List, Optional, Dict, Any, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.common.recipe import DelayedScaling, Format from transformer_engine.common.recipe import DelayedScaling, Format
from .constants import dist_group_type from .constants import dist_group_type
......
...@@ -14,7 +14,7 @@ from contextlib import contextmanager ...@@ -14,7 +14,7 @@ from contextlib import contextmanager
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from ._common import _ParameterInitMeta from ._common import _ParameterInitMeta
from ..export import is_in_onnx_export_mode from ..export import is_in_onnx_export_mode
from ..fp8 import ( from ..fp8 import (
......
...@@ -11,7 +11,7 @@ import torch ...@@ -11,7 +11,7 @@ import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.nn import init from torch.nn import init
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from .base import TransformerEngineBaseModule from .base import TransformerEngineBaseModule
from ..cpp_extensions import ( from ..cpp_extensions import (
layernorm_fwd_inf, layernorm_fwd_inf,
......
...@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union ...@@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from .base import ( from .base import (
get_workspace, get_workspace,
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""Fused Adam optimizer.""" """Fused Adam optimizer."""
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
import torch import torch
from torch.optim.optimizer import Optimizer, required from torch.optim.optimizer import Optimizer, required
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from .multi_tensor_apply import multi_tensor_applier from .multi_tensor_apply import multi_tensor_applier
......
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Installation script for TE pytorch extensions."""
# pylint: disable=wrong-import-position,wrong-import-order
import sys
import os
import shutil
from pathlib import Path
import setuptools
from torch.utils.cpp_extension import BuildExtension
try:
import torch # pylint: disable=unused-import
except ImportError as e:
raise RuntimeError("This package needs Torch to build.") from e
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_dir):
shutil.copytree(build_tools_dir, current_file_path / "build_tools", dirs_exist_ok=True)
from build_tools.build_ext import get_build_ext
from build_tools.utils import package_files, copy_common_headers
from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension
CMakeBuildExtension = get_build_ext(BuildExtension)
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
copy_common_headers(
current_file_path.parent,
str(current_file_path / common_headers_dir))
ext_modules = [
setup_pytorch_extension(
"csrc", current_file_path / "csrc", current_file_path / common_headers_dir)]
# Configure package
setuptools.setup(
name="transformer_engine_torch",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"],
include_package_data=True,
package_data={"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools")},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
...@@ -9,7 +9,7 @@ import torch ...@@ -9,7 +9,7 @@ import torch
from torch import nn from torch import nn
import torch._C._onnx as _C_onnx import torch._C._onnx as _C_onnx
from torch.onnx import _type_utils from torch.onnx import _type_utils
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode from transformer_engine.pytorch.export import is_in_onnx_export_mode
from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32 from transformer_engine.pytorch.te_onnx_extensions import compute_in_fp32
......
...@@ -30,7 +30,7 @@ import torch._C._onnx as _C_onnx ...@@ -30,7 +30,7 @@ import torch._C._onnx as _C_onnx
# Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics # Monkey-patch graph manipulation methods on Graph, used for the ONNX symbolics
from torch.onnx._internal import jit_utils from torch.onnx._internal import jit_utils
import transformer_engine_extensions as tex import transformer_engine_torch as tex
# This file registers custom op symbolic ONNX functions and does not export any symbols. # This file registers custom op symbolic ONNX functions and does not export any symbols.
......
...@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union ...@@ -10,7 +10,7 @@ from typing import Callable, List, Optional, Tuple, Union
import torch import torch
import transformer_engine_extensions as tex import transformer_engine_torch as tex
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.pytorch.attention import ( from transformer_engine.pytorch.attention import (
InferenceParams, InferenceParams,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment