Unverified Commit 35e687d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Remove remaining references to TensorFlow (#474)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 8e757a45
...@@ -31,8 +31,7 @@ jobs: ...@@ -31,8 +31,7 @@ jobs:
name: 'JAX' name: 'JAX'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: ghcr.io/nvidia/jax:latest
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
...@@ -40,9 +39,7 @@ jobs: ...@@ -40,9 +39,7 @@ jobs:
with: with:
submodules: recursive submodules: recursive
- name: 'Build' - name: 'Build'
run: | run: pip install . -v
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip install . -v
env: env:
NVTE_FRAMEWORK: jax NVTE_FRAMEWORK: jax
- name: 'Sanity check' - name: 'Sanity check'
......
...@@ -50,16 +50,13 @@ jobs: ...@@ -50,16 +50,13 @@ jobs:
name: 'JAX Python' name: 'JAX Python'
runs-on: ubuntu-latest runs-on: ubuntu-latest
container: container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available image: ghcr.io/nvidia/jax:latest
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
options: --user root options: --user root
steps: steps:
- name: 'Checkout' - name: 'Checkout'
uses: actions/checkout@v3 uses: actions/checkout@v3
- name: 'Lint' - name: 'Lint'
run: | run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax
export PYTHON_ONLY=1 export PYTHON_ONLY=1
export TE_PATH=. export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh bash ./qa/L0_jax_lint/test.sh
...@@ -46,8 +46,8 @@ simplifying mixed precision training for users. ...@@ -46,8 +46,8 @@ simplifying mixed precision training for users.
Highlights Highlights
---------- ----------
* Easy-to-use modules for building Transformer layers with FP8 support * Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models * Optimizations (e.g. fused kernels) for Transformer models
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs * Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later * Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
...@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo ...@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
Transformer model architectures. Transformer model architectures.
Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow. Transformer Engine supports the following DL frameworks: PyTorch and JAX (Flax, Praxis).
NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer` NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_. of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.
...@@ -233,13 +233,13 @@ Integrations ...@@ -233,13 +233,13 @@ Integrations
Transformer Engine has been integrated with popular LLM frameworks such as: Transformer Engine has been integrated with popular LLM frameworks such as:
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_ * `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_ * `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_ * `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_ * `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_ * `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_ * `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
* `NVIDIA NeMo <https://github.com/NVIDIA/NeMo>`_ * `NVIDIA NeMo <https://github.com/NVIDIA/NeMo>`_
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html>`_ - Coming soon! * `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html>`_ - Coming soon!
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon! * `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon! * `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!
...@@ -249,7 +249,7 @@ Contributing ...@@ -249,7 +249,7 @@ Contributing
================== ==================
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests, We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide. follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
Papers Papers
================== ==================
...@@ -262,9 +262,9 @@ Papers ...@@ -262,9 +262,9 @@ Papers
Videos Videos
================== ==================
* `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_ * `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_ * `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_ * `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg .. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
:target: https://opensource.org/licenses/Apache-2.0 :target: https://opensource.org/licenses/Apache-2.0
...@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension): ...@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
def setup_common_extension() -> CMakeExtension: def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library """Setup CMake extension for common library
Also builds JAX, TensorFlow, and userbuffers support if needed. Also builds JAX or userbuffers support if needed.
""" """
cmake_flags = [] cmake_flags = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment