"include/git@developer.sourcefind.cn:jerrrrry/infinicore.git" did not exist on "ff84910c9b8e304fe3e44ab16137d7719c7012e3"
Unverified Commit 93f00a79 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Improvements for building wheels (#1148)



* Improvements for wheels
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes for wheel build
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Move package finder to common
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* format
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fixes
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Lint
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* FIx
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Fix CI and distributed test
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix paddle ci
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 9437ceb2
...@@ -301,7 +301,7 @@ def install_and_import(package): ...@@ -301,7 +301,7 @@ def install_and_import(package):
globals()[main_package] = importlib.import_module(main_package) globals()[main_package] = importlib.import_module(main_package)
def uninstall_te_fw_packages(): def uninstall_te_wheel_packages():
subprocess.check_call( subprocess.check_call(
[ [
sys.executable, sys.executable,
...@@ -309,6 +309,7 @@ def uninstall_te_fw_packages(): ...@@ -309,6 +309,7 @@ def uninstall_te_fw_packages():
"pip", "pip",
"uninstall", "uninstall",
"-y", "-y",
"transformer_engine_cu12",
"transformer_engine_torch", "transformer_engine_torch",
"transformer_engine_paddle", "transformer_engine_paddle",
"transformer_engine_jax", "transformer_engine_jax",
......
...@@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "false", "false", "true"] CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"]
...@@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +33,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true"] CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"]
...@@ -5,10 +5,11 @@ ...@@ -5,10 +5,11 @@
set -e set -e
PLATFORM=${1:-manylinux_2_28_x86_64} PLATFORM=${1:-manylinux_2_28_x86_64}
BUILD_COMMON=${2:-true} BUILD_METAPACKAGE=${2:-true}
BUILD_JAX=${3:-true} BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true} BUILD_PYTORCH=${4:-true}
BUILD_PADDLE=${5:-true} BUILD_JAX=${5:-true}
BUILD_PADDLE=${6:-true}
export NVTE_RELEASE_BUILD=1 export NVTE_RELEASE_BUILD=1
export TARGET_BRANCH=${TARGET_BRANCH:-} export TARGET_BRANCH=${TARGET_BRANCH:-}
...@@ -20,12 +21,33 @@ cd /TransformerEngine ...@@ -20,12 +21,33 @@ cd /TransformerEngine
git checkout $TARGET_BRANCH git checkout $TARGET_BRANCH
git submodule update --init --recursive git submodule update --init --recursive
if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
mv dist/* /wheelhouse/
fi
if $BUILD_COMMON ; then if $BUILD_COMMON ; then
VERSION=`cat build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Create the wheel.
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
# Repack the wheel for cuda specific package, i.e. cu12.
/opt/python/cp38-cp38/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
/opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}
# Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*) whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name" IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}-${whl_parts[1]}-py3-none-${whl_parts[4]}" whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
mv dist/"$whl_name" /wheelhouse/"$whl_name_target" rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target"
fi fi
if $BUILD_PYTORCH ; then if $BUILD_PYTORCH ; then
...@@ -37,8 +59,8 @@ fi ...@@ -37,8 +59,8 @@ fi
if $BUILD_JAX ; then if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax cd /TransformerEngine/transformer_engine/jax
/opt/python/cp38-cp38/bin/pip install jax jaxlib /opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/ cp dist/* /wheelhouse/
fi fi
...@@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then ...@@ -48,30 +70,30 @@ if $BUILD_PADDLE ; then
dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64 dnf -y install libcudnn8-devel.x86_64 libcudnn8.x86_64
cd /TransformerEngine/transformer_engine/paddle cd /TransformerEngine/transformer_engine/paddle
/opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl /opt/python/cp38-cp38/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp38-cp38/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp38.txt
/opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine paddlepaddle-gpu /opt/python/cp38-cp38/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl /opt/python/cp39-cp39/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp39-cp39/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt /opt/python/cp39-cp39/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp39.txt
/opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine paddlepaddle-gpu /opt/python/cp39-cp39/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl /opt/python/cp310-cp310/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp310-cp310/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp310.txt
/opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine paddlepaddle-gpu /opt/python/cp310-cp310/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl /opt/python/cp311-cp311/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp311-cp311/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt /opt/python/cp311-cp311/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp311.txt
/opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine paddlepaddle-gpu /opt/python/cp311-cp311/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
/opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl /opt/python/cp312-cp312/bin/pip install /wheelhouse/*.whl --no-deps
/opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1 /opt/python/cp312-cp312/bin/pip install paddlepaddle-gpu==2.6.1
/opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt /opt/python/cp312-cp312/bin/python setup.py bdist_wheel --verbose --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/paddle_cp312.txt
/opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine paddlepaddle-gpu /opt/python/cp312-cp312/bin/pip uninstall -y transformer-engine transformer-engine-cu12 paddlepaddle-gpu
mv dist/* /wheelhouse/ mv dist/* /wheelhouse/
fi fi
......
...@@ -6,16 +6,30 @@ set -e ...@@ -6,16 +6,30 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install wheel
cd $TE_PATH cd $TE_PATH
pip uninstall -y transformer-engine pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-jax
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
cd transformer_engine/jax cd transformer_engine/jax
python setup.py sdist NVTE_RELEASE_BUILD=1 python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/* pip install dist/*
cd $TE_PATH cd $TE_PATH
pip install dist/* pip install dist/*.whl --no-deps
python $TE_PATH/tests/jax/test_sanity_import.py python $TE_PATH/tests/jax/test_sanity_import.py
...@@ -6,15 +6,28 @@ set -e ...@@ -6,15 +6,28 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install wheel==0.44.0 pydantic
cd $TE_PATH cd $TE_PATH
pip uninstall -y transformer-engine pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-paddle
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel
pip install dist/*
cd transformer_engine/paddle
python setup.py bdist_wheel
export NVTE_RELEASE_BUILD=0 VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
pip install dist/*.whl --no-deps
cd transformer_engine/paddle
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
pip install dist/* pip install dist/*
python $TE_PATH/tests/paddle/test_sanity_import.py python $TE_PATH/tests/paddle/test_sanity_import.py
...@@ -6,16 +6,30 @@ set -e ...@@ -6,16 +6,30 @@ set -e
: "${TE_PATH:=/opt/transformerengine}" : "${TE_PATH:=/opt/transformerengine}"
pip install wheel
cd $TE_PATH cd $TE_PATH
pip uninstall -y transformer-engine pip uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-torch
export NVTE_RELEASE_BUILD=1
python setup.py bdist_wheel VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}"
# Core wheel.
NVTE_RELEASE_BUILD=1 python setup.py bdist_wheel
wheel unpack dist/*
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE}
rm dist/*.whl
mv *.whl dist/
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python setup.py bdist_wheel
cd transformer_engine/pytorch cd transformer_engine/pytorch
python setup.py sdist NVTE_RELEASE_BUILD=1 python setup.py sdist
export NVTE_RELEASE_BUILD=0
pip install dist/* pip install dist/*
cd $TE_PATH cd $TE_PATH
pip install dist/* pip install dist/*.whl --no-deps
python $TE_PATH/tests/pytorch/test_sanity_import.py python $TE_PATH/tests/pytorch/test_sanity_import.py
...@@ -4,6 +4,10 @@ ...@@ -4,6 +4,10 @@
set -e set -e
# pkg_resources is deprecated in setuptools 70+ and the packaging submodule
# has been removed from it. This is a temporary fix until upstream MLM fix.
pip install setuptools==69.5.1
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py
......
...@@ -22,7 +22,7 @@ from build_tools.utils import ( ...@@ -22,7 +22,7 @@ from build_tools.utils import (
get_frameworks, get_frameworks,
install_and_import, install_and_import,
remove_dups, remove_dups,
uninstall_te_fw_packages, uninstall_te_wheel_packages,
) )
frameworks = get_frameworks() frameworks = get_frameworks()
...@@ -106,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]: ...@@ -106,46 +106,69 @@ def setup_requirements() -> Tuple[List[str], List[str], List[str]]:
if __name__ == "__main__": if __name__ == "__main__":
# Dependencies
setup_requires, install_requires, test_requires = setup_requirements()
__version__ = te_version() __version__ = te_version()
ext_modules = [setup_common_extension()] with open("README.rst", encoding="utf-8") as f:
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): long_description = f.read()
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included. # Settings for building top level empty package for dependency management.
uninstall_te_fw_packages() if bool(int(os.getenv("NVTE_BUILD_METAPACKAGE", "0"))):
if "pytorch" in frameworks: assert bool(
from build_tools.pytorch import setup_pytorch_extension int(os.getenv("NVTE_RELEASE_BUILD", "0"))
), "NVTE_RELEASE_BUILD env must be set for metapackage build."
ext_modules.append( ext_modules = []
setup_pytorch_extension( cmdclass = {}
"transformer_engine/pytorch/csrc", package_data = {}
current_file_path / "transformer_engine" / "pytorch" / "csrc", include_package_data = False
current_file_path / "transformer_engine", setup_requires = []
install_requires = ([f"transformer_engine_cu12=={__version__}"],)
extras_require = {
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
}
else:
setup_requires, install_requires, test_requires = setup_requirements()
ext_modules = [setup_common_extension()]
cmdclass = {"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}
package_data = {"": ["VERSION.txt"]}
include_package_data = True
extras_require = {"test": test_requires}
if not bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))):
# Remove residual FW packages since compiling from source
# results in a single binary with FW extensions included.
uninstall_te_wheel_packages()
if "pytorch" in frameworks:
from build_tools.pytorch import setup_pytorch_extension
ext_modules.append(
setup_pytorch_extension(
"transformer_engine/pytorch/csrc",
current_file_path / "transformer_engine" / "pytorch" / "csrc",
current_file_path / "transformer_engine",
)
) )
) if "jax" in frameworks:
if "jax" in frameworks: from build_tools.jax import setup_jax_extension
from build_tools.jax import setup_jax_extension
ext_modules.append(
ext_modules.append( setup_jax_extension(
setup_jax_extension( "transformer_engine/jax/csrc",
"transformer_engine/jax/csrc", current_file_path / "transformer_engine" / "jax" / "csrc",
current_file_path / "transformer_engine" / "jax" / "csrc", current_file_path / "transformer_engine",
current_file_path / "transformer_engine", )
) )
) if "paddle" in frameworks:
if "paddle" in frameworks: from build_tools.paddle import setup_paddle_extension
from build_tools.paddle import setup_paddle_extension
ext_modules.append(
ext_modules.append( setup_paddle_extension(
setup_paddle_extension( "transformer_engine/paddle/csrc",
"transformer_engine/paddle/csrc", current_file_path / "transformer_engine" / "paddle" / "csrc",
current_file_path / "transformer_engine" / "paddle" / "csrc", current_file_path / "transformer_engine",
current_file_path / "transformer_engine", )
) )
)
# Configure package # Configure package
setuptools.setup( setuptools.setup(
...@@ -158,13 +181,10 @@ if __name__ == "__main__": ...@@ -158,13 +181,10 @@ if __name__ == "__main__":
"transformer_engine/build_tools", "transformer_engine/build_tools",
], ],
), ),
extras_require={ extras_require=extras_require,
"test": test_requires,
"pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"],
"paddle": [f"transformer_engine_paddle=={__version__}"],
},
description="Transformer acceleration library", description="Transformer acceleration library",
long_description=long_description,
long_description_content_type="text/x-rst",
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist}, cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": TimedBdist},
python_requires=">=3.8, <3.13", python_requires=">=3.8, <3.13",
...@@ -178,6 +198,6 @@ if __name__ == "__main__": ...@@ -178,6 +198,6 @@ if __name__ == "__main__":
setup_requires=setup_requires, setup_requires=setup_requires,
install_requires=install_requires, install_requires=install_requires,
license_files=("LICENSE",), license_files=("LICENSE",),
include_package_data=True, include_package_data=include_package_data,
package_data={"": ["VERSION.txt"]}, package_data=package_data,
) )
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
"""FW agnostic user-end APIs""" """FW agnostic user-end APIs"""
import sys
import glob import glob
import sysconfig import sysconfig
import subprocess import subprocess
...@@ -15,6 +16,16 @@ from pathlib import Path ...@@ -15,6 +16,16 @@ from pathlib import Path
import transformer_engine import transformer_engine
def is_package_installed(package):
"""Checks if a pip package is installed."""
return (
subprocess.run(
[sys.executable, "-m", "pip", "show", package], capture_output=True, check=False
).returncode
== 0
)
def get_te_path(): def get_te_path():
"""Find Transformer Engine install path using pip""" """Find Transformer Engine install path using pip"""
return Path(transformer_engine.__path__[0]).parent return Path(transformer_engine.__path__[0]).parent
......
...@@ -5,21 +5,50 @@ ...@@ -5,21 +5,50 @@
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import logging
import ctypes import ctypes
from importlib.metadata import version
from transformer_engine.common import get_te_path from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_jax"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
" transformer-engine[jax]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[jax]==VERSION'",
module_name,
)
extension = _get_sys_extension() extension = _get_sys_extension()
try: try:
so_dir = get_te_path() / "transformer_engine" so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration: except StopIteration:
so_dir = get_te_path() so_dir = get_te_path()
so_path = next(so_dir.glob(f"transformer_engine_jax.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL) return ctypes.CDLL(so_path, mode=ctypes.RTLD_GLOBAL)
......
...@@ -6,9 +6,41 @@ ...@@ -6,9 +6,41 @@
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import logging
from importlib.metadata import version
from transformer_engine.common import is_package_installed
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_paddle"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
" transformer-engine[paddle]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[paddle]==VERSION'",
module_name,
)
from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import from transformer_engine import transformer_engine_paddle # pylint: disable=unused-import
......
...@@ -6,25 +6,54 @@ ...@@ -6,25 +6,54 @@
# pylint: disable=wrong-import-position,wrong-import-order # pylint: disable=wrong-import-position,wrong-import-order
import logging
import importlib import importlib
import importlib.util
import sys import sys
import torch import torch
from importlib.metadata import version
from transformer_engine.common import get_te_path from transformer_engine.common import get_te_path, is_package_installed
from transformer_engine.common import _get_sys_extension from transformer_engine.common import _get_sys_extension
def _load_library(): def _load_library():
"""Load shared library with Transformer Engine C extensions""" """Load shared library with Transformer Engine C extensions"""
module_name = "transformer_engine_torch"
if is_package_installed(module_name):
assert is_package_installed("transformer_engine"), "Could not find `transformer-engine`."
assert is_package_installed(
"transformer_engine_cu12"
), "Could not find `transformer-engine-cu12`."
assert (
version(module_name)
== version("transformer-engine")
== version("transformer-engine-cu12")
), (
"TransformerEngine package version mismatch. Found"
f" {module_name} v{version(module_name)}, transformer-engine"
f" v{version('transformer-engine')}, and transformer-engine-cu12"
f" v{version('transformer-engine-cu12')}. Install transformer-engine using 'pip install"
" transformer-engine[pytorch]==VERSION'"
)
if is_package_installed("transformer-engine-cu12"):
if not is_package_installed(module_name):
logging.info(
"Could not find package %s. Install transformer-engine using 'pip"
" install transformer-engine[pytorch]==VERSION'",
module_name,
)
extension = _get_sys_extension() extension = _get_sys_extension()
try: try:
so_dir = get_te_path() / "transformer_engine" so_dir = get_te_path() / "transformer_engine"
so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
except StopIteration: except StopIteration:
so_dir = get_te_path() so_dir = get_te_path()
so_path = next(so_dir.glob(f"transformer_engine_torch.*.{extension}")) so_path = next(so_dir.glob(f"{module_name}.*.{extension}"))
module_name = "transformer_engine_torch"
spec = importlib.util.spec_from_file_location(module_name, so_path) spec = importlib.util.spec_from_file_location(module_name, so_path)
solib = importlib.util.module_from_spec(spec) solib = importlib.util.module_from_spec(spec)
sys.modules[module_name] = solib sys.modules[module_name] = solib
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment