Unverified Commit 3069ebc2 authored by Frédéric Bastien's avatar Frédéric Bastien Committed by GitHub
Browse files

Update installation instruction for JAX and add some dependencies. (#117)



* Update installation instructio for JAX and add some depenencies.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Bring back support for none pip installed pybind11.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Apply suggestions from code review
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

* Changes following review.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Change order to make it more clear.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Add other reviers suggestion.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* pybind11 is needed for all FW.
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Add flax as a dep
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>

* Update README.rst
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>

---------
Signed-off-by: default avatarFrederic Bastien <fbastien@nvidia.com>
Signed-off-by: default avatarFrédéric Bastien <frederic.bastien@gmail.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent db95afeb
...@@ -162,32 +162,32 @@ Transformer Engine comes preinstalled in the pyTorch container on ...@@ -162,32 +162,32 @@ Transformer Engine comes preinstalled in the pyTorch container on
From source From source
^^^^^^^^^^^ ^^^^^^^^^^^
First, install the prequisites. For JAX and Tensorflow, pybind11 must be installed:
.. code-block:: bash .. code-block:: bash
apt-get install ninja-build pybind11-dev pip install pybind11
Clone the repository and inside it type: Then, you can install this optional dependency:
.. code-block:: bash .. code-block:: bash
NVTE_FRAMEWORK=all pip install . # Building with all frameworks. pip install ninja
NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only.
NVTE_FRAMEWORK=jax pip install . # Building with JAX only.
You can also specify which framework bindings to build. The default is pytorch only. Install TE (optionally specifying the framework):
.. code-block:: bash .. code-block:: bash
# Build with TensorFlow bindings git clone https://github.com/NVIDIA/TransformerEngine.git
NVTE_FRAMEWORK=tensorflow pip install . cd TransformerEngine
# Build with Jax bindings # Execute one of the following command
NVTE_FRAMEWORK=jax pip install . NVTE_FRAMEWORK=all pip install . # Build TE for all supported frameworks.
NVTE_FRAMEWORK=pytorch pip install . # Build TE for PyTorch only.
NVTE_FRAMEWORK=jax pip install . # Build TE for JAX only.
NVTE_FRAMEWORK=tensorflow pip install . # Build TE for TensorFlow only.
# Build with all bindings (Pytorch, TF, Jax) If the framework is not explicitly specified, TE will be built for PyTorch only.
NVTE_FRAMEWORK=all pip install .
User Guide User Guide
---------- ----------
......
...@@ -29,9 +29,12 @@ pip - from GitHub ...@@ -29,9 +29,12 @@ pip - from GitHub
Additional Prerequisites Additional Prerequisites
^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^
1. `CMake <https://cmake.org/>`__ version 3.18 or later 1. `CMake <https://cmake.org/>`__ version 3.18 or later.
2. `pyTorch <https://pytorch.org/>`__ with GPU support 2. [For pyTorch support] `pyTorch <https://pytorch.org/>`__ with GPU support.
3. `Ninja <https://ninja-build.org/>`__ 3. [For JAX support] `JAX <https://github.com/google/jax/>`__ with GPU support, version >= 0.4.7.
4. [For TensorFlow support] `TensorFlow <https://www.tensorflow.org/>`__ with GPU support.
5. `pybind11`: `pip install pybind11`.
6. [Optional] `Ninja <https://ninja-build.org/>`__: `pip install ninja`.
Installation (stable release) Installation (stable release)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
......
...@@ -160,23 +160,24 @@ class PyTorchBuilder(FrameworkBuilderBase): ...@@ -160,23 +160,24 @@ class PyTorchBuilder(FrameworkBuilderBase):
def install_requires(): def install_requires():
return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",] return ["flash-attn @ git+https://github.com/ksivaman/flash-attention.git@hopper",]
class JaxBuilder(FrameworkBuilderBase): class TensorFlowBuilder(FrameworkBuilderBase):
def cmake_flags(self): def cmake_flags(self):
return ["-DENABLE_JAX=ON"] return ["-DENABLE_TENSORFLOW=ON"]
def run(self, extensions): def run(self, extensions):
print("Building jax extensions!") print("Building TensorFlow extensions!")
class TensorFlowBuilder(FrameworkBuilderBase): class JaxBuilder(FrameworkBuilderBase):
def cmake_flags(self): def cmake_flags(self):
return ["-DENABLE_TENSORFLOW=ON"] p = [d for d in sys.path if 'dist-packages' in d][0]
return ["-DENABLE_JAX=ON", "-DCMAKE_PREFIX_PATH="+p]
def run(self, extensions): def run(self, extensions):
print("Building TensorFlow extensions!") print("Building jax extensions!")
@staticmethod
def install_requires(): def install_requires():
return ["pydantic",] # TODO: find a way to install pybind11 and ninja directly.
return ['cmake', 'flax']
ext_modules = [] ext_modules = []
dlfw_builder_funcs = [] dlfw_builder_funcs = []
...@@ -207,11 +208,16 @@ if framework in ("all", "pytorch"): ...@@ -207,11 +208,16 @@ if framework in ("all", "pytorch"):
if framework in ("all", "jax"): if framework in ("all", "jax"):
dlfw_builder_funcs.append(JaxBuilder) dlfw_builder_funcs.append(JaxBuilder)
# Trigger a better error when pybind11 isn't present.
# Sadly, if pybind11 was installed with `apt -y install pybind11-dev`
# This doesn't install a python packages. So the line bellow is too strict.
# When it fail, we need to detect if cmake will find pybind11.
# import pybind11
if framework in ("all", "tensorflow"): if framework in ("all", "tensorflow"):
dlfw_builder_funcs.append(TensorFlowBuilder) dlfw_builder_funcs.append(TensorFlowBuilder)
dlfw_install_requires = [] dlfw_install_requires = ['pydantic']
for builder in dlfw_builder_funcs: for builder in dlfw_builder_funcs:
dlfw_install_requires = dlfw_install_requires + builder.install_requires() dlfw_install_requires = dlfw_install_requires + builder.install_requires()
...@@ -272,10 +278,16 @@ class CMakeBuildExtension(build_ext, object): ...@@ -272,10 +278,16 @@ class CMakeBuildExtension(build_ext, object):
build_dir = os.path.abspath(build_dir) build_dir = os.path.abspath(build_dir)
cmake_args = [ cmake_args = [
"-GNinja",
"-DCMAKE_BUILD_TYPE=" + config, "-DCMAKE_BUILD_TYPE=" + config,
"-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir), "-DCMAKE_LIBRARY_OUTPUT_DIRECTORY_{}={}".format(config.upper(), build_dir),
] ]
try:
import ninja
except ImportError:
pass
else:
cmake_args.append("-GNinja")
cmake_args = cmake_args + self.dlfw_flags cmake_args = cmake_args + self.dlfw_flags
cmake_build_args = ["--config", config] cmake_build_args = ["--config", config]
...@@ -399,5 +411,10 @@ setup( ...@@ -399,5 +411,10 @@ setup(
ext_modules=ext_modules, ext_modules=ext_modules,
cmdclass={"build_ext": TEBuildExtension}, cmdclass={"build_ext": TEBuildExtension},
install_requires=dlfw_install_requires, install_requires=dlfw_install_requires,
extras_require={
'test': ['pytest',
'tensorflow_datasets'],
'test_pytest': ['onnxruntime',],
},
license_files=("LICENSE",), license_files=("LICENSE",),
) )
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment