README.rst 15.7 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10

    See LICENSE for license information.

|License|

Transformer Engine
==================

11
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `FP8 Convergence <#fp8-convergence>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
12
13

Latest News
Santosh Bhavani's avatar
Santosh Bhavani committed
14
===========
Santosh Bhavani's avatar
Santosh Bhavani committed
15

Santosh Bhavani's avatar
Santosh Bhavani committed
16
* [03/2024] `Turbocharged Training: Optimizing the Databricks Mosaic AI stack with FP8 <https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8>`_
17
* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-release-notes.html>`_
18
19
* [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 <https://developer.nvidia.com/blog/new-nvidia-nemo-framework-features-and-nvidia-h200-supercharge-llm-training-performance-and-versatility/>`_

20
.. image:: docs/examples/H200-NeMo-performance.png
21
22
23
24
25
26
27
28
29
  :width: 600
  :alt: H200

* [11/2023] `Inflection-2: The Next Step Up <https://inflection.ai/inflection-2>`_
* [11/2023] `Unleashing The Power Of Transformers With NVIDIA Transformer Engine <https://lambdalabs.com/blog/unleashing-the-power-of-transformers-with-nvidia-transformer-engine>`_
* [11/2023] `Accelerating PyTorch Training Workloads with FP8 <https://towardsdatascience.com/accelerating-pytorch-training-workloads-with-fp8-5a5123aec7d7>`_
* [09/2023] `Transformer Engine added to AWS DL Container for PyTorch Training <https://github.com/aws/deep-learning-containers/pull/3315>`_
* [06/2023] `Breaking MLPerf Training Records with NVIDIA H100 GPUs <https://developer.nvidia.com/blog/breaking-mlperf-training-records-with-nvidia-h100-gpus/>`_
* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
30
31

What is Transformer Engine?
Santosh Bhavani's avatar
Santosh Bhavani committed
32
===========================
Przemek Tredak's avatar
Przemek Tredak committed
33
34
35
36
37
38
.. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that
Santosh Bhavani's avatar
Santosh Bhavani committed
39
can be used seamlessly with your framework-specific code. TE also includes a framework agnostic
Ming-Xu Huang's avatar
Ming-Xu Huang committed
40
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
Przemek Tredak's avatar
Przemek Tredak committed
41
42

As the number of parameters in Transformer models continues to grow, training and inference for
43
architectures such as BERT, GPT and T5 become very memory and compute-intensive. Most deep learning
Przemek Tredak's avatar
Przemek Tredak committed
44
45
46
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with
Santosh Bhavani's avatar
Santosh Bhavani committed
47
minimal differences in accuracy as compared to FP32 training. With Hopper GPU
Przemek Tredak's avatar
Przemek Tredak committed
48
49
architecture FP8 precision was introduced, which offers improved performance over FP16 with no
degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is
Santosh Bhavani's avatar
Santosh Bhavani committed
50
not available natively in frameworks today.
Przemek Tredak's avatar
Przemek Tredak committed
51
52

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Santosh Bhavani's avatar
Santosh Bhavani committed
53
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
54
layer as well as a framework-agnostic library in C++ including structs and kernels needed for FP8 support.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
55
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
Santosh Bhavani's avatar
Santosh Bhavani committed
56
57
58
simplifying mixed precision training for users.

Highlights
Santosh Bhavani's avatar
Santosh Bhavani committed
59
==========
Przemek Tredak's avatar
Przemek Tredak committed
60

61
62
* Easy-to-use modules for building Transformer layers with FP8 support
* Optimizations (e.g. fused kernels) for Transformer models
Przemyslaw Tredak's avatar
Przemyslaw Tredak committed
63
* Support for FP8 on NVIDIA Hopper and NVIDIA Ada GPUs
Santosh Bhavani's avatar
Santosh Bhavani committed
64
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Ming-Xu Huang's avatar
Ming-Xu Huang committed
65
66

Examples
Santosh Bhavani's avatar
Santosh Bhavani committed
67
========
Ming-Xu Huang's avatar
Ming-Xu Huang committed
68

Santosh Bhavani's avatar
Santosh Bhavani committed
69
PyTorch
Ming-Xu Huang's avatar
Ming-Xu Huang committed
70
^^^^^^^
Przemek Tredak's avatar
Przemek Tredak committed
71
72
73
74
75
76
77
78
79
80
81
82
83

.. code-block:: python

  import torch
  import transformer_engine.pytorch as te
  from transformer_engine.common import recipe

  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048

  # Initialize model and inputs.
nzmora-nvidia's avatar
nzmora-nvidia committed
84
  model = te.Linear(in_features, out_features, bias=True)
Przemek Tredak's avatar
Przemek Tredak committed
85
86
  inp = torch.randn(hidden_size, in_features, device="cuda")

Ming-Xu Huang's avatar
Ming-Xu Huang committed
87
  # Create an FP8 recipe. Note: All input args are optional.
88
  fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.E4M3)
Przemek Tredak's avatar
Przemek Tredak committed
89

Ming-Xu Huang's avatar
Ming-Xu Huang committed
90
  # Enable autocasting for the forward pass
Przemek Tredak's avatar
Przemek Tredak committed
91
92
93
94
95
96
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      out = model(inp)

  loss = out.sum()
  loss.backward()

Ming-Xu Huang's avatar
Ming-Xu Huang committed
97
98
99
100

JAX
^^^

101
102
103
Flax
~~~~

Ming-Xu Huang's avatar
Ming-Xu Huang committed
104
105
.. code-block:: python

106
  import flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
107
108
109
  import jax
  import jax.numpy as jnp
  import transformer_engine.jax as te
110
  import transformer_engine.jax.flax as te_flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
111
112
113
114
115
116
117
118
119
120
121
122
  from transformer_engine.common import recipe

  BATCH = 32
  SEQLEN = 128
  HIDDEN = 1024

  # Initialize RNG and inputs.
  rng = jax.random.PRNGKey(0)
  init_rng, data_rng = jax.random.split(rng)
  inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

  # Create an FP8 recipe. Note: All input args are optional.
123
  fp8_recipe = recipe.DelayedScaling(margin=0, fp8_format=recipe.Format.HYBRID)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
124
125
126

  # Enable autocasting for the forward pass
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
127
      model = te_flax.DenseGeneral(features=HIDDEN)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
128
129
130
131
132
133
134

      def loss_fn(params, other_vars, inp):
        out = model.apply({'params':params, **other_vars}, inp)
        return jnp.mean(out)

      # Initialize models.
      variables = model.init(init_rng, inp)
135
      other_variables, params = flax.core.pop(variables, 'params')
Ming-Xu Huang's avatar
Ming-Xu Huang committed
136
137
138
139
140
141
142

      # Construct the forward and backward function
      fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

      for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)

Przemek Tredak's avatar
Przemek Tredak committed
143
144
145
.. overview-end-marker-do-not-remove

Installation
Santosh Bhavani's avatar
Santosh Bhavani committed
146
============
Santosh Bhavani's avatar
Santosh Bhavani committed
147
.. installation
Przemek Tredak's avatar
Przemek Tredak committed
148

149
Pre-requisites
Przemek Tredak's avatar
Przemek Tredak committed
150
^^^^^^^^^^^^^^^^^^^^
151
* Linux x86_64
152
153
* CUDA 12.0+ for Hopper and CUDA 12.1+ for Ada
* NVIDIA Driver supporting CUDA 12.0 or later
154
155
* cuDNN 8.1 or later
* For fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
Przemek Tredak's avatar
Przemek Tredak committed
156

157
158
159
160
161
Docker
^^^^^^^^^^^^^^^^^^^^

The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_. For example to use the NGC PyTorch container interactively,
Przemek Tredak's avatar
Przemek Tredak committed
162

Santosh Bhavani's avatar
Santosh Bhavani committed
163
.. code-block:: bash
Przemek Tredak's avatar
Przemek Tredak committed
164

165
    docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.10-py3
166

167
Where 23.10 is the container version. For example, 23.10 for the October 2023 release.
168

169
pip
Santosh Bhavani's avatar
Santosh Bhavani committed
170
^^^^^^^^^^^^^^^^^^^^
171
172
173
174
175
176
To install the latest stable version of Transformer Engine,

.. code-block:: bash

    pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable

177
178
179
180
181
182
183
184
185
This will automatically detect if any supported deep learning frameworks are installed and build Transformer Engine support for them. To explicitly specify frameworks, set the environment variable NVTE_FRAMEWORK to a comma-separated list (e.g. NVTE_FRAMEWORK=jax,pytorch,paddle).

Alternatively, the package can be directly installed from `Transformer Engine's PyPI <https://pypi.org/project/transformer-engine/>`_, e.g.

.. code-block:: bash

    pip install transformer_engine[pytorch]

To obtain the necessary Python bindings for Transformer Engine, the frameworks needed must be explicitly specified as extra dependencies in a comma-separated list (e.g. [jax,pytorch,paddle]). Transformer Engine ships wheels for the core library as well as the PaddlePaddle extensions. Source distributions are shipped for the JAX and PyTorch extensions.
186

Santosh Bhavani's avatar
Santosh Bhavani committed
187
188
From source
^^^^^^^^^^^
189
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_.
190

191
Compiling with FlashAttention-2
192
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
193
Transformer Engine release v0.11.0 adds support for FlashAttention-2 in PyTorch for improved performance.
194

195
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.
196

197
Note that NGC PyTorch 23.08+ containers include FlashAttention-2.
198

199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
Breaking Changes
================

v1.7: Padding mask definition for PyTorch
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
In an effort to unify the definition and usage of the attention mask across all three frameworks in Transformer Engine, the padding mask has changed from `True` meaning inclusion of the corresponding position in attention to exclusion of that position in our PyTorch implementation. Since v1.7, all attention mask types follow the same definition where `True` means masking out the corresponding position and `False` means including that position in attention calculation.

An example of this change is,

.. code-block:: bash

    # for a batch of 3 sequences where `a`s, `b`s and `c`s are the useful tokens
    # and `0`s are the padding tokens,
    [a, a, a, 0, 0,
     b, b, 0, 0, 0,
     c, c, c, c, 0]
    # the padding mask for this batch before v1.7 is,
    [ True,  True,  True, False, False,
      True,  True, False, False, False,
      True,  True,  True,  True, False]
    # and for v1.7 onwards it should be,
    [False, False, False,  True,  True,
     False, False,  True,  True,  True,
     False, False, False, False,  True]

224
FP8 Convergence
Santosh Bhavani's avatar
Santosh Bhavani committed
225
===============
Przemek Tredak's avatar
Przemek Tredak committed
226

227
228
229
230
231
232
233
234
235
236
237
FP8 has been tested extensively across different model architectures and configurations and we found **no significant difference** between FP8 and BF16 training loss curves. FP8 has also been validated for accuracy on downstream LLM tasks (e.g. LAMBADA and WikiText). Below are examples of models tested for convergence across different frameworks.

+------------+------------------+---------------------------------------------------------------------------------------------------------+
| Model      | Framework        | Source                                                                                                  |
+============+==================+=========================================================================================================+
| T5-770M    |  JAX/T5x         | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/t5x#convergence-and-performance|
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| MPT-1.3B   |  Mosaic Composer | https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1                                              |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-5B     |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
238
| GPT-5B     |  NeMo Framework  | Available on request                                                                                    |
239
240
241
242
243
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-7B  |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| T5-11B     |  JAX/T5x         | Available on request                                                                                    |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
Santosh Bhavani's avatar
Santosh Bhavani committed
244
245
| MPT-13B    |  Mosaic Composer | https://www.databricks.com/blog/turbocharged-training-optimizing-databricks-mosaic-ai-stack-fp8         |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
246
| GPT-22B    |  NeMo Framework  | Available on request                                                                                    |
247
248
249
250
251
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| LLama2-70B |  Alibaba Pai     | https://mp.weixin.qq.com/s/NQT0uKXLbXyh5031zBdeBQ                                                       |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
| GPT-175B   |  JAX/Paxml       | https://github.com/NVIDIA/JAX-Toolbox/tree/main/rosetta/rosetta/projects/pax#h100-results               |
+------------+------------------+---------------------------------------------------------------------------------------------------------+
Przemek Tredak's avatar
Przemek Tredak committed
252

Santosh Bhavani's avatar
Santosh Bhavani committed
253
Integrations
Santosh Bhavani's avatar
Santosh Bhavani committed
254
============
Przemek Tredak's avatar
Przemek Tredak committed
255

256
Transformer Engine has been integrated with popular LLM frameworks such as:
Przemek Tredak's avatar
Przemek Tredak committed
257

258
259
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_
260
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_
261
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_
262
* `NVIDIA JAX Toolbox <https://github.com/NVIDIA/JAX-Toolbox>`_
263
* `NVIDIA Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_
264
* `NVIDIA NeMo Framework <https://github.com/NVIDIA/NeMo-Megatron-Launcher>`_
265
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel-core-features-v2-tensor-parallelism.html>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
266
* `Levanter <https://github.com/stanford-crfm/levanter>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
267
* `Hugging Face Nanotron <https://github.com/huggingface/nanotron>`_ - Coming soon!
268
269
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!
270
* `GPT-NeoX <https://github.com/EleutherAI/gpt-neox>`_ - Coming soon!
271

Przemek Tredak's avatar
Przemek Tredak committed
272

Santosh Bhavani's avatar
Santosh Bhavani committed
273
Contributing
Santosh Bhavani's avatar
Santosh Bhavani committed
274
============
Przemek Tredak's avatar
Przemek Tredak committed
275

Santosh Bhavani's avatar
Santosh Bhavani committed
276
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
277
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide.
Przemek Tredak's avatar
Przemek Tredak committed
278

Santosh Bhavani's avatar
Santosh Bhavani committed
279
Papers
Santosh Bhavani's avatar
Santosh Bhavani committed
280
======
Santosh Bhavani's avatar
Santosh Bhavani committed
281
282
283

* `Attention original paper <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`_
* `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
Przemek Tredak's avatar
Przemek Tredak committed
284
* `Megatron-LM sequence parallel <https://arxiv.org/pdf/2205.05198.pdf>`_
Kirthi Shankar Sivamani's avatar
Kirthi Shankar Sivamani committed
285
* `FP8 Formats for Deep Learning <https://arxiv.org/abs/2209.05433>`_
Przemek Tredak's avatar
Przemek Tredak committed
286

Santosh Bhavani's avatar
Santosh Bhavani committed
287
Videos
Santosh Bhavani's avatar
Santosh Bhavani committed
288
======
Santosh Bhavani's avatar
Santosh Bhavani committed
289

Santosh Bhavani's avatar
Santosh Bhavani committed
290
* `What's New in Transformer Engine and FP8 Training | GTC 2024 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
291
292
* `FP8 Training with Transformer Engine | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_
* `FP8 for Deep Learning | GTC 2023 <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_
293
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_
Santosh Bhavani's avatar
Santosh Bhavani committed
294

Przemek Tredak's avatar
Przemek Tredak committed
295
296
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
   :target: https://opensource.org/licenses/Apache-2.0