README.rst 13.3 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10

    See LICENSE for license information.

|License|

Transformer Engine
==================

11
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
12
13
14
15
16

Latest News
==================


17
18
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_

19
.. image:: docs/examples/H200-NeMo-performance.png
20
21
22
23
24
25
26
27
28
  :width: 600
  :alt: H200

* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
29
30
31

What is Transformer Engine?
==================
Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
35
36
37
.. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that
Santosh Bhavani's avatar
Santosh Bhavani committed
38
can be used seamlessly with your framework-specific code. TE also includes a framework agnostic
Ming-Xu Huang's avatar
Ming-Xu Huang committed
39
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
Przemek Tredak's avatar
Przemek Tredak committed
40
41

As the number of parameters in Transformer models continues to grow, training and inference for
42
architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning
Przemek Tredak's avatar
Przemek Tredak committed
43
44
45
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with
Santosh Bhavani's avatar
Santosh Bhavani committed
46
minimal differences in accuracy as compared to FP32 training. With Hopper GPU
Przemek Tredak's avatar
Przemek Tredak committed
47
48
architecture FP8 precision was introduced, which offers improved performance over FP16 with no
degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is
Santosh Bhavani's avatar
Santosh Bhavani committed
49
not available natively in frameworks today.
Przemek Tredak's avatar
Przemek Tredak committed
50
51

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Santosh Bhavani's avatar
Santosh Bhavani committed
52
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
53
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
54
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
Santosh Bhavani's avatar
Santosh Bhavani committed
55
56
57
58
simplifying mixed precision training for users.

Highlights
----------
Przemek Tredak's avatar
Przemek Tredak committed
59

60
61
* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
62
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
Santosh Bhavani's avatar
Santosh Bhavani committed
63
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Ming-Xu Huang's avatar
Ming-Xu Huang committed
64
65

Examples
Santosh Bhavani's avatar
Santosh Bhavani committed
66
----------
Ming-Xu Huang's avatar
Ming-Xu Huang committed
67

Santosh Bhavani's avatar
Santosh Bhavani committed
68
PyTorch
Ming-Xu Huang's avatar
Ming-Xu Huang committed
69
^^^^^^^
Przemek Tredak's avatar
Przemek Tredak committed
70
71
72
73
74
75
76
77
78
79
80
81
82

.. code-block:: python

  import torch
  import transformer_engine.pytorch as te
  from transformer_engine.common import recipe

  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048

  # Initialize model and inputs.
nzmora-nvidia's avatar
nzmora-nvidia committed
83
  model = te.Linear(in_features, out_features, bias=True)
Przemek Tredak's avatar
Przemek Tredak committed
84
85
  inp = torch.randn(hidden_size, in_features, device="cuda")

Ming-Xu Huang's avatar
Ming-Xu Huang committed
86
  # Create an FP8 recipe. Note: All input args are optional.
Przemek Tredak's avatar
Przemek Tredak committed
87
88
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

Ming-Xu Huang's avatar
Ming-Xu Huang committed
89
  # Enable autocasting for the forward pass
Przemek Tredak's avatar
Przemek Tredak committed
90
91
92
93
94
95
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      out = model(inp)

  loss = out.sum()
  loss.backward()

Ming-Xu Huang's avatar
Ming-Xu Huang committed
96
97
98
99

JAX
^^^

100
101
102
Flax
~~~~

Ming-Xu Huang's avatar
Ming-Xu Huang committed
103
104
.. code-block:: python

105
  import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
106
107
108
  import jax
  import jax.numpy as jnp
  import transformer_engine.jax as te
109
  import transformer_engine.jax.flax as te_flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
  from transformer_engine.common import recipe

  BATCH = 32
  SEQLEN = 128
  HIDDEN = 1024

  # Initialize RNG and inputs.
  rng = jax.random.PRNGKey(0)
  init_rng, data_rng = jax.random.split(rng)
  inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

  # Create an FP8 recipe. Note: All input args are optional.
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)

  # Enable autocasting for the forward pass
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
126
      model = te_flax.DenseGeneral(features=HIDDEN)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
127
128
129
130
131
132
133

      def loss_fn(params, other_vars, inp):
        out = model.apply({'params':params, **other_vars}, inp)
        return jnp.mean(out)

      # Initialize models.
      variables = model.init(init_rng, inp)
134
      other_variables, params = flax.core.pop(variables, 'params')
Ming-Xu Huang's avatar
Ming-Xu Huang committed
135
136
137
138
139
140
141

      # Construct the forward and backward function
      fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

      for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

Przemek Tredak's avatar
Przemek Tredak committed
142
143
144
.. overview-end-marker-do-not-remove

Installation
Santosh Bhavani's avatar
Santosh Bhavani committed
145
146
----------
.. installation
Przemek Tredak's avatar
Przemek Tredak committed
147

148
Pre-requisites
Przemek Tredak's avatar
Przemek Tredak committed
149
^^^^^^^^^^^^^^^^^^^^
150
151
152
153
154
* Linux x86_64
* CUDA 11.8+ for Hopper and CUDA 12.1+ for Ada
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
Przemek Tredak's avatar
Przemek Tredak committed
155

156
157
158
159
160
Docker
^^^^^^^^^^^^^^^^^^^^

The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively,
Przemek Tredak's avatar
Przemek Tredak committed
161

Santosh Bhavani's avatar
Santosh Bhavani committed
162
.. code-block:: bash
Przemek Tredak's avatar
Przemek Tredak committed
163

164
    docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
165

166
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.
167

168
pip
Santosh Bhavani's avatar
Santosh Bhavani committed
169
^^^^^^^^^^^^^^^^^^^^
170
171
172
173
174
175
176
To install the latest stable version of Transformer Engine,

.. code-block:: bash

    pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch).
177

Santosh Bhavani's avatar
Santosh Bhavani committed
178
179
From source
^^^^^^^^^^^
180
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.
181

182
Compiling with FlashAttention-2
183
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
184
185
186
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance. 

It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue. If the errors persist, install a supported version of FlashAttention-1 (v1.0.6 to v1.0.9).
187

188
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
189

190
191
FP8 Convergence
==================
Przemek Tredak's avatar
Przemek Tredak committed
192

193
194
195
196
197
198
199
200
201
202
203
FP8 has been tested extensively across different model architectures and configurations and we found **no significant difference** between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.

+------------+------------------+---------------------------------------------------------------------------------------------------------+
| Model      | Framework        | Source                                                                                                  |
+============+==================+=========================================================================================================+
| T5-770M    |  JAX/T5x         | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance|
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| MPT-1.3B   |  Mosaic Composer | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1                                              |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-5B     |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
204
| GPT-5B     |  NeMo Framework  | Available on request                                                                                    |
205
206
207
208
209
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-7B  |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| T5-11B     |  JAX/T5x         | Available on request                                                                                    |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
210
| GPT-22B    |  NeMo Framework  | Available on request                                                                                    |
211
212
213
214
215
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-70B |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-175B   |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
Przemek Tredak's avatar
Przemek Tredak committed
216

Santosh Bhavani's avatar
Santosh Bhavani committed
217
218
Integrations
==================
Przemek Tredak's avatar
Przemek Tredak committed
219

220
Transformer Engine has been integrated with popular LLM frameworks such as:
Przemek Tredak's avatar
Przemek Tredak committed
221

222
223
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
224
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
225
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
226
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
227
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
228
* `NVIDIA NeMo Framework <https://github.com/NVIDIA/NeMo-Megatron-Launcher>`_
229
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features-v2-tensor-parallelism.html>`
230
231
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!
232
* `GPT-NeoX <https://github.com/EleutherAI/gpt-neox>`_ - Coming soon!
233

Przemek Tredak's avatar
Przemek Tredak committed
234

Santosh Bhavani's avatar
Santosh Bhavani committed
235
236
Contributing
==================
Przemek Tredak's avatar
Przemek Tredak committed
237

Santosh Bhavani's avatar
Santosh Bhavani committed
238
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
239
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
Przemek Tredak's avatar
Przemek Tredak committed
240

Santosh Bhavani's avatar
Santosh Bhavani committed
241
242
243
244
245
Papers
==================

* `Attention original paper <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`_
* `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
Przemek Tredak's avatar
Przemek Tredak committed
246
* `Megatron-LM sequence parallel <https://arxiv.org/pdf/2205.05198.pdf>`_
Kirthi Shankar Sivamani's avatar
Kirthi Shankar Sivamani committed
247
* `FP8 Formats for Deep Learning <https://arxiv.org/abs/2209.05433>`_
Przemek Tredak's avatar
Przemek Tredak committed
248

Santosh Bhavani's avatar
Santosh Bhavani committed
249
250
251
Videos
==================

252
253
254
* `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
255

Przemek Tredak's avatar
Przemek Tredak committed
256
257
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
   :target: https://opensource.org/licenses/Apache-2.0