# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. # A workflow to trigger TE build on GitHub name: 'Build' on: pull_request: workflow_dispatch: jobs: core: name: 'Core' runs-on: ubuntu-latest container: image: nvcr.io/nvidia/cuda:12.1.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 pip install cmake==3.21.0 pybind11[global] ninja nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install --no-build-isolation . -v env: NVTE_FRAMEWORK: none MAX_JOBS: 1 - name: 'Sanity check' run: python3 -c "import transformer_engine" working-directory: / pytorch: name: 'PyTorch' runs-on: ubuntu-latest container: image: nvcr.io/nvidia/cuda:12.8.0-devel-ubuntu22.04 options: --user root steps: - name: 'Dependencies' run: | apt-get update apt-get install -y git python3.9 pip cudnn9-cuda-12 pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install --no-build-isolation . -v --no-deps env: NVTE_FRAMEWORK: pytorch MAX_JOBS: 1 - name: 'Sanity check' run: python3 tests/pytorch/test_sanity_import.py jax: name: 'JAX' runs-on: ubuntu-latest container: image: ghcr.io/nvidia/jax:jax options: --user root steps: - name: 'Dependencies' run: pip install pybind11[global] nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install --no-build-isolation . -v env: NVTE_FRAMEWORK: jax MAX_JOBS: 1 - name: 'Sanity check' run: python3 tests/jax/test_sanity_import.py all: name: 'All' runs-on: ubuntu-latest container: image: ghcr.io/nvidia/jax:jax options: --user root steps: - name: 'Dependencies' run: pip install torch pybind11[global] einops onnxscript nvidia-mathdx==25.1.1 - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install --no-build-isolation . -v --no-deps env: NVTE_FRAMEWORK: all MAX_JOBS: 1 - name: 'Sanity check' run: python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py