Unverified Commit 35e687d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Remove remaining references to TensorFlow (#474)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 8e757a45
......@@ -31,8 +31,7 @@ jobs:
name: 'JAX'
runs-on: ubuntu-latest
container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
image: ghcr.io/nvidia/jax:latest
options: --user root
steps:
- name: 'Checkout'
......@@ -40,9 +39,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip install . -v
run: pip install . -v
env:
NVTE_FRAMEWORK: jax
- name: 'Sanity check'
......
......@@ -50,16 +50,13 @@ jobs:
name: 'JAX Python'
runs-on: ubuntu-latest
container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
image: ghcr.io/nvidia/jax:latest
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
......@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
Transformer model architectures.
Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow.
Transformer Engine supports the following DL frameworks: PyTorch and JAX (Flax, Praxis).
NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.
......
......@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library
Also builds JAX, TensorFlow, and userbuffers support if needed.
Also builds JAX or userbuffers support if needed.
"""
cmake_flags = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment