Unverified Commit 35e687d0 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Remove remaining references to TensorFlow (#474)


Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 8e757a45
......@@ -31,8 +31,7 @@ jobs:
name: 'JAX'
runs-on: ubuntu-latest
container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
image: ghcr.io/nvidia/jax:latest
options: --user root
steps:
- name: 'Checkout'
......@@ -40,9 +39,7 @@ jobs:
with:
submodules: recursive
- name: 'Build'
run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html && \
pip install . -v
run: pip install . -v
env:
NVTE_FRAMEWORK: jax
- name: 'Sanity check'
......
......@@ -50,16 +50,13 @@ jobs:
name: 'JAX Python'
runs-on: ubuntu-latest
container:
#image: nvcr.io/nvidia/jax:XX.XX-py3 # Not yet available
image: nvcr.io/nvidia/tensorflow:23.03-tf2-py3
image: ghcr.io/nvidia/jax:latest
options: --user root
steps:
- name: 'Checkout'
uses: actions/checkout@v3
- name: 'Lint'
run: |
pip install --upgrade "jax[cuda12_local]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install flax
export PYTHON_ONLY=1
export TE_PATH=.
bash ./qa/L0_jax_lint/test.sh
......@@ -46,8 +46,8 @@ simplifying mixed precision training for users.
Highlights
----------
* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
......@@ -174,7 +174,7 @@ While the more granular modules in Transformer Engine allow building any Transfo
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
Transformer model architectures.
Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow.
Transformer Engine supports the following DL frameworks: PyTorch and JAX (Flax, Praxis).
NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.
......@@ -233,13 +233,13 @@ Integrations
Transformer Engine has been integrated with popular LLM frameworks such as:
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
* `NVIDIA NeMo <https://github.com/NVIDIA/NeMo>`_
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
* `NVIDIA NeMo <https://github.com/NVIDIA/NeMo>`_
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html>`_ - Coming soon!
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!
......@@ -249,7 +249,7 @@ Contributing
==================
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
Papers
==================
......@@ -262,9 +262,9 @@ Papers
Videos
==================
* `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
* `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
:target: https://opensource.org/licenses/Apache-2.0
......@@ -434,7 +434,7 @@ class CMakeBuildExtension(BuildExtension):
def setup_common_extension() -> CMakeExtension:
"""Setup CMake extension for common library
Also builds JAX, TensorFlow, and userbuffers support if needed.
Also builds JAX or userbuffers support if needed.
"""
cmake_flags = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment