README.rst 10.4 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10

    See LICENSE for license information.

|License|

Transformer Engine
==================

Santosh Bhavani's avatar
Santosh Bhavani committed
11
12
13
14
15
16
17
18
19
20
`Quickstart <#examples>`_ | `Installation <#installation>`_ | `User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_ | `Examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_ | `Model Support <#model-support>`_ | `Integrations <#integrations>`_ | `Release notes <https://docs.nvidia.com/deeplearning/transformer-engine/release-notes/index.html>`_

Latest News
==================

* [04/2023] `Benchmarking Large Language Models on NVIDIA H100 GPUs with CoreWeave (Part 1) <https://www.mosaicml.com/blog/coreweave-nvidia-h100-part-1>`_


What is Transformer Engine?
==================
Przemek Tredak's avatar
Przemek Tredak committed
21
22
23
24
25
26
.. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that
Santosh Bhavani's avatar
Santosh Bhavani committed
27
can be used seamlessly with your framework-specific code. TE also includes a framework agnostic
Ming-Xu Huang's avatar
Ming-Xu Huang committed
28
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
Przemek Tredak's avatar
Przemek Tredak committed
29
30

As the number of parameters in Transformer models continues to grow, training and inference for
Ming-Xu Huang's avatar
Ming-Xu Huang committed
31
architectures such as BERT, GPT and T5 become very memory and compute intensive. Most deep learning
Przemek Tredak's avatar
Przemek Tredak committed
32
33
34
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with
Santosh Bhavani's avatar
Santosh Bhavani committed
35
minimal differences in accuracy as compared to FP32 training. With Hopper GPU
Przemek Tredak's avatar
Przemek Tredak committed
36
37
architecture FP8 precision was introduced, which offers improved performance over FP16 with no
degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is
Santosh Bhavani's avatar
Santosh Bhavani committed
38
not available natively in frameworks today.
Przemek Tredak's avatar
Przemek Tredak committed
39
40

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Santosh Bhavani's avatar
Santosh Bhavani committed
41
42
Model (LLM) libraries. It provides a Python API consisting of modules to easily build a Transformer
layer as well as a framework agnostic library in C++ including structs and kernels needed for FP8 support.
Ming-Xu Huang's avatar
Ming-Xu Huang committed
43
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
Santosh Bhavani's avatar
Santosh Bhavani committed
44
45
46
47
simplifying mixed precision training for users.

Highlights
----------
Przemek Tredak's avatar
Przemek Tredak committed
48

Santosh Bhavani's avatar
Santosh Bhavani committed
49
50
51
52
* Easy-to-use modules for building Transformer layers with FP8 support 
* Optimizations (e.g. fused kernels) for Transformer models 
* Support for FP8 on NVIDIA Hopper (H100) and NVIDIA Ada (RTX 4 GPUs)
* Support for optimizations across all precisions (FP16, BF16) on NVIDIA Ampere GPU architecture generations and later
Ming-Xu Huang's avatar
Ming-Xu Huang committed
53
54

Examples
Santosh Bhavani's avatar
Santosh Bhavani committed
55
----------
Ming-Xu Huang's avatar
Ming-Xu Huang committed
56

Santosh Bhavani's avatar
Santosh Bhavani committed
57
PyTorch
Ming-Xu Huang's avatar
Ming-Xu Huang committed
58
^^^^^^^
Przemek Tredak's avatar
Przemek Tredak committed
59
60
61
62
63
64
65
66
67
68
69
70
71

.. code-block:: python

  import torch
  import transformer_engine.pytorch as te
  from transformer_engine.common import recipe

  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048

  # Initialize model and inputs.
nzmora-nvidia's avatar
nzmora-nvidia committed
72
  model = te.Linear(in_features, out_features, bias=True)
Przemek Tredak's avatar
Przemek Tredak committed
73
74
  inp = torch.randn(hidden_size, in_features, device="cuda")

Ming-Xu Huang's avatar
Ming-Xu Huang committed
75
  # Create an FP8 recipe. Note: All input args are optional.
Przemek Tredak's avatar
Przemek Tredak committed
76
77
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

Ming-Xu Huang's avatar
Ming-Xu Huang committed
78
  # Enable autocasting for the forward pass
Przemek Tredak's avatar
Przemek Tredak committed
79
80
81
82
83
84
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      out = model(inp)

  loss = out.sum()
  loss.backward()

Ming-Xu Huang's avatar
Ming-Xu Huang committed
85
86
87
88

JAX
^^^

89
90
91
Flax
~~~~

Ming-Xu Huang's avatar
Ming-Xu Huang committed
92
93
94
95
96
.. code-block:: python

  import jax
  import jax.numpy as jnp
  import transformer_engine.jax as te
97
  import transformer_engine.jax.flax as te_flax
Ming-Xu Huang's avatar
Ming-Xu Huang committed
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
  from transformer_engine.common import recipe

  BATCH = 32
  SEQLEN = 128
  HIDDEN = 1024

  # Initialize RNG and inputs.
  rng = jax.random.PRNGKey(0)
  init_rng, data_rng = jax.random.split(rng)
  inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

  # Create an FP8 recipe. Note: All input args are optional.
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)

  # Enable autocasting for the forward pass
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
114
      model = te_flax.DenseGeneral(features=HIDDEN)
Ming-Xu Huang's avatar
Ming-Xu Huang committed
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131

      def loss_fn(params, other_vars, inp):
        out = model.apply({'params':params, **other_vars}, inp)
        return jnp.mean(out)

      # Initialize models.
      variables = model.init(init_rng, inp)
      other_variables, params = variables.pop('params')

      # Construct the forward and backward function
      fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

      for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
        # Update FP8 metas
        other_variables = te.update_fp8_metas(other_grads)

132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
TensorFlow
^^^^^^^^^^

.. code-block:: python

  import tensorflow as tf
  import transformer_engine.tensorflow as te
  from transformer_engine.common import recipe
  
  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048
  
  # Initialize model and inputs.
  model = te.Dense(out_features, use_bias=True)
  inp = tf.random.normal((hidden_size, in_features))
  
  optimizer = tf.keras.optimizers.Adam(0.001)
  
  # Create FP8 recipe. Note: All input args are optional.
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)
  
  with tf.GradientTape(persistent=True) as tape:
      # Enables autocasting for the forward pass
      with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
          out = model(inp)
      loss = tf.reduce_sum(out)
  grads = tape.gradient(loss, model.trainable_variables)
  optimizer.apply_gradients(zip(grads, model.trainable_variables))

Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
.. overview-end-marker-do-not-remove

Installation
Santosh Bhavani's avatar
Santosh Bhavani committed
166
167
----------
.. installation
Przemek Tredak's avatar
Przemek Tredak committed
168
169
170
171

In the NGC container
^^^^^^^^^^^^^^^^^^^^

Santosh Bhavani's avatar
Santosh Bhavani committed
172
173
The quickest way to get started with Transformer Engine is the NGC PyTorch container on
`NVIDIA GPU Cloud Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_ (versions 22.09 and later).
Przemek Tredak's avatar
Przemek Tredak committed
174

Santosh Bhavani's avatar
Santosh Bhavani committed
175
.. code-block:: bash
Przemek Tredak's avatar
Przemek Tredak committed
176

Santosh Bhavani's avatar
Santosh Bhavani committed
177
    docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:23.04-py3
178

Santosh Bhavani's avatar
Santosh Bhavani committed
179
Where 23.04 is the container version. For example, 23.04 for April 2023 release.
180

Santosh Bhavani's avatar
Santosh Bhavani committed
181
182
183
184
185
186
187
Pre-requisites
^^^^^^^^^^^^^^^^^^^^
* Linux x86_64
* CUDA 11.8 or later
* NVIDIA Driver supporting CUDA 11.8 or later
* cuDNN 8.1 or later
* For FP8 fused attention, CUDA 12.1 or later, NVIDIA Driver supporting CUDA 12.1 or later, and cuDNN 8.9 or later.
188

Santosh Bhavani's avatar
Santosh Bhavani committed
189
190
From source
^^^^^^^^^^^
191

Santosh Bhavani's avatar
Santosh Bhavani committed
192
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html>`_.
Przemek Tredak's avatar
Przemek Tredak committed
193

Santosh Bhavani's avatar
Santosh Bhavani committed
194
195
Model Support
----------
Przemek Tredak's avatar
Przemek Tredak committed
196
197
198

While the more granular modules in Transformer Engine allow building any Transformer architecture,
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
Santosh Bhavani's avatar
Santosh Bhavani committed
199
200
201
Transformer model architectures.

Transformer Engine supports the following DL frameworks: PyTorch, JAX (Flax, Praxis), and TensorFlow.
Przemek Tredak's avatar
Przemek Tredak committed
202

Santosh Bhavani's avatar
Santosh Bhavani committed
203
NOTE: For simplicity, we only show PyTorch examples below. For the usage of `TransformerLayer`
Ming-Xu Huang's avatar
Ming-Xu Huang committed
204
205
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.

Przemek Tredak's avatar
Przemek Tredak committed
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
GPT
^^^

`GPT` architecture has `LayerNorm` at the input side (before `QKV Gemm`) and the residual connection
is taken from the input of that `LayerNorm`. In TE this can be achieved by setting the following
arguments in the `TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          apply_residual_connection_post_layernorm=False,
          output_layernorm=False,
          layer_type="encoder",
  )

BERT
^^^^

`BERT` architecture has `LayerNorm` at the output side (after the final `BiasDropoutAdd`) and the
residual connection is taken from the output of that `LayerNorm`. In TE this can be achieved by
setting the following arguments in the `TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          apply_residual_connection_post_layernorm=True,
          output_layernorm=True,
          layer_type="encoder",
  )

T5
^^

`T5` architecture has an additional `cross-attention` + `BiasDropoutAdd` + `LayerNorm` block before
the `MLP` layer. In TE this can be added by setting the `layer_type` to `decoder` in the
`TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          layer_type="decoder",
  )

Santosh Bhavani's avatar
Santosh Bhavani committed
255
256
Integrations
==================
Przemek Tredak's avatar
Przemek Tredak committed
257

Santosh Bhavani's avatar
Santosh Bhavani committed
258
Transformer Engine has been integrated with several popular open-source DL frameworks such as:
Przemek Tredak's avatar
Przemek Tredak committed
259

260
* `DeepSpeed <https://github.com/microsoft/DeepSpeed/pull/3731>`_ 
Santosh Bhavani's avatar
Santosh Bhavani committed
261
262
263
* `Hugging Face Accelerate <https://github.com/huggingface/accelerate/releases/tag/v0.17.0>`_ 
* `MosaicML Composer <https://github.com/mosaicml/composer/releases/tag/v0.13.1>`_ 
* `Megatron-LM <https://github.com/NVIDIA/Megatron-LM>`_ 
264
265
* `Amazon SageMaker Model Parallel Library <https://docs.aws.amazon.com/sagemaker/latest/dg/model-parallel.html>`_ - Coming soon!
* `Colossal-AI <https://github.com/hpcaitech/ColossalAI>`_ - Coming soon!
Santosh Bhavani's avatar
Santosh Bhavani committed
266
* `Lightning <https://github.com/Lightning-AI/lightning/issues/17172>`_ - Coming soon!
267
268
* `PeriFlow <https://github.com/friendliai/periflow-python-sdk>`_ - Coming soon!

Przemek Tredak's avatar
Przemek Tredak committed
269

Santosh Bhavani's avatar
Santosh Bhavani committed
270
271
Contributing
==================
Przemek Tredak's avatar
Przemek Tredak committed
272

Santosh Bhavani's avatar
Santosh Bhavani committed
273
274
We welcome contributions to Transformer Engine! To contribute to Transformer Engine and make pull requests,
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ guide. 
Przemek Tredak's avatar
Przemek Tredak committed
275

Santosh Bhavani's avatar
Santosh Bhavani committed
276
277
278
279
280
Papers
==================

* `Attention original paper <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`_
* `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_
Przemek Tredak's avatar
Przemek Tredak committed
281
282
* `Megatron-LM sequence parallel <https://arxiv.org/pdf/2205.05198.pdf>`_

Santosh Bhavani's avatar
Santosh Bhavani committed
283
284
285
286
287
288
289
Videos
==================

* `FP8 Training with Transformer Engine <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51393>`_  
* `FP8 for Deep Learning <https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s52166/>`_  
* `Inside the Hopper Architecture <https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s42663/>`_  

Przemek Tredak's avatar
Przemek Tredak committed
290
291
.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
   :target: https://opensource.org/licenses/Apache-2.0