README.rst 7.34 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
..
2
    Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16

    See LICENSE for license information.

|License|

Transformer Engine
==================

.. overview-begin-marker-do-not-remove

Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, including
using 8-bit floating point (FP8) precision on Hopper GPUs, to provide better performance with lower
memory utilization in both training and inference. TE provides a collection of highly optimized
building blocks for popular Transformer architectures and an automatic mixed precision-like API that
Ming-Xu Huang's avatar
Ming-Xu Huang committed
17
18
can be used seamlessly with your own framework-specific code. TE also includes a framework agnostic 
C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.
Przemek Tredak's avatar
Przemek Tredak committed
19
20

As the number of parameters in Transformer models continues to grow, training and inference for
Ming-Xu Huang's avatar
Ming-Xu Huang committed
21
architectures such as BERT, GPT and T5 become very memory and compute intensive. Most deep learning
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
26
27
28
29
30
frameworks train with FP32 by default. This is not essential, however, to achieve full accuracy for
many deep learning models. Using mixed-precision training, which combines single-precision (FP32)
with lower precision (e.g. FP16) format when training a model, results in significant speedups with
minimal differences in accuracy as compared to FP32 training. With the introduction of Hopper GPU
architecture FP8 precision was introduced, which offers improved performance over FP16 with no
degradation in accuracy. Although all major deep learning frameworks support FP16, FP8 support is
not available today.

TE addresses the problem of FP8 support by providing APIs that integrate with popular Large Language
Ming-Xu Huang's avatar
Ming-Xu Huang committed
31
32
33
Model (LLM) libraries. It provides python layer consisting of modules to easily build Transformer
layer as well as framework agnostic library in C++ including structs and kernels needed for FP8 support.
Modules provided by TE internally maintain scaling factors and other values needed for FP8 training, greatly
Przemek Tredak's avatar
Przemek Tredak committed
34
35
simplifying for the users.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
36
37
38
39
40
41

Examples
--------

pyTorch
^^^^^^^
Przemek Tredak's avatar
Przemek Tredak committed
42
43
44
45
46
47
48
49
50
51
52
53
54

.. code-block:: python

  import torch
  import transformer_engine.pytorch as te
  from transformer_engine.common import recipe

  # Set dimensions.
  in_features = 768
  out_features = 3072
  hidden_size = 2048

  # Initialize model and inputs.
nzmora-nvidia's avatar
nzmora-nvidia committed
55
  model = te.Linear(in_features, out_features, bias=True)
Przemek Tredak's avatar
Przemek Tredak committed
56
57
  inp = torch.randn(hidden_size, in_features, device="cuda")

Ming-Xu Huang's avatar
Ming-Xu Huang committed
58
  # Create an FP8 recipe. Note: All input args are optional.
Przemek Tredak's avatar
Przemek Tredak committed
59
60
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.E4M3)

Ming-Xu Huang's avatar
Ming-Xu Huang committed
61
  # Enable autocasting for the forward pass
Przemek Tredak's avatar
Przemek Tredak committed
62
63
64
65
66
67
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      out = model(inp)

  loss = out.sum()
  loss.backward()

Ming-Xu Huang's avatar
Ming-Xu Huang committed
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

JAX
^^^

.. code-block:: python

  import jax
  import jax.numpy as jnp
  import transformer_engine.jax as te
  from transformer_engine.common import recipe

  BATCH = 32
  SEQLEN = 128
  HIDDEN = 1024

  # Initialize RNG and inputs.
  rng = jax.random.PRNGKey(0)
  init_rng, data_rng = jax.random.split(rng)
  inp = jax.random.normal(data_rng, [BATCH, SEQLEN, HIDDEN], jnp.float32)

  # Create an FP8 recipe. Note: All input args are optional.
  fp8_recipe = recipe.DelayedScaling(margin=0, interval=1, fp8_format=recipe.Format.HYBRID)

  # Enable autocasting for the forward pass
  with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
      model = te.DenseGeneral(features=HIDDEN)

      def loss_fn(params, other_vars, inp):
        out = model.apply({'params':params, **other_vars}, inp)
        return jnp.mean(out)

      # Initialize models.
      variables = model.init(init_rng, inp)
      other_variables, params = variables.pop('params')

      # Construct the forward and backward function
      fwd_bwd_fn = jax.value_and_grad(loss_fn, argnums=(0, 1))

      for _ in range(10):
        loss, (param_grads, other_grads) = fwd_bwd_fn(params, other_variables, inp)
        # Update FP8 metas
        other_variables = te.update_fp8_metas(other_grads)


Przemek Tredak's avatar
Przemek Tredak committed
112
113
114
Highlights
----------

Ming-Xu Huang's avatar
Ming-Xu Huang committed
115
116
* Easy-to-use modules enabling building of the Transformer layers with FP8 support
  on H100 GPUs.
Przemek Tredak's avatar
Przemek Tredak committed
117
* Optimizations (e.g. fused kernels) for Transformer models across all precisions and NVIDIA GPU
Ming-Xu Huang's avatar
Ming-Xu Huang committed
118
  architectures.
Przemek Tredak's avatar
Przemek Tredak committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137

.. overview-end-marker-do-not-remove

Installation
------------

In the NGC container
^^^^^^^^^^^^^^^^^^^^

Transformer Engine comes preinstalled in the pyTorch container on
`NVIDIA GPU Cloud <https://ngc.nvidia.com>`_ (versions 22.09 and later).

From source
^^^^^^^^^^^

Clone the repository and inside it type:

.. code-block:: bash

Ming-Xu Huang's avatar
Ming-Xu Huang committed
138
139
140
  NVTE_FRAMEWORK=all pip install .     # Building with all frameworks.
  NVTE_FRAMEWORK=pytorch pip install . # Building with pyTorch only.
  NVTE_FRAMEWORK=jax pip install .     # Building with JAX only.
Przemek Tredak's avatar
Przemek Tredak committed
141

142
143
144
145
146
User Guide
----------

For examples, tutorials and API reference please refer to the
`User Guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html>`_.
Przemek Tredak's avatar
Przemek Tredak committed
147
148
149
150
151
152
153
154

Transformer Architectures
-------------------------

While the more granular modules in Transformer Engine allow building any Transformer architecture,
the `TransformerLayer` API of Transformer Engine is flexible enough to build multiple major
variations of Transformers.

Ming-Xu Huang's avatar
Ming-Xu Huang committed
155
156
157
NOTE: For simplicity, we only show pyTorch examples below. For the usage of `TransformerLayer`
of all supported frameworks, refer to `examples <https://github.com/NVIDIA/TransformerEngine/tree/main/examples>`_.

Przemek Tredak's avatar
Przemek Tredak committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
GPT
^^^

`GPT` architecture has `LayerNorm` at the input side (before `QKV Gemm`) and the residual connection
is taken from the input of that `LayerNorm`. In TE this can be achieved by setting the following
arguments in the `TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          apply_residual_connection_post_layernorm=False,
          output_layernorm=False,
          layer_type="encoder",
  )

BERT
^^^^

`BERT` architecture has `LayerNorm` at the output side (after the final `BiasDropoutAdd`) and the
residual connection is taken from the output of that `LayerNorm`. In TE this can be achieved by
setting the following arguments in the `TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          apply_residual_connection_post_layernorm=True,
          output_layernorm=True,
          layer_type="encoder",
  )

T5
^^

`T5` architecture has an additional `cross-attention` + `BiasDropoutAdd` + `LayerNorm` block before
the `MLP` layer. In TE this can be added by setting the `layer_type` to `decoder` in the
`TransformerLayer` API.

.. code-block:: python

  transformer_engine.pytorch.TransformerLayer(
          ...,
          ...,
          layer_type="decoder",
  )

Contributing to Transformer Engine
----------------------------------

We welcome contributions to Transformer Engine. To contribute to TE and make pull requests,
follow the guidelines outlined in the `<CONTRIBUTING.rst>`_ document.

Useful Links
------------

* `Attention original paper <https://proceedings.neurips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`_

* `Megatron-LM tensor parallel <https://arxiv.org/pdf/1909.08053.pdf>`_

* `Megatron-LM sequence parallel <https://arxiv.org/pdf/2205.05198.pdf>`_

.. |License| image:: https://img.shields.io/badge/License-Apache%202.0-blue.svg
   :target: https://opensource.org/licenses/Apache-2.0