Unverified Commit 3069ebc2 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Update installation instruction for JAX and add some dependencies. (#117)



* Update installation instructio for JAX and add some depenencies.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Bring back support for none pip installed pybind11.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

* Changes following review.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Change order to make it more clear.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Add other reviers suggestion.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* pybind11 is needed for all FW.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Add flax as a dep
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Update README.rst
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

---------
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent db95afeb
......@@ -162,32 +162,32 @@ Transformer Engine comes preinstalled in the pyTorch container on
From source
^^^^^^^^^^^
First, install the prequisites.
For JAX and Tensorflow, pybind11 must be installed:
.. code-block:: bash
apt-get install ninja-build pybind11-dev
pip install pybind11
Clone the repository and inside it type:
Then, you can install this optional dependency:
.. code-block:: bash
NVTE_FRAMEWORK=all pip install . # Building with all frameworks.
NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only.
NVTE_FRAMEWORK=jax pip install . # Building with JAX only.
pip install ninja
You can also specify which framework bindings to build. The default is pytorch only.
Install TE (optionally specifying the framework):
.. code-block:: bash
# Build with TensorFlow bindings
NVTE_FRAMEWORK=tensorflow pip install .
git clone https://github.com/NVIDIA/TransformerEngine.git
cd TransformerEngine
# Build with Jax bindings
NVTE_FRAMEWORK=jax pip install .
# Execute one of the following command
NVTE_FRAMEWORK=all pip install . # Build TE for all supported frameworks.
NVTE_FRAMEWORK=pytorch pip install . # Build TE for PyTorch only.
NVTE_FRAMEWORK=jax pip install . # Build TE for JAX only.
NVTE_FRAMEWORK=tensorflow pip install . # Build TE for TensorFlow only.
# Build with all bindings (Pytorch, TF, Jax)
NVTE_FRAMEWORK=all pip install .
If the framework is not explicitly specified, TE will be built for PyTorch only.
User Guide
----------
......
......@@ -29,9 +29,12 @@ pip - from GitHub
Additional Prerequisites
^^^^^^^^^^^^^^^^^^^^^^^^
1. `CMake <https://cmake.org/>`__ version 3.18 or later
2. `pyTorch <https://pytorch.org/>`__ with GPU support
3. `Ninja <https://ninja-build.org/>`__
1. `CMake <https://cmake.org/>`__ version 3.18 or later.
2. [For pyTorch support] `pyTorch <https://pytorch.org/>`__ with GPU support.
3. [For JAX support] `JAX <https://github.com/google/jax/>`__ with GPU support, version >= 0.4.7.
4. [For TensorFlow support] `TensorFlow <https://www.tensorflow.org/>`__ with GPU support.
5. `pybind11`: `pip install pybind11`.
6. [Optional] `Ninja <https://ninja-build.org/>`__: `pip install ninja`.
Installation (stable release)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
......@@ -160,23 +160,24 @@ class PyTorchBuilder(FrameworkBuilderBase):
def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",]
class JaxBuilder(FrameworkBuilderBase):
class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self):
return ["-DENABLE_JAX=ON"]
return ["-DENABLE_TENSORFLOW=ON"]
def run(self, extensions):
print("Building jax extensions!")
print("Building TensorFlow extensions!")
class TensorFlowBuilder(FrameworkBuilderBase):
class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self):
return ["-DENABLE_TENSORFLOW=ON"]
p = [d for d in sys.path if 'dist-packages' in d][0]
return ["-DENABLE_JAX=ON", "-DCMAKE_PREFIX_PATH="+p]
def run(self, extensions):
print("Building TensorFlow extensions!")
print("Building jax extensions!")
@staticmethod
def install_requires():
return ["pydantic",]
# TODO: find a way to install pybind11 and ninja directly.
return ['cmake', 'flax']
ext_modules = []
dlfw_builder_funcs = []
......@@ -207,11 +208,16 @@ if framework in ("all", "pytorch"):
if framework in ("all", "jax"):
dlfw_builder_funcs.append(JaxBuilder)
# Trigger a better error when pybind11 isn't present.
# Sadly, if pybind11 was installed with `apt -y install pybind11-dev`
# This doesn't install a python packages. So the line bellow is too strict.
# When it fail, we need to detect if cmake will find pybind11.
# import pybind11
if framework in ("all", "tensorflow"):
dlfw_builder_funcs.append(TensorFlowBuilder)
dlfw_install_requires = []
dlfw_install_requires = ['pydantic']
for builder in dlfw_builder_funcs:
dlfw_install_requires = dlfw_install_requires + builder.install_requires()
......@@ -272,10 +278,16 @@ class CMakeBuildExtension(build_ext, object):
build_dir = os.path.abspath(build_dir)
cmake_args = [
"-GNinja",
"-DCMAKE_BUILD_TYPE=" + config,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
]
try:
import ninja
except ImportError:
pass
else:
cmake_args.append("-GNinja")
cmake_args = cmake_args + self.dlfw_flags
cmake_build_args = ["--config", config]
......@@ -399,5 +411,10 @@ setup(
ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension},
install_requires=dlfw_install_requires,
extras_require={
'test': ['pytest',
'tensorflow_datasets'],
'test_pytest': ['onnxruntime',],
},
license_files=("LICENSE",),
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment