Unverified Commit 35e687d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Remove remaining references to TensorFlow (#474)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 8e757a45
...@@ -31,8 +31,7 @@ jobs: ...@@ -31,8 +31,7 @@ jobs:
name: 'JAX' name: 'JAX'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: ghcr.io/nvidia/jax:latest
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
...@@ -40,9 +39,7 @@ jobs: ...@@ -40,9 +39,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: | run: pip install . -v
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip install . -v
env: env:
NVTE_FRAMEWORK: jax NVTE_FRAMEWORK: jax
- name: 'Sanity check' - name: 'Sanity check'
......
...@@ -50,16 +50,13 @@ jobs: ...@@ -50,16 +50,13 @@ jobs:
name: 'JAX Python' name: 'JAX Python'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: ghcr.io/nvidia/jax:latest
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: 'Lint' - name: 'Lint'
run: | run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax
export PYTHON_ONLY=1 export PYTHON_ONLY=1
export TE_PATH=. export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh bash ./qa/L0_jax_lint/test.sh
...@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo ...@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
Transformer model architectures. Transformer model architectures.
Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow. Transformer Engine supports the following DL frameworks: PyTorch and JAX (Flax, Praxis).
NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer` NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_. of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.
......
...@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension): ...@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library """Setup CMake extension for common library
Also builds JAX, TensorFlow, and userbuffers support if needed. Also builds JAX or userbuffers support if needed.
""" """
cmake_flags = [] cmake_flags = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment