# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. # A workflow to trigger TE build on GitHub name: 'Build' on: pull_request: workflow_dispatch: jobs: pytorch: name: 'PyTorch' runs-on: ubuntu-latest if: false # NGC PyTorch container does not fit on GitHub runner container: image: nvcr.io/nvidia/pytorch:23.03-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install . -v --no-deps env: NVTE_FRAMEWORK: pytorch MAX_JOBS: 1 - name: 'Sanity check' run: python tests/pytorch/test_sanity_import.py jax: name: 'JAX' runs-on: ubuntu-latest container: #image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: | pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \ pip install . -v env: NVTE_FRAMEWORK: jax - name: 'Sanity check' run: python tests/jax/test_sanity_import.py tensorflow: name: 'TensorFlow' runs-on: ubuntu-latest container: image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3 options: --user root steps: - name: 'Checkout' uses: actions/checkout@v3 with: submodules: recursive - name: 'Build' run: pip install . -v env: NVTE_FRAMEWORK: tensorflow - name: 'Sanity check' run: python tests/tensorflow/test_sanity_import.py