quickstart_jax.ipynb 31 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "962d87bb",
   "metadata": {},
   "source": [
    "\n",
    "\n",
    "# Getting Started\n",
    "\n",
    "## Overview\n",
    "\n",
    "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n",
    "\n",
    "This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n",
    "We recommend you to try understanding the basics of JAX first, using these resources:\n",
    "\n",
    "- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n",
    "- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n",
    "- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n",
    "- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n",
    "\n",
    "## Let's build a Transformer decoder layer!\n",
    "<small>_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._</small>\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n",
    "\n",
    "</div>\n",
    "\n",
    "Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"transformer_layer.png\" width=\"20%\">\n",
    "<figcaption> Figure 1: Structure of a GPT decoder layer.</figcaption>\n",
    "</figure>\n",
    "\n",
    "We construct the components as follows:\n",
    "\n",
    "- `LayerNorm`: `nn.LayerNorm` (Flax)\n",
    "- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n",
    "- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n",
    "- `Projection`: `nn.Dense` (Flax)\n",
    "- `Dropout`: `nn.Dropout` (Flax)\n",
    "- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n",
    "\n",
    "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together:  \n"
   ]
  },
  {
   "cell_type": "code",
56
   "execution_count": 26,
57
58
59
60
61
62
63
64
65
66
67
68
69
   "id": "d5284a38",
   "metadata": {},
   "outputs": [],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "from flax import linen as nn\n",
    "import quickstart_jax_utils as utils\n",
    "from typing import Optional"
   ]
  },
  {
   "cell_type": "code",
70
   "execution_count": 27,
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
   "id": "a4d1cfdc",
   "metadata": {},
   "outputs": [],
   "source": [
    "class FlaxMLP(nn.Module):\n",
    "    \"\"\"Feed-forward network in Transformer layer\n",
    "    Built with plain Flax modules.\n",
    "    \"\"\"\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n",
    "        x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n",
    "        x = nn.gelu(x, approximate=True)  # equivalent to tanh approximation\n",
    "        x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n",
    "        return x\n",
    "\n",
    "class FlaxTransformerLayer(nn.Module):\n",
    "    \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int\n",
    "    num_attention_heads: int\n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1\n",
    "    \n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray, \n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        # Create causal mask if not provided\n",
    "        if attention_mask is None:\n",
    "            attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
    "        \n",
    "        res = x\n",
    "        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "        \n",
    "        # Fused QKV projection\n",
    "        qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "        \n",
119
120
    "        # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n",
    "        # which is the correct format for dot_product_attention\n",
121
    "        \n",
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
    "        # Apply dot product attention\n",
    "        # Note: dot_product_attention expects mask to be broadcastable to \n",
    "        # [batch, num_heads, q_length, kv_length], but attention_mask from \n",
    "        # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n",
    "        \n",
    "        # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n",
    "        dropout_rng = None\n",
    "        if not deterministic and self.attention_dropout > 0:\n",
    "            dropout_rng = self.make_rng('dropout')\n",
    "        \n",
    "        x = nn.dot_product_attention(\n",
    "            query=q,\n",
    "            key=k,\n",
    "            value=v,\n",
    "            mask=attention_mask,\n",
    "            dropout_rng=dropout_rng,\n",
138
    "            dropout_rate=self.attention_dropout,\n",
139
140
    "            deterministic=deterministic,\n",
    "            broadcast_dropout=True,\n",
141
    "        )\n",
142
143
144
145
    "        \n",
    "        # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n",
    "        x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n",
    "        \n",
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    "        x = res + x\n",
    "        \n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "        \n",
    "        # MLP\n",
    "        mlp = FlaxMLP(\n",
    "            hidden_size=self.hidden_size,\n",
    "            ffn_hidden_size=self.ffn_hidden_size,\n",
    "        )\n",
    "        x = mlp(x)\n",
    "        \n",
    "        return x + res\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fbc3510b",
   "metadata": {},
   "source": [
    "## Testing Performance\n",
    "\n",
    "Now let's test the performance of our FlaxTransformerLayer:\n"
   ]
  },
  {
   "cell_type": "code",
174
   "execution_count": 28,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
   "id": "8b44649d",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Layer configuration\n",
    "hidden_size = 4096\n",
    "sequence_length = 2048\n",
    "batch_size = 4\n",
    "ffn_hidden_size = 16384\n",
    "num_attention_heads = 32\n",
    "dtype = jnp.bfloat16\n",
    "\n",
    "# Synthetic data\n",
    "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n",
    "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n",
    "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n"
   ]
  },
  {
   "cell_type": "code",
195
   "execution_count": 29,
196
197
198
199
200
201
202
203
   "id": "e44ed26d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Pure Flax FlaxTransformerLayer initialized successfully!\n",
204
      "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n"
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
     ]
    }
   ],
   "source": [
    "# Initialize the FlaxTransformerLayer\n",
    "flax_transformer = FlaxTransformerLayer(\n",
    "    hidden_size=hidden_size,\n",
    "    ffn_hidden_size=ffn_hidden_size,\n",
    "    num_attention_heads=num_attention_heads,\n",
    ")\n",
    "\n",
    "# Initialize parameters\n",
    "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n",
    "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n"
   ]
  },
  {
   "cell_type": "code",
225
   "execution_count": 30,
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
   "id": "de91af7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Input shape: (4, 2048, 4096)\n",
      "Output shape: (4, 2048, 4096)\n",
      "Output dtype: float32\n",
      "Forward pass completed successfully!\n"
     ]
    }
   ],
   "source": [
    "# Example usage of forward pass\n",
    "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n",
    "print(f\"Input shape: {x.shape}\")\n",
    "print(f\"Output shape: {y.shape}\")\n",
    "print(f\"Output dtype: {y.dtype}\")\n",
    "print(\"Forward pass completed successfully!\")\n"
   ]
  },
  {
   "cell_type": "code",
251
   "execution_count": 31,
252
253
254
255
256
257
258
   "id": "037bc8d9",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
259
      "Mean time: 18.546080589294434 ms\n"
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
     ]
    }
   ],
   "source": [
    "import importlib\n",
    "import quickstart_jax_utils\n",
    "importlib.reload(quickstart_jax_utils)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=flax_transformer.apply,\n",
    "    variables=params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ccb16f31",
   "metadata": {},
   "source": [
    "## Meet Transformer Engine\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n",
    "\n",
    "</div>\n",
    "\n",
    "As a reminder, the FlaxTransformerLayer above used:\n",
    "\n",
    "- `nn.LayerNorm`: Flax LayerNorm\n",
    "- `nn.Dense`: Flax Dense layer for QKV projection  \n",
    "- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n",
    "- `nn.Dense`: Flax Dense layer for projection\n",
    "- `nn.Dropout`: Flax Dropout\n",
    "- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n",
    "\n",
    "Below we show how to use Transformer Engine Flax modules for better performance:\n"
   ]
  },
  {
   "cell_type": "code",
307
   "execution_count": 32,
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
   "id": "bed20d6b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformer_engine.jax as te\n",
    "import transformer_engine.jax.flax as te_flax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f28cb444",
   "metadata": {},
   "source": [
    "TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:"
   ]
  },
  {
   "cell_type": "code",
326
   "execution_count": 33,
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
   "id": "56105579",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention\n",
    "\n",
    "\n",
    "class TEUnfusedMLP(nn.Module):\n",
    "    hidden_size : int\n",
    "    ffn_hidden_size: int\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n",
    "        x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n",
    "        x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n",
    "        x = te.activation.activation(x, activation_type=('gelu',))\n",
    "        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n",
    "        return x\n",
    "\n",
    "class TEUnfusedTransformerLayer(nn.Module):\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int \n",
    "    num_attention_heads: int  \n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1 \n",
    "    use_te_attention: bool = True  # True for TE attention, False for Flax attention\n",
    "\n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray,\n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        # Create causal mask if not provided\n",
    "        if attention_mask is None:\n",
    "            attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n",
    "        \n",
    "        res = x\n",
    "        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "\n",
    "        # Fused QKV projection\n",
    "        qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "\n",
    "        # Attention - either TE or Flax implementation\n",
    "        if self.use_te_attention:\n",
    "            # Use TE's DotProductAttention\n",
    "            attention = TEDotProductAttention(\n",
    "                head_dim=self.kv_channels,\n",
    "                num_attention_heads=self.num_attention_heads,\n",
    "                num_gqa_groups=self.num_attention_heads,  # No GQA\n",
    "                attention_dropout=self.attention_dropout,\n",
    "                attn_mask_type='causal',\n",
    "            )\n",
    "            x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
    "            # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
    "            x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
    "            x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
    "            x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
    "        else:\n",
    "            # Use Flax's MultiHeadDotProductAttention\n",
    "            q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n",
    "            k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n",
    "            v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n",
    "            \n",
    "            attention = nn.MultiHeadDotProductAttention(\n",
    "                num_heads=self.num_attention_heads,\n",
    "                qkv_features=self.kv_channels,\n",
    "                dropout_rate=self.attention_dropout,\n",
    "            )\n",
    "            x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n",
    "\n",
    "        x = res + x\n",
    "\n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n",
    "\n",
    "        # MLP\n",
    "        mlp = TEUnfusedMLP(\n",
    "            hidden_size=self.hidden_size,\n",
    "            ffn_hidden_size=self.ffn_hidden_size\n",
    "        )\n",
    "\n",
    "        x = mlp(x, deterministic=deterministic)\n",
    "\n",
    "        return x + res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a76911ac",
   "metadata": {},
   "source": [
    "Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation:  https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html"
   ]
  },
  {
   "cell_type": "code",
431
   "execution_count": 34,
432
433
434
435
436
437
438
   "id": "4b67511f",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
439
      "Mean time: 16.375374794006348 ms\n"
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
     ]
    }
   ],
   "source": [
    "te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads,\n",
    "    use_te_attention=False\n",
    ")\n",
    "\n",
    "te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n",
    "    variables=te_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0b230058",
   "metadata": {},
   "source": [
    "Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation"
   ]
  },
  {
   "cell_type": "code",
473
   "execution_count": 35,
474
475
476
   "id": "5146cd99",
   "metadata": {},
   "outputs": [
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:634: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n",
      "  warnings.warn(\n",
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
      "Fall back to the unfused attention.\n",
      "Please try to update the cuDNN and TE to the latest version.\n",
      "self.dtype=<class 'jax.numpy.float32'>\n",
      "qkv_layout=<QKVLayout.BSHD_BSHD_BSHD: <NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: 9>>\n",
      "attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
      "attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
      "self.attention_dropout=0.1\n",
      "self.num_attention_heads=32\n",
      "self.num_gqa_groups=32\n",
      "seqlen_q=2048\n",
      "seqlen_kv=2048\n",
      "head_dim_qk=128\n",
      "head_dim_v=128\n",
      "\n",
      "  warnings.warn(\n"
     ]
    },
501
502
503
504
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
505
      "Mean time: 12.403340339660645 ms\n"
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
     ]
    }
   ],
   "source": [
    "te_unfused_transformer = TEUnfusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads,\n",
    ")\n",
    "\n",
    "te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer.apply,\n",
    "    variables=te_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c9a101d3",
   "metadata": {},
   "source": [
    "## Enabling Quantization (FP8 or FP4)\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We configure a TE module to perform compute in FP8.\n",
    "\n",
    "</div>\n",
    "\n",
Paweł Gadziński's avatar
Paweł Gadziński committed
543
    "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n",
544
545
546
547
548
549
550
551
552
553
554
555
    "\n",
    "<div class=\"alert alert-warning\">\n",
    "\n",
    "<b>Important: FP8 Metadata Initialization</b>\n",
    "\n",
    "When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "code",
556
   "execution_count": 36,
557
558
559
560
561
562
563
564
565
566
567
   "id": "c2eee376",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformer_engine.common.recipe import Format, DelayedScaling\n",
    "fp8_format = Format.HYBRID\n",
    "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")"
   ]
  },
  {
   "cell_type": "code",
568
   "execution_count": 37,
569
570
571
572
573
574
575
   "id": "de96827c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
576
      "Mean time: 9.396424293518066 ms\n"
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
     ]
    }
   ],
   "source": [
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "\n",
    "    # Example usage of forward \n",
    "    y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_unfused_transformer.apply,\n",
    "    variables=te_unfused_params,  # Ensure the correct `params` is passed\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3801b201",
   "metadata": {},
   "source": [
    "\n",
    "## Fused TE Modules\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Summary</b>\n",
    "    \n",
    "We optimize the example Transformer layer with TE modules for fused operations.\n",
    "\n",
    "</div>\n",
    "\n",
    "The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n",
    "\n",
    "Transformer Engine therefore provides coarser modules that span multiple layers:\n",
    "\n",
    "* `LayerNormDenseGeneral`\n",
    "* `LayerNormMLP`\n",
    "* `TransformerLayer`\n",
    "\n",
    "To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n",
    "\n",
    "Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:"
   ]
  },
  {
   "cell_type": "code",
629
   "execution_count": 38,
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
   "id": "11203785",
   "metadata": {},
   "outputs": [],
   "source": [
    "class TEFusedTransformerLayer(nn.Module):\n",
    "    hidden_size: int\n",
    "    ffn_hidden_size: int \n",
    "    num_attention_heads: int  \n",
    "    layernorm_eps: float = 1e-5\n",
    "    attention_dropout: float = 0.1\n",
    "\n",
    "    def setup(self):\n",
    "        self.kv_channels = self.hidden_size // self.num_attention_heads\n",
    "\n",
    "    @nn.compact\n",
    "    def __call__(\n",
    "        self, \n",
    "        x: jnp.ndarray,\n",
    "        attention_mask: Optional[jnp.ndarray] = None,\n",
    "        deterministic: bool = False\n",
    "    ) -> jnp.ndarray:\n",
    "        res = x\n",
    "\n",
    "         # Fused QKV projection\n",
    "        qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n",
    "                                              epsilon=self.layernorm_eps, \n",
    "                                              use_bias=True, \n",
    "                                              return_layernorm_output=False)(x)\n",
    "        qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n",
    "        q, k, v = jnp.split(qkv, 3, axis=3)\n",
    "\n",
    "        # Attention using TE's DotProductAttention\n",
    "        attention = TEDotProductAttention(\n",
    "            head_dim=self.kv_channels,\n",
    "            num_attention_heads=self.num_attention_heads,\n",
    "            num_gqa_groups=self.num_attention_heads,  \n",
    "            attention_dropout=self.attention_dropout,\n",
    "            attn_mask_type='causal',\n",
    "        )\n",
    "        x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n",
    "        # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n",
    "        x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n",
    "        x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n",
    "        x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n",
    "\n",
    "        x = res + x\n",
    "\n",
    "        # Second residual connection\n",
    "        res = x\n",
    "        x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n",
    "                                 epsilon=self.layernorm_eps,\n",
    "                                 use_bias=True,\n",
    "                                 activations=('gelu',),\n",
    "                                 intermediate_dropout_rate=0.0,\n",
    "                                 return_layernorm_output=False\n",
    "                                 )(x, deterministic=deterministic)\n",
    "\n",
    "        return x + res"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "334cff59",
   "metadata": {},
   "source": [
    "Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's."
   ]
  },
  {
   "cell_type": "code",
700
   "execution_count": 39,
701
702
703
704
705
706
707
   "id": "6b0c705e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
708
      "Mean time: 9.145426750183105 ms\n"
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
     ]
    }
   ],
   "source": [
    "te_fused_transformer = TEFusedTransformerLayer(\n",
    "    hidden_size, \n",
    "    ffn_hidden_size, \n",
    "    num_attention_heads\n",
    ")\n",
    "\n",
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n",
    "    # Example usage of forward \n",
    "    y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n",
    "\n",
    "utils.speedometer(\n",
    "    model_apply_fn=te_fused_transformer.apply,\n",
    "    variables=te_fused_params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe}\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a45c12c8",
   "metadata": {},
   "source": [
    "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures."
   ]
  },
  {
   "cell_type": "code",
745
   "execution_count": 40,
746
747
   "id": "b2aaa8ef",
   "metadata": {},
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/code/github/TransformerEngine/transformer_engine/jax/flax/transformer.py:742: UserWarning: Fused attention is not enabled because there is no available kernel.\n",
      "Fall back to the unfused attention.\n",
      "Please try to update the cuDNN and TE to the latest version.\n",
      "self.dtype=<class 'jax.numpy.float32'>\n",
      "qkv_layout=<QKVLayout.BS3HD: <NVTE_QKV_Layout.NVTE_BS3HD: 5>>\n",
      "attn_bias_type=<AttnBiasType.NO_BIAS: <NVTE_Bias_Type.NVTE_NO_BIAS: 0>>\n",
      "attn_mask_type=<AttnMaskType.CAUSAL_MASK: <NVTE_Mask_Type.NVTE_CAUSAL_MASK: 2>>\n",
      "self.attention_dropout=0.1\n",
      "self.num_attention_heads=32\n",
      "self.num_gqa_groups=32\n",
      "seqlen_q=2048\n",
      "seqlen_kv=2048\n",
      "head_dim_qk=128\n",
      "head_dim_v=128\n",
      "\n",
      "  warnings.warn(\n"
     ]
    }
   ],
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
   "source": [
    "\n",
    "te_transformer = te_flax.TransformerLayer(\n",
    "    hidden_size=hidden_size,\n",
    "    mlp_hidden_size=ffn_hidden_size, \n",
    "    num_attention_heads=num_attention_heads,\n",
    "    mlp_activations=(\"gelu\",),\n",
    "    self_attn_mask_type='causal',\n",
    "    layernorm_epsilon=1e-5,\n",
    "    use_bias=True,\n",
    "    intermediate_dropout=0.0,\n",
    "    enable_relative_embedding=False,\n",
    "    self_attn_bias_type='no_bias',\n",
    "    hidden_dropout=0.0\n",
    ")\n",
    "\n",
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
    "    te_transformer_params = te_transformer.init(key, x, deterministic=False)\n",
    "    y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)"
   ]
  },
  {
   "cell_type": "code",
795
   "execution_count": 41,
796
797
798
799
800
801
802
   "id": "b9cdbf22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
803
      "Mean time: 9.020795822143555 ms\n"
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
     ]
    }
   ],
   "source": [
    "utils.speedometer(\n",
    "    model_apply_fn=te_transformer.apply,\n",
    "    model_init_fn=te_transformer.init,\n",
    "    variables=te_transformer_params,\n",
    "    input=x,\n",
    "    output_grad=dy,\n",
    "    dropout_key=dropout_key,\n",
    "    forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n",
    "    autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe }\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}