build_wheels.sh 2.54 KB
Newer Older
1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
6
7
#
# See LICENSE for license information.

set -e

PLATFORM=${1:-manylinux_2_28_x86_64}
8
9
BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true}
10
BUILD_PYTORCH=${4:-true}
11
BUILD_JAX=${5:-true}
12
13

export NVTE_RELEASE_BUILD=1
14
15
export TARGET_BRANCH=${TARGET_BRANCH:-}
mkdir -p /wheelhouse/logs
16
17
18
19
20
21
22

# Generate wheels for common library.
git config --global --add safe.directory /TransformerEngine
cd /TransformerEngine
git checkout $TARGET_BRANCH
git submodule update --init --recursive

23
24
25
26
27
28
if $BUILD_METAPACKAGE ; then
        cd /TransformerEngine
        NVTE_BUILD_METAPACKAGE=1 /opt/python/cp310-cp310/bin/python setup.py bdist_wheel 2>&1 | tee /wheelhouse/logs/metapackage.txt
        mv dist/* /wheelhouse/
fi

29
if $BUILD_COMMON ; then
30
31
32
33
        VERSION=`cat build_tools/VERSION.txt`
        WHL_BASE="transformer_engine-${VERSION}"

        # Create the wheel.
34
        /opt/python/cp38-cp38/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
35
36
37
38
39
40
41
42
43
44

        # Repack the wheel for cuda specific package, i.e. cu12.
        /opt/python/cp38-cp38/bin/wheel unpack dist/*
        # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
        sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
        sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
        mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info"
        /opt/python/cp38-cp38/bin/wheel pack ${WHL_BASE}

        # Rename the wheel to make it python version agnostic.
45
46
        whl_name=$(basename dist/*)
        IFS='-' read -ra whl_parts <<< "$whl_name"
47
48
49
        whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}"
        rm -rf $WHL_BASE dist
        mv *.whl /wheelhouse/"$whl_name_target"
50
51
52
53
54
55
56
57
58
59
60
fi

if $BUILD_PYTORCH ; then
	cd /TransformerEngine/transformer_engine/pytorch
	/opt/python/cp38-cp38/bin/pip install torch
	/opt/python/cp38-cp38/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
	cp dist/* /wheelhouse/
fi

if $BUILD_JAX ; then
	cd /TransformerEngine/transformer_engine/jax
61
62
	/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib
	/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
63
64
	cp dist/* /wheelhouse/
fi