# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. # A workflow to trigger TE build on GitHub name: 'Build' on: pull_request: workflow_dispatch: jobs: pytorch: name: 'PyTorch' runs-on: ubuntu-latest container: image: nvcr.io/nvidia/pytorch:23.03-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: | mkdir -p wheelhouse && \ NVTE_FRAMEWORK=pytorch pip wheel -w wheelhouse . -v - name: 'Upload wheel' uses: actions/upload-artifact@v3 with: name: te_wheel_pyt path: wheelhouse/transformer_engine*.whl retention-days: 7 - name: 'Install' run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl - name: 'Sanity check' run: python tests/pytorch/test_sanity_import.py jax: name: 'JAX' runs-on: ubuntu-latest container: #image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ mkdir -p wheelhouse && \ NVTE_FRAMEWORK=jax pip wheel -w wheelhouse . -v - name: 'Upload wheel' uses: actions/upload-artifact@v3 with: name: te_wheel_jax path: wheelhouse/transformer_engine*.whl retention-days: 7 - name: 'Install' run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl - name: 'Sanity check' run: python tests/jax/test_sanity_import.py tensorflow: name: 'TensorFlow' runs-on: ubuntu-latest container: image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: | pip install ninja pybind11 && \ mkdir -p wheelhouse && \ NVTE_FRAMEWORK=tensorflow pip wheel -w wheelhouse . -v - name: 'Upload wheel' uses: actions/upload-artifact@v3 with: name: te_wheel_tf path: wheelhouse/transformer_engine*.whl retention-days: 7 - name: 'Install' run: pip install --no-cache-dir wheelhouse/transformer_engine*.whl - name: 'Sanity check' run: python tests/tensorflow/test_sanity_import.py