Unverified Commit 09813578 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Build scripts for pip wheels (#1036)



* Specify python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add classifiers for python
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add utils to build wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make wheel scripts
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add aarch
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle wheel
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* PaddlePaddle only builds for x86
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add optional fwk deps
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Python3.8; catch install error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] cudnn9 compile with paddle support
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [wip] dont link cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dlopen cudnn
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* dynamically load nvrtc
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove residual packages; exclude stub from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Exclude builtins from nvrtc .so search
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* properly include files for sdist
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* paddle wheel tie to python version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle build from src [wip]
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix workflow paddle build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix paddle
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix lint from pr986
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Add sanity import to wheel test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* remove upper limit on paddlepaddle version
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove unused imports
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Remove pybind11 dependency
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix cpp tests
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Search .sos in cuda home
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* CLeanup, remove residual code
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6ae584dd
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
#include "transformer_engine/layer_norm.h"
#include "transformer_engine/rmsnorm.h"
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -4,7 +4,7 @@
* See LICENSE for license information.
************************************************************************/
#include "jax/csrc/extensions.h"
#include "extensions.h"
#include "transformer_engine/cast.h"
namespace transformer_engine {
......
......@@ -6,7 +6,7 @@
#include "transformer_engine/softmax.h"
#include "jax/csrc/extensions.h"
#include "extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -6,7 +6,7 @@
#include "transformer_engine/transpose.h"
#include "jax/csrc/extensions.h"
#include "extensions.h"
namespace transformer_engine {
namespace jax {
......
......@@ -29,13 +29,14 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext
from build_tools.utils import package_files, copy_common_headers, install_and_import
from build_tools.utils import copy_common_headers, install_and_import
from build_tools.te_version import te_version
from build_tools.jax import setup_jax_extension
install_and_import("pybind11")
from pybind11.setup_helpers import build_ext as BuildExtension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension)
......@@ -53,18 +54,12 @@ if __name__ == "__main__":
setuptools.setup(
name="transformer_engine_jax",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Jax Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["jax", "flax>=0.7.1"],
tests_require=["numpy", "praxis"],
include_package_data=True,
package_data={
"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools"),
},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
shutil.rmtree("build_tools")
recursive-include build_tools *.*
recursive-include common_headers *.*
recursive-include csrc *.*
......@@ -29,15 +29,13 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
shutil.copytree(build_tools_dir, build_tools_copy)
from build_tools.build_ext import get_build_ext # pylint: disable=wrong-import-position
from build_tools.utils import (
package_files,
copy_common_headers,
) # pylint: disable=wrong-import-position
from build_tools.te_version import te_version # pylint: disable=wrong-import-position
from build_tools.paddle import setup_paddle_extension # pylint: disable=wrong-import-position
from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.paddle import setup_paddle_extension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension)
......@@ -55,18 +53,12 @@ if __name__ == "__main__":
setuptools.setup(
name="transformer_engine_paddle",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Paddle Paddle Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["paddlepaddle-gpu"],
install_requires=["paddlepaddle-gpu>=2.6.1"],
tests_require=["numpy"],
include_package_data=True,
package_data={
"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools"),
},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
shutil.rmtree("build_tools")
recursive-include build_tools *.*
recursive-include common_headers *.*
recursive-include csrc *.*
......@@ -30,11 +30,12 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext
from build_tools.utils import package_files, copy_common_headers
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension)
......@@ -52,18 +53,12 @@ if __name__ == "__main__":
setuptools.setup(
name="transformer_engine_torch",
version=te_version(),
packages=["csrc", common_headers_dir, "build_tools"],
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
install_requires=["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"],
tests_require=["numpy", "onnxruntime", "torchvision"],
include_package_data=True,
package_data={
"csrc": package_files("csrc"),
common_headers_dir: package_files(common_headers_dir),
"build_tools": package_files("build_tools"),
},
)
if any(x in sys.argv for x in (".", "sdist", "bdist_wheel")):
shutil.rmtree(common_headers_dir)
shutil.rmtree("build_tools")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment