Commit 0a5016b1 authored by wenjh's avatar wenjh
Browse files

Merge nv release_v2.9


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parents 063ef88d 70f53666
recursive-include transformer_engine/common/include *.*
...@@ -205,7 +205,7 @@ pip Installation ...@@ -205,7 +205,7 @@ pip Installation
**Prerequisites for pip installation:** **Prerequisites for pip installation:**
* A compatible C++ compiler * A compatible C++ compiler
* CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) installed * CUDA Toolkit with cuDNN and NVCC (NVIDIA CUDA Compiler) if installing from source.
To install the latest stable version with pip: To install the latest stable version with pip:
......
...@@ -295,11 +295,9 @@ def cuda_archs() -> str: ...@@ -295,11 +295,9 @@ def cuda_archs() -> str:
if archs is None: if archs is None:
version = cuda_version() version = cuda_version()
if version >= (13, 0): if version >= (13, 0):
archs = "75;80;89;90;100;100a;103a;120" archs = "75;80;89;90;100;120"
elif version >= (12, 9):
archs = "70;80;89;90;100;100a;103a;120"
elif version >= (12, 8): elif version >= (12, 8):
archs = "70;80;89;90;100;100a;120" archs = "70;80;89;90;100;120"
else: else:
archs = "70;80;89;90" archs = "70;80;89;90"
return archs return archs
......
...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64 ...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_aarch64
WORKDIR /TransformerEngine/ WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/ COPY ../.. /TransformerEngine/
ARG VER="12-3" ARG CUDA_MAJOR="12"
ARG ARCH="aarch64" ARG CUDA_MINOR="3"
RUN dnf -y install vim
# Args for build_wheels.sh
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}
ENV CUDA_MAJOR=${CUDA_MAJOR}
# Cuda toolkit, cudnn, driver. # Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/sbsa/cuda-rhel8.repo
RUN dnf -y install epel-release RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
cuda-libraries-${VER}.${ARCH} \ cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64 \
cuda-libraries-devel-${VER}.${ARCH} cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.aarch64
RUN dnf -y install --allowerasing cudnn9-cuda-12 RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN rm -rf /var/cache/dnf/* RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit RUN dnf -y install cuda-toolkit-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN dnf -y install glog.aarch64 glog-devel.aarch64 RUN dnf -y install glog.aarch64 glog-devel.aarch64
RUN dnf -y install libnccl libnccl-devel libnccl-static
ENV PATH="/usr/local/cuda/bin:${PATH}" ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_aarch64", "true", "true", "false", "false", "false"] CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_aarch64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]
...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64 ...@@ -7,23 +7,34 @@ FROM quay.io/pypa/manylinux_2_28_x86_64
WORKDIR /TransformerEngine/ WORKDIR /TransformerEngine/
COPY ../.. /TransformerEngine/ COPY ../.. /TransformerEngine/
ARG VER="12-3" ARG CUDA_MAJOR="12"
ARG ARCH="x86_64" ARG CUDA_MINOR="3"
RUN dnf -y install vim
# Args for build_wheels.sh
ARG BUILD_METAPACKAGE=true
ARG BUILD_COMMON=true
ARG BUILD_PYTORCH=true
ARG BUILD_JAX=true
ENV BUILD_METAPACKAGE=${BUILD_METAPACKAGE}
ENV BUILD_COMMON=${BUILD_COMMON}
ENV BUILD_PYTORCH=${BUILD_PYTORCH}
ENV BUILD_JAX=${BUILD_JAX}
ENV CUDA_MAJOR=${CUDA_MAJOR}
# Cuda toolkit, cudnn, driver. # Cuda toolkit, cudnn, driver.
RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo RUN dnf config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo
RUN dnf -y install epel-release RUN dnf -y install epel-release
RUN dnf -y install cuda-compiler-${VER}.${ARCH} \ RUN dnf -y install cuda-compiler-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
cuda-libraries-${VER}.${ARCH} \ cuda-libraries-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64 \
cuda-libraries-devel-${VER}.${ARCH} cuda-libraries-devel-${CUDA_MAJOR}-${CUDA_MINOR}.x86_64
RUN dnf -y install --allowerasing cudnn9-cuda-12 RUN dnf -y install --allowerasing cudnn9-cuda-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN rm -rf /var/cache/dnf/* RUN rm -rf /var/cache/dnf/*
RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf RUN echo "/usr/local/cuda/lib64" >> /etc/ld.so.conf.d/999_nvidia_cuda.conf
RUN dnf -y install cuda-toolkit RUN dnf -y install cuda-toolkit-${CUDA_MAJOR}
RUN dnf clean all RUN dnf clean all
RUN dnf -y install glog.x86_64 glog-devel.x86_64 RUN dnf -y install glog.x86_64 glog-devel.x86_64
RUN dnf -y install libnccl libnccl-devel libnccl-static
ENV PATH="/usr/local/cuda/bin:${PATH}" ENV PATH="/usr/local/cuda/bin:${PATH}"
ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}" ENV LD_LIBRARY_PATH="/usr/local/cuda/lib64:${LD_LIBRARY_PATH}"
...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda ...@@ -33,4 +44,4 @@ ENV CUDA_PATH=/usr/local/cuda
ENV CUDADIR=/usr/local/cuda ENV CUDADIR=/usr/local/cuda
ENV NVTE_RELEASE_BUILD=1 ENV NVTE_RELEASE_BUILD=1
CMD ["/bin/bash", "/TransformerEngine/build_tools/wheel_utils/build_wheels.sh", "manylinux_2_28_x86_64", "true", "true", "true", "true", "true"] CMD ["/bin/bash", "-c", "bash /TransformerEngine/build_tools/wheel_utils/build_wheels.sh manylinux_2_28_x86_64 $BUILD_METAPACKAGE $BUILD_COMMON $BUILD_PYTORCH $BUILD_JAX $CUDA_MAJOR"]
\ No newline at end of file
...@@ -9,8 +9,10 @@ BUILD_METAPACKAGE=${2:-true} ...@@ -9,8 +9,10 @@ BUILD_METAPACKAGE=${2:-true}
BUILD_COMMON=${3:-true} BUILD_COMMON=${3:-true}
BUILD_PYTORCH=${4:-true} BUILD_PYTORCH=${4:-true}
BUILD_JAX=${5:-true} BUILD_JAX=${5:-true}
CUDA_MAJOR=${6:-12}
export NVTE_RELEASE_BUILD=1 export NVTE_RELEASE_BUILD=1
export PIP_CONSTRAINT=""
export TARGET_BRANCH=${TARGET_BRANCH:-} export TARGET_BRANCH=${TARGET_BRANCH:-}
mkdir -p /wheelhouse/logs mkdir -p /wheelhouse/logs
...@@ -21,7 +23,7 @@ git checkout $TARGET_BRANCH ...@@ -21,7 +23,7 @@ git checkout $TARGET_BRANCH
git submodule update --init --recursive git submodule update --init --recursive
# Install deps # Install deps
/opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja /opt/python/cp310-cp310/bin/pip install cmake pybind11[global] ninja setuptools wheel nvidia-mathdx==25.1.1
if $BUILD_METAPACKAGE ; then if $BUILD_METAPACKAGE ; then
cd /TransformerEngine cd /TransformerEngine
...@@ -36,32 +38,32 @@ if $BUILD_COMMON ; then ...@@ -36,32 +38,32 @@ if $BUILD_COMMON ; then
# Create the wheel. # Create the wheel.
/opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt /opt/python/cp310-cp310/bin/python setup.py bdist_wheel --verbose --python-tag=py3 --plat-name=$PLATFORM 2>&1 | tee /wheelhouse/logs/common.txt
# Repack the wheel for cuda specific package, i.e. cu12. # Repack the wheel for specific cuda version.
/opt/python/cp310-cp310/bin/wheel unpack dist/* /opt/python/cp310-cp310/bin/wheel unpack dist/*
# From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore). # From python 3.10 to 3.11, the package name delimiter in metadata got changed from - (hyphen) to _ (underscore).
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu${CUDA_MAJOR}/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu${CUDA_MAJOR}-${VERSION}.dist-info"
/opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE} /opt/python/cp310-cp310/bin/wheel pack ${WHL_BASE}
# Rename the wheel to make it python version agnostic. # Rename the wheel to make it python version agnostic.
whl_name=$(basename dist/*) whl_name=$(basename dist/*)
IFS='-' read -ra whl_parts <<< "$whl_name" IFS='-' read -ra whl_parts <<< "$whl_name"
whl_name_target="${whl_parts[0]}_cu12-${whl_parts[1]}-py3-none-${whl_parts[4]}" whl_name_target="${whl_parts[0]}_cu${CUDA_MAJOR}-${whl_parts[1]}-py3-none-${whl_parts[4]}"
rm -rf $WHL_BASE dist rm -rf $WHL_BASE dist
mv *.whl /wheelhouse/"$whl_name_target" mv *.whl /wheelhouse/"$whl_name_target"
fi fi
if $BUILD_PYTORCH ; then if $BUILD_PYTORCH ; then
cd /TransformerEngine/transformer_engine/pytorch cd /TransformerEngine/transformer_engine/pytorch
/opt/python/cp310-cp310/bin/pip install torch /opt/python/cp310-cp310/bin/pip install torch
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/torch.txt
cp dist/* /wheelhouse/ cp dist/* /wheelhouse/
fi fi
if $BUILD_JAX ; then if $BUILD_JAX ; then
cd /TransformerEngine/transformer_engine/jax cd /TransformerEngine/transformer_engine/jax
/opt/python/cp310-cp310/bin/pip install "jax[cuda12_local]" jaxlib /opt/python/cp310-cp310/bin/pip install "jax[cuda${CUDA_MAJOR}_local]" jaxlib
/opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt /opt/python/cp310-cp310/bin/python setup.py sdist 2>&1 | tee /wheelhouse/logs/jax.txt
cp dist/* /wheelhouse/ cp dist/* /wheelhouse/
fi fi
...@@ -2,7 +2,29 @@ ...@@ -2,7 +2,29 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
docker build --no-cache -t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch . # Remove leftovers.
rm -rf aarch_wheelhouse_cu12 aarch_wheelhouse_cu13
# CUDA 12.
docker build --no-cache \
--build-arg CUDA_MAJOR=12 \
--build-arg CUDA_MINOR=3 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu12
# CUDA 13.
docker build --no-cache \
--build-arg CUDA_MAJOR=13 \
--build-arg CUDA_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "aarch_wheel" -f build_tools/wheel_utils/Dockerfile.aarch .
docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel" docker run --runtime=nvidia --gpus=all --ipc=host "aarch_wheel"
rm -rf aarch_wheelhouse docker cp $(docker ps -aq | head -1):/wheelhouse aarch_wheelhouse_cu13
docker cp $(docker ps -aq | head -1):/wheelhouse/ aarch_wheelhouse
...@@ -2,7 +2,29 @@ ...@@ -2,7 +2,29 @@
# #
# See LICENSE for license information. # See LICENSE for license information.
docker build --no-cache -t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 . # Remove leftovers.
rm -rf x86_wheelhouse_cu12 x86_wheelhouse_cu13
# CUDA 12.
docker build --no-cache \
--build-arg CUDA_MAJOR=12 \
--build-arg CUDA_MINOR=3 \
--build-arg BUILD_METAPACKAGE=true \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=true \
--build-arg BUILD_JAX=true \
-t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu12
# CUDA 13.
docker build --no-cache \
--build-arg CUDA_MAJOR=13 \
--build-arg CUDA_MINOR=0 \
--build-arg BUILD_METAPACKAGE=false \
--build-arg BUILD_COMMON=true \
--build-arg BUILD_PYTORCH=false \
--build-arg BUILD_JAX=false \
-t "x86_wheel" -f build_tools/wheel_utils/Dockerfile.x86 .
docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel" docker run --runtime=nvidia --gpus=all --ipc=host "x86_wheel"
rm -rf x86_wheelhouse docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse_cu13
docker cp $(docker ps -aq | head -1):/wheelhouse x86_wheelhouse
...@@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr ...@@ -38,6 +38,14 @@ Transformer Engine can be directly installed from `our PyPI <https://pypi.org/pr
To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions. To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch]). Transformer Engine ships wheels for the core library. Source distributions are shipped for the JAX and PyTorch extensions.
The core package from Transformer Engine (without any framework extensions) can be installed via:
.. code-block:: bash
pip3 install transformer_engine[core]
By default, this will install the core library compiled for CUDA 12. The cuda major version can be specified by modified the extra dependency to `core_cu12` or `core_cu13`.
pip - from GitHub pip - from GitHub
----------------------- -----------------------
......
...@@ -670,7 +670,7 @@ class TestEncoder(unittest.TestCase): ...@@ -670,7 +670,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4(self): def test_te_nvfp4(self):
"""Test Transformer Engine with NVFP4""" """Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling") result = self.exec(True, "NVFP4BlockScaling")
assert result[0] < 0.451 and result[1] > 0.79 assert result[0] < 0.451 and result[1] > 0.788
@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16") @unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_shardy(self): def test_te_bf16_shardy(self):
...@@ -708,7 +708,7 @@ class TestEncoder(unittest.TestCase): ...@@ -708,7 +708,7 @@ class TestEncoder(unittest.TestCase):
def test_te_nvfp4_shardy(self): def test_te_nvfp4_shardy(self):
"""Test Transformer Engine with NVFP4""" """Test Transformer Engine with NVFP4"""
result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True) result = self.exec(True, "NVFP4BlockScaling", enable_shardy=True)
assert result[0] < 0.451 and result[1] > 0.79 assert result[0] < 0.451 and result[1] > 0.788
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -385,7 +385,7 @@ class TestEncoder(unittest.TestCase): ...@@ -385,7 +385,7 @@ class TestEncoder(unittest.TestCase):
self.args.use_fp8 = True self.args.use_fp8 = True
self.args.fp8_recipe = "NVFP4BlockScaling" self.args.fp8_recipe = "NVFP4BlockScaling"
actual = train_and_evaluate(self.args) actual = train_and_evaluate(self.args)
assert actual[0] < 0.476 and actual[1] > 0.775 assert actual[0] < 0.477 and actual[1] > 0.769
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -32,6 +32,6 @@ pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/ ...@@ -32,6 +32,6 @@ pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/
# standard sanity and numerics tests with initialized debug # standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -27,8 +27,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -27,8 +27,8 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py"
ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" ROCBLAS_ATOMICS_MOD=0 HIPBLASLT_ATOMICS_MOD=0 PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
......
...@@ -8,5 +8,6 @@ set -xe ...@@ -8,5 +8,6 @@ set -xe
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh
...@@ -8,4 +8,5 @@ set -xe ...@@ -8,4 +8,5 @@ set -xe
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* # Use --xla_gpu_enable_triton_gemm=false to ensure the reference JAX implementation we are using is accurate.
XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_*
...@@ -161,8 +161,11 @@ if __name__ == "__main__": ...@@ -161,8 +161,11 @@ if __name__ == "__main__":
ext_modules = [] ext_modules = []
package_data = {} package_data = {}
include_package_data = False include_package_data = False
install_requires = ([f"transformer_engine_cu12=={__version__}"],) install_requires = []
extras_require = { extras_require = {
"core": [f"transformer_engine_cu12=={__version__}"],
"core_cu12": [f"transformer_engine_cu12=={__version__}"],
"core_cu13": [f"transformer_engine_cu13=={__version__}"],
"pytorch": [f"transformer_engine_torch=={__version__}"], "pytorch": [f"transformer_engine_torch=={__version__}"],
"jax": [f"transformer_engine_jax=={__version__}"], "jax": [f"transformer_engine_jax=={__version__}"],
} }
......
...@@ -8,7 +8,7 @@ from itertools import product ...@@ -8,7 +8,7 @@ from itertools import product
import pytest import pytest
import jax import jax
from jax.experimental.pjit import pjit, _UNSPECIFIED from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED
from transformer_engine.jax.sharding import MeshResource from transformer_engine.jax.sharding import MeshResource
...@@ -154,13 +154,15 @@ def compare_ops( ...@@ -154,13 +154,15 @@ def compare_ops(
grad_args = tuple(range(len(inputs))) grad_args = tuple(range(len(inputs)))
target_grad_func = jax.value_and_grad(target_func, argnums=grad_args) target_grad_func = jax.value_and_grad(target_func, argnums=grad_args)
target_pjitter = pjit(target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) target_jitter = jax.jit(
target_fwd, target_grads = target_pjitter(*inputs, **kwargs) target_grad_func, in_shardings=in_shardings, out_shardings=out_shardings
target_hlo = target_pjitter.lower(*inputs, **kwargs).compile().as_text() )
target_fwd, target_grads = target_jitter(*inputs, **kwargs)
target_hlo = target_jitter.lower(*inputs, **kwargs).compile().as_text()
ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args) ref_grad_func = jax.value_and_grad(ref_func, argnums=grad_args)
ref_pjitter = pjit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings) ref_jitter = jax.jit(ref_grad_func, in_shardings=in_shardings, out_shardings=out_shardings)
ref_fwd, ref_grads = ref_pjitter(*inputs, **kwargs) ref_fwd, ref_grads = ref_jitter(*inputs, **kwargs)
assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype) assert_allclose(target_fwd, ref_fwd, dtype=metric_fwd_dtype)
......
...@@ -40,7 +40,6 @@ from transformer_engine.jax.quantize import ( ...@@ -40,7 +40,6 @@ from transformer_engine.jax.quantize import (
QuantizerFactory, QuantizerFactory,
QuantizeLayout, QuantizeLayout,
noop_quantizer_set, noop_quantizer_set,
should_use_rht,
) )
from transformer_engine.jax.quantize import helper from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation from transformer_engine.jax.activation import activation
...@@ -685,21 +684,14 @@ class TestQuantize: ...@@ -685,21 +684,14 @@ class TestQuantize:
Purely quantization related tests that will always test on a wider set of types and shapes Purely quantization related tests that will always test on a wider set of types and shapes
""" """
def _skip_for_fp4(self, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Temporary hack to skip unsupported FP4 cases until we implement them""" """Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes(): if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}") pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return return
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
...@@ -780,22 +772,8 @@ class TestQuantize: ...@@ -780,22 +772,8 @@ class TestQuantize:
assert_dequantized_scaled_tensor(scaled_tensor, x) assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison( def _should_use_precise_comparison(
self, in_dtype, scaling_mode, q_layout, input_shape, flatten_axis self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
): ):
# TODO(jberchtold): Remove this hack once we have a better solution to ensure bitwise identical results between TE and JAX RHT+quant implementations. Currently for certain shapes the quantized fp4 data differs by a small amount on <0.5% of the values.
RHT_SLIGHT_MISMATCH_SHAPES = [
((32, 256, 128), -1),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
if (
should_use_rht(scaling_mode, q_layout=q_layout)
and (input_shape, flatten_axis) in RHT_SLIGHT_MISMATCH_SHAPES
):
# TE fused RHT+quant and JAX RHT+quant have slight implementation differences which can lead to small numerical differences on certain shapes
return False
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16: if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation # With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False return False
...@@ -805,7 +783,7 @@ class TestQuantize: ...@@ -805,7 +783,7 @@ class TestQuantize:
def test_quantize_bitwise( def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -816,28 +794,20 @@ class TestQuantize: ...@@ -816,28 +794,20 @@ class TestQuantize:
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try: te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
def test_quantize_bitwise_jitted( def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
): ):
self._skip_for_fp4(input_shape, q_dtype, scaling_mode, q_layout, flatten_axis) self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype) input = jax.random.uniform(key, input_shape, in_dtype)
...@@ -851,21 +821,13 @@ class TestQuantize: ...@@ -851,21 +821,13 @@ class TestQuantize:
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis) jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
try: te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
except AssertionError as e:
if should_use_rht(scaling_mode, q_layout=q_layout) and in_dtype != jnp.bfloat16:
error_message = e.args[0]
if "RHT requires input to be bfloat16" in error_message:
# Successfully caught the expected error, early return from the test
return
raise e
assert_bitwise_scaled_tensors( assert_bitwise_scaled_tensors(
te_output, te_output,
jax_output, jax_output,
precise_comparison=self._should_use_precise_comparison( precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, q_layout, input_shape, flatten_axis in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
), ),
) )
...@@ -985,12 +947,6 @@ class TestStochasticRounding: ...@@ -985,12 +947,6 @@ class TestStochasticRounding:
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis): def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other.""" """Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
# HACK: FIXME TODO(jberchtold)
row = reduce(operator.mul, input_shape[flatten_axis:], 1)
col = reduce(operator.mul, input_shape[:flatten_axis], 1)
will_use_rht = should_use_rht(scaling_mode, q_layout=q_layout)
if will_use_rht and (row % 64 != 0 or col % 128 != 0):
pytest.skip("Unfused RHT is not supported currently, skipping")
key = jax.random.PRNGKey(0) key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype) inputs = jax.random.uniform(key, input_shape, in_dtype)
...@@ -1007,6 +963,97 @@ class TestStochasticRounding: ...@@ -1007,6 +963,97 @@ class TestStochasticRounding:
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4) assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(inputs, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(inputs, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(te_output, jax_output)
def _ref_gemm_with_jnp_dot(self, a, b, data_layout):
if data_layout[0] == "T":
a = jnp.swapaxes(a, -1, -2)
if data_layout[1] == "T":
b = jnp.swapaxes(b, -1, -2)
return jnp.dot(a, b)
def _generate_gemm_input(self, m, n, k, data_layout):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 2)
x = jax.random.uniform(
subkeys[0],
(m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m),
dtype=jnp.bfloat16,
) / jnp.sqrt(k)
w = jax.random.uniform(
subkeys[1],
(k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k),
dtype=jnp.bfloat16,
) / jnp.sqrt(n)
lhs_contracting_dim = (1,) if data_layout[0] == "N" else (0,)
rhs_contracting_dim = (0,) if data_layout[1] == "N" else (1,)
contracting_dims = (lhs_contracting_dim, rhs_contracting_dim)
return (x, w, contracting_dims)
@pytest_parametrize_wrapper("m,n,k", [(64, 32, 64)])
# We do not test NN and TT layouts here as they do not have both inputs using RHT due to RHT only supporting the colwise layout currently
@pytest_parametrize_wrapper("data_layout", ["TN", "NT"])
@pytest_parametrize_wrapper("with_jax_gemm", [True, False])
def test_rht_gemm(self, in_dtype, q_dtype, scaling_mode, m, n, k, data_layout, with_jax_gemm):
key = jax.random.PRNGKey(0)
lhs_scaling_mode, rhs_scaling_mode = scaling_mode, scaling_mode
x, w, contracting_dims = self._generate_gemm_input(m, n, k, data_layout)
lhs_quantizer = QuantizerFactory.create(
scaling_mode=lhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
rhs_quantizer = QuantizerFactory.create(
scaling_mode=rhs_scaling_mode,
q_dtype=jnp.float4_e2m1fn,
use_rht=True,
)
with use_jax_gemm(enabled=with_jax_gemm):
primitive_out = tex.gemm(
x,
w,
contracting_dims=contracting_dims,
lhs_quantizer=lhs_quantizer,
rhs_quantizer=rhs_quantizer,
)
ref_out = self._ref_gemm_with_jnp_dot(x, w, data_layout)
assert_allclose(primitive_out, ref_out, dtype=jnp.float4_e2m1fn)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE) @pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("input_shape", [(8, 16, 32)]) @pytest_parametrize_wrapper("input_shape", [(8, 16, 32)])
......
...@@ -134,9 +134,12 @@ class TestDistributedLayernorm: ...@@ -134,9 +134,12 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) b_named_sharding = NamedSharding(mesh, b_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
beta_ = jax.device_put(beta, b_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -148,8 +151,11 @@ class TestDistributedLayernorm: ...@@ -148,8 +151,11 @@ class TestDistributedLayernorm:
grad_args=(0, 1, 2), grad_args=(0, 1, 2),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec, b_pspec), in_shardings=(x_named_sharding, g_named_sharding, b_named_sharding),
out_shardings=(None, (x_pspec, g_pspec, b_pspec)), out_shardings=(
None,
(x_named_sharding, g_named_sharding, b_named_sharding),
),
) )
except AssertionError as err: except AssertionError as err:
# Layernorm should still produce the correct numerical result with # Layernorm should still produce the correct numerical result with
...@@ -210,8 +216,10 @@ class TestDistributedLayernorm: ...@@ -210,8 +216,10 @@ class TestDistributedLayernorm:
devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape)
mesh = Mesh(devices, mesh_axes) mesh = Mesh(devices, mesh_axes)
with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource): with mesh, autocast(enabled=True, recipe=fp8_recipe, mesh_resource=mesh_resource):
x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) x_named_sharding = NamedSharding(mesh, x_pspec)
gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) g_named_sharding = NamedSharding(mesh, g_pspec)
x_ = jax.device_put(x, x_named_sharding)
gamma_ = jax.device_put(gamma, g_named_sharding)
with warnings.catch_warnings(record=True) as warns: with warnings.catch_warnings(record=True) as warns:
try: try:
...@@ -223,8 +231,8 @@ class TestDistributedLayernorm: ...@@ -223,8 +231,8 @@ class TestDistributedLayernorm:
grad_args=(0, 1), grad_args=(0, 1),
metric_fwd_dtype=q_dtype, metric_fwd_dtype=q_dtype,
metric_bwd_dtype=q_dtype, metric_bwd_dtype=q_dtype,
in_shardings=(x_pspec, g_pspec), in_shardings=(x_named_sharding, g_named_sharding),
out_shardings=(None, (x_pspec, g_pspec)), out_shardings=(None, (x_named_sharding, g_named_sharding)),
) )
except AssertionError as err: except AssertionError as err:
# RmsNorm should still produce the correct numerical result with # RmsNorm should still produce the correct numerical result with
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment