fp8_primer.ipynb 24.5 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "7b3e6954",
   "metadata": {},
   "source": [
8
    "# Using FP8 and FP4 with Transformer Engine\n",
Przemek Tredak's avatar
Przemek Tredak committed
9
    "\n",
10
    "H100 GPU introduced support for a new datatype, FP8 (8-bit floating point), enabling higher throughput of matrix multiplies and convolutions. Blackwell added support for NVFP4 and MXFP8 datatypes. In this example we will introduce these low precision datatypes and show how to use them with Transformer Engine.\n",
Przemek Tredak's avatar
Przemek Tredak committed
11
12
13
14
15
16
17
18
19
20
    "\n",
    "## Introduction to FP8\n",
    "\n",
    "### Structure\n",
    "\n",
    "The FP8 datatype supported by H100 is actually 2 distinct datatypes, useful in different parts of the training of neural networks:\n",
    "\n",
    "* E4M3 - it consists of 1 sign bit, 4 exponent bits and 3 bits of mantissa. It can store values up to +/-448 and `nan`.\n",
    "* E5M2 - it consists of 1 sign bit, 5 exponent bits and 2 bits of mantissa. It can store values up to +/-57344, +/- `inf` and `nan`. The tradeoff of the increased dynamic range is lower precision of the stored values.\n",
    "\n",
21
    "<figure align=\"center\" id=\"fig_1\">\n",
Przemek Tredak's avatar
Przemek Tredak committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
    "<img src=\"fp8_formats.png\" width=\"60%\">\n",
    "<figcaption> Figure 1: Structure of the floating point datatypes. All of the values shown (in FP16, BF16, FP8 E4M3 and FP8 E5M2) are the closest representations of value 0.3952.</figcaption>\n",
    "</figure>\n",
    "\n",
    "During training neural networks both of these types may be utilized. Typically forward activations and weights require more precision, so E4M3 datatype is best used during forward pass. In the backward pass, however, gradients flowing through the network typically are less susceptible to the loss of precision, but require higher dynamic range. Therefore they are best stored using E5M2 data format. H100 TensorCores provide support for any combination of these types as the inputs, enabling us to store each tensor using its preferred precision.\n",
    "\n",
    "### Mixed precision training - a quick introduction\n",
    "\n",
    "In order to understand how FP8 can be used for training Deep Learning models, it is useful to first remind ourselves how mixed precision works with other datatypes, especially FP16.\n",
    "\n",
    "Mixed precision recipe for FP16 training has 2 components: choosing which operations should be performed in FP16 and dynamic loss scaling.\n",
    "\n",
    "* Choosing the operations to be performed in FP16 precision requires analysis of the numerical behavior of the outputs with respect to inputs of the operation as well as the expected performance benefit. This enables marking operations like matrix multiplies, convolutions and normalization layers as safe, while leaving `norm` or `exp` operations as requiring high precision.\n",
    "* Dynamic loss scaling enables avoiding both over- and underflows of the gradients during training. Those may happen since, while the dynamic range of FP16 is enough to store the distribution of the gradient values, this distribution may be centered around values too high or too low for FP16 to handle. Scaling the loss shifts those distributions (without affecting numerics by using only powers of 2) into the range representable in FP16. \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"loss_scaling.png\" width=\"50%\">\n",
    "<figcaption> Figure 2: Scaling the loss enables shifting the gradient distribution into the representable range of FP16 datatype. </figcaption>\n",
    "</figure>\n",
    "\n",
    "### Mixed precision training with FP8\n",
    "\n",
    "While the dynamic range provided by the FP8 types is sufficient to store any particular activation or gradient, it is not sufficient for all of them at the same time. This makes the single loss scaling factor strategy, which worked for FP16, infeasible for FP8 training and instead requires using distinct scaling factors for each FP8 tensor.\n",
    "\n",
    "There are multiple strategies for choosing a scaling factor that is appropriate for a given FP8 tensor:\n",
    "\n",
    "* just-in-time scaling. This strategy chooses the scaling factor based on the maximum of absolute values (amax) of the tensor being produced. In practice it is infeasible, as it requires multiple passes through data - the operator produces and writes out the output in higher precision, then the maximum absolute value of the output is found and applied to all values in order to obtain the final FP8 output. This results in a lot of overhead, severely diminishing gains from using FP8.\n",
    "* delayed scaling. This strategy chooses the scaling factor based on the maximums of absolute values seen in some number of previous iterations. This enables full performance of FP8 computation, but requires storing the history of maximums as additional parameters of the FP8 operators. \n",
    "\n",
    "<figure align=\"center\">\n",
    "<img src=\"delayed_scaling.png\" width=\"80%\">\n",
    "<figcaption> Figure 3: Delayed scaling strategy. The FP8 operator uses scaling factor obtained using the history of amaxes (maximums of absolute values) seen in some number of previous iterations and produces both the FP8 output and the current amax, which gets stored in the history.</figcaption>\n",
    "</figure>\n",
    "\n",
    "As one can see in Figure 3, delayed scaling strategy requires both storing the history of amaxes, but also choosing a recipe for converting that history into the scaling factor used in the next iteration."
   ]
  },
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
  {
   "cell_type": "markdown",
   "id": "f03b58ed-71e8-422a-95be-35c1cc60c4e2",
   "metadata": {},
   "source": [
    "## MXFP8 and block scaling\n",
    "\n",
    "NVIDIA Blackwell architecture introduced support for a new variant of the FP8 format: [MXFP8](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf). \n",
    "\n",
    "### MXFP8 vs FP8\n",
    "\n",
    "The main difference between \"regular\" FP8 and MXFP8 lies in the granularity of the scaling. In FP8, each tensor has a single FP32 scaling factor, so all values in the tensor need to \"fit\" within the dynamic range of the FP8 datatype. This requires using the less precise E5M2 format to represent some tensors in the network (like gradients).\n",
    "\n",
    "MXFP8 addresses this by assigning a different scaling factor to each block of 32 [consecutive](#handling-transposes) values. This allows all values to be represented with the E4M3 datatype.\n",
    "\n",
    "<figure align=\"center\" id=\"fig_4\">\n",
    "<img src=\"MXFP8_FP8_comparison_1.png\" width=\"100%\">\n",
    "<figcaption> Figure 4: MXFP8 uses multiple scaling factors for a single tensor. The picture shows only 4 values per block for simplicity, but real MXFP8 has 32 values per block.</figcaption>\n",
    "</figure>\n",
    "\n",
    "<figure align=\"center\" id=\"fig_5\">\n",
    "<img src=\"MXFP8_FP8_comparison_2.png\" width=\"100%\">\n",
    "<figcaption> Figure 5: Due to multiple scaling factors, tensor's dynamic range requirements are reduced and so E4M3 format can be used as far fewer elements get saturated to 0.</figcaption>\n",
    "</figure>\n",
    "\n",
    "The second difference is the datatype used to store the scaling factors. FP8 uses FP32 (E8M23) while MXFP8 uses an 8-bit representation of a power of 2 (E8M0).\n",
    "\n",
    "<figure align=\"center\" id=\"fig_6\">\n",
    "<img src=\"E8M0.png\" width=\"100%\">\n",
    "<figcaption> Figure 6: Structure of the E8M0 datatype used for storing scaling factors in MXFP8.</figcaption>\n",
    "</figure>\n",
    "\n",
    "### Handling transposes\n",
    "\n",
    "The forward and backward passes of linear layers involve multiple matrix multiplications with different reduction dimensions. Blackwell Tensor Cores require MXFP8 data to be \"consecutive\" over the reduction dimension, so MXFP8 training uses non-transposed and transposed MXFP8 tensors at different points. However, while transposing FP8 data is numerically trivial, transposing MXFP8 data requires requantization.\n",
    "\n",
    "To avoid loss of precision connected with this double quantization, Transformer Engine creates both regular and transposed copies of the tensor from the original high precision input.\n",
    "\n",
    "<figure align=\"center\" id=\"fig_7\">\n",
    "<img src=\"linear_mxfp8.png\" width=\"80%\">\n",
    "<figcaption> Figure 7: Linear layer in MXFP8. Calculating both forward and backward pass requires tensors quantized in both directions.</figcaption>\n",
    "</figure>"
   ]
  },
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  {
   "cell_type": "markdown",
   "id": "fd7b4f37-50a2-4d41-9067-cf0c471cb2d7",
   "metadata": {},
   "source": [
    "## Beyond FP8 - training with NVFP4\n",
    "\n",
    "In addition to MXFP8, NVIDIA Blackwell introduced support for an even smaller, 4-bit format called NVFP4. The values are represented there in E2M1 format, able to represent values of magnitude up to +/-6.\n",
    "\n",
    "<figure align=\"center\" id=\"fig_8\">\n",
    "<img src=\"FP4_format.png\" width=\"50%\">\n",
    "<figcaption> Figure 8: FP4 E2M1 format can represent values between +/-6.</figcaption>\n",
    "</figure>\n",
    "\n",
    "### NVFP4 Format\n",
    "\n",
    "NVFP4 format is similar to MXFP8 - it also uses granular scaling to preserve the dynamic range. The differences are:\n",
    "\n",
    " - Granularity of the scaling factors: in NVFP4 format a single scaling factor is used per block of 16 elements, whereas MXFP8 uses 1 scaling factor per block of 32 elements\n",
    " - Datatype of the scaling factors: NVFP4 uses FP8 E4M3 as the scaling factor per block, whereas MXFP8 uses E8M0 as the scaling factor datatype. Choice of E4M3 for the scaling factor enables preservation of more information about mantissa, but does not enable the full dynamic range of FP32. Therefore, NVFP4 uses an additional single per-tensor FP32 scaling factor to avoid overflows.\n",
    "\n",
    "In the NVFP4 training recipe for weight tensors we use a different variant of the NVFP4 quantization, where a single scaling factor is shared by a 2D block of 16x16 elements. This is similar to the weight quantization scheme employed in [DeepSeek-v3 training](https://arxiv.org/abs/2412.19437v1), but with a much finer granularity.\n",
    "\n",
    "### NVFP4 training recipe\n",
    "\n",
    "The NVFP4 training recipe implemented in Transformer Engine is described in [Pretraining Large Language Models with NVFP4](https://arxiv.org/abs/2509.25149v1) paper. The main elements of the recipe are:\n",
    "\n",
    " - Stochastic Rounding. When quantizing gradients to NVFP4, we use stochastic rounding to avoid the bias introduced by quantization. With stochastic rounding values are rounded probabilistically to one of their two nearest representable numbers, with probabilities inversely\n",
    "proportional to their distances.\n",
    " - 2D Scaling. The non-square size of the quantization blocks, while increasing granularity, has a property that the quantized tensor and its transpose no longer hold the same values. This is important since the transposed tensors are used when calculating gradients of the linear layers. While most tensors are not sensitive to this issue during training, it does affect the training accuracy when applied to the weight tensors. Therefore, the weights of the linear layers are quantized using a 2D scheme, where a single scaling factor is shared by a 2D block of 16x16 elements.\n",
    " - Random Hadamard Transforms. While microscaling reduces the dynamic range needed to represent tensor values, outliers can still have a\n",
    "disproportionate impact on FP4 formats, degrading model accuracy. Random Hadamard transforms address this by reshaping the tensor distribution to be more Gaussian-like, which smooths outliers and makes tensors easier to represent accurately in NVFP4. In Transformer Engine, we use a 16x16 Hadamard matrix for activations and gradients when performing weight gradient computation.\n",
135
    " - Last few layers in higher precision. The last few layers of the LLM are more sensitive to the quantization and so we recommend running them in higher precision (for example MXFP8). This is not done automatically in Transformer Engine, since TE does not have the full information about the structure of the network being trained. This can be easily achieved though by modifying the model training code to run the last few layers under a different `autocast` (or nesting 2 autocasts in order to override the recipe for a part of the network).\n",
136
137
138
139
140
141
142
143
144
    "\n",
    "The full linear layer utilizing NVFP4 is presented in Figure 9.\n",
    "\n",
    "<figure align=\"center\" id=\"fig_9\">\n",
    "<img src=\"FP4_linear.png\" width=\"80%\">\n",
    "<figcaption> Figure 9: Linear layer utilizing NVFP4</figcaption>\n",
    "</figure>"
   ]
  },
Przemek Tredak's avatar
Przemek Tredak committed
145
146
147
148
149
  {
   "cell_type": "markdown",
   "id": "cf5e0b0d",
   "metadata": {},
   "source": [
150
    "## Using FP8 and FP4 with Transformer Engine\n",
Przemek Tredak's avatar
Przemek Tredak committed
151
    "\n",
152
    "Transformer Engine library provides tools enabling easy to use training with FP8 and FP4 datatypes using different strategies.\n",
Przemek Tredak's avatar
Przemek Tredak committed
153
154
155
    "\n",
    "### FP8 recipe\n",
    "\n",
156
157
158
159
160
161
162
    "Transformer Engine defines a range of different low precision recipes to choose from in the `transformer_engine.common.recipe` module.\n",
    "\n",
    " - The [DelayedScaling](../api/common.rst#transformer_engine.common.recipe.DelayedScaling) recipe stores all of the required options for training with FP8 delayed scaling: length of the amax history to use for scaling factor computation, FP8 data format, etc.\n",
    " - [Float8CurrentScaling](../api/common.rst#transformer_engine.common.recipe.Float8CurrentScaling) recipe enables current per-tensor scaling with FP8.\n",
    " - [Float8BlockScaling](../api/common.rst#transformer_engine.common.recipe.Float8BlockScaling) recipe enables block scaling with FP8 as described in [DeepSeek-v3 paper](https://arxiv.org/abs/2412.19437v1).\n",
    " - [MXFP8BlockScaling](../api/common.rst#transformer_engine.common.recipe.MXFP8BlockScaling) recipe enables MXFP8 training.\n",
    " - [NVFP4BlockScaling](../api/common.rst#transformer_engine.common.recipe.NVFP4BlockScaling) recipe enables NVFP4 training."
Przemek Tredak's avatar
Przemek Tredak committed
163
164
165
166
167
168
169
170
171
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0c8fd0ef",
   "metadata": {},
   "outputs": [],
   "source": [
172
    "from transformer_engine.common.recipe import Format, DelayedScaling, MXFP8BlockScaling, NVFP4BlockScaling\n",
Przemek Tredak's avatar
Przemek Tredak committed
173
174
    "\n",
    "fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass\n",
175
176
    "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n",
    "mxfp8_format = Format.E4M3  # E4M3 used everywhere\n",
177
178
    "mxfp8_recipe = MXFP8BlockScaling(fp8_format=mxfp8_format)\n",
    "nvfp4_recipe = NVFP4BlockScaling()"
Przemek Tredak's avatar
Przemek Tredak committed
179
180
181
182
183
184
185
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f9591eb5",
   "metadata": {},
   "source": [
186
    "This recipe is then used to configure the low precision training."
Przemek Tredak's avatar
Przemek Tredak committed
187
188
189
190
191
192
193
194
195
   ]
  },
  {
   "cell_type": "markdown",
   "id": "734d3934",
   "metadata": {},
   "source": [
    "### FP8 autocasting\n",
    "\n",
196
    "Not every operation is safe to be performed using FP8. All of the modules provided by Transformer Engine library were designed to provide maximum performance benefit from FP8 datatype while maintaining accuracy. In order to enable FP8 operations, TE modules need to be wrapped inside the [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager."
Przemek Tredak's avatar
Przemek Tredak committed
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f8b1ff7f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import transformer_engine.pytorch as te\n",
    "import torch\n",
    "\n",
    "torch.manual_seed(12345)\n",
    "\n",
    "my_linear = te.Linear(768, 768, bias=True)\n",
    "\n",
    "inp = torch.rand((1024, 768)).cuda()\n",
    "\n",
215
    "with te.autocast(enabled=True, recipe=fp8_recipe):\n",
Przemek Tredak's avatar
Przemek Tredak committed
216
217
218
219
220
221
222
223
    "    out_fp8 = my_linear(inp)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e41161f1",
   "metadata": {},
   "source": [
224
    "The `autocast` context manager hides the complexity of handling FP8:\n",
Przemek Tredak's avatar
Przemek Tredak committed
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
    "\n",
    "- All FP8-safe operations have their inputs cast to FP8\n",
    "- Amax history is updated\n",
    "- New scaling factors are computed and ready for the next iteration\n",
    "\n",
    "<div class=\"alert alert-info\">\n",
    "\n",
    "<b>Note</b>\n",
    "\n",
    "Support for FP8 in the Linear layer of Transformer Engine is currently limited to tensors with shapes where both dimensions are divisible by 16. In terms of the input to the full Transformer network, this typically requires padding sequence length to be multiple of 16.\n",
    "\n",
    "</div>"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f7bb2de9",
   "metadata": {},
   "source": [
    "### Handling backward pass\n",
    "\n",
246
    "When a model is run inside the `autocast` region, especially in multi-GPU training, some communication is required in order to synchronize the scaling factors and amax history. In order to perform that communication without introducing much overhead, `autocast` context manager aggregates the tensors before performing the communication.\n",
Przemek Tredak's avatar
Przemek Tredak committed
247
    "\n",
248
    "Due to this aggregation the backward call needs to happen outside of the `autocast` context manager. It has no impact on the computation precision - the precision of the backward pass is determined by the precision of the forward pass."
Przemek Tredak's avatar
Przemek Tredak committed
249
250
251
252
253
254
255
256
257
258
259
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e012bc8d",
   "metadata": {},
   "outputs": [],
   "source": [
    "loss_fp8 = out_fp8.mean()\n",
    "\n",
260
    "loss_fp8.backward()  # This backward pass uses FP8, since out_fp8 was calculated inside autocast\n",
Przemek Tredak's avatar
Przemek Tredak committed
261
262
263
    "\n",
    "out_fp32 = my_linear(inp)\n",
    "loss_fp32 = out_fp32.mean()\n",
264
    "loss_fp32.backward()  # This backward pass does not use FP8, since out_fp32 was calculated outside autocast"
Przemek Tredak's avatar
Przemek Tredak committed
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1a6723ca",
   "metadata": {},
   "source": [
    "### Precision\n",
    "\n",
    "If we compare the results of the FP32 and FP8 execution, we will see that they are relatively close, but different:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "41e9a37b",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
286
287
288
       "tensor([[ 0.2276,  0.2629,  0.3000,  ...,  0.1297, -0.3702,  0.1807],\n",
       "        [-0.0963, -0.3724,  0.1717,  ..., -0.1250, -0.8501, -0.1669],\n",
       "        [ 0.4526,  0.3479,  0.5976,  ...,  0.1685, -0.8864, -0.1977],\n",
Przemek Tredak's avatar
Przemek Tredak committed
289
       "        ...,\n",
290
291
292
       "        [ 0.1698,  0.6062,  0.0385,  ...,  0.4038, -0.4564,  0.0143],\n",
       "        [ 0.0679,  0.2947,  0.2750,  ..., -0.3271, -0.4990,  0.1198],\n",
       "        [ 0.1865,  0.2353,  0.9170,  ...,  0.0673, -0.5567,  0.1246]],\n",
Przemek Tredak's avatar
Przemek Tredak committed
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
       "       device='cuda:0', grad_fn=<_LinearBackward>)"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_fp8"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b328ae0e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
314
315
316
       "tensor([[ 0.2373,  0.2674,  0.2980,  ...,  0.1134, -0.3661,  0.1650],\n",
       "        [-0.0767, -0.3778,  0.1862,  ..., -0.1370, -0.8448, -0.1770],\n",
       "        [ 0.4615,  0.3593,  0.5813,  ...,  0.1696, -0.8826, -0.1826],\n",
Przemek Tredak's avatar
Przemek Tredak committed
317
       "        ...,\n",
318
319
320
       "        [ 0.1914,  0.6038,  0.0382,  ...,  0.4049, -0.4729,  0.0118],\n",
       "        [ 0.0864,  0.2895,  0.2719,  ..., -0.3337, -0.4922,  0.1240],\n",
       "        [ 0.2019,  0.2275,  0.9027,  ...,  0.0706, -0.5481,  0.1356]],\n",
Przemek Tredak's avatar
Przemek Tredak committed
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
       "       device='cuda:0', grad_fn=<_LinearBackward>)"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_fp32"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a9413c0a",
   "metadata": {},
   "source": [
    "That happens because in the FP8 case both the input and weights are cast to FP8 before the computation. We can see this if instead of the original inputs we use the inputs representable in FP8 (using a function defined in [quickstart_utils.py](quickstart_utils.py)):"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ea939581",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
351
352
353
      "tensor([[ 0.2276,  0.2629,  0.3000,  ...,  0.1297, -0.3702,  0.1807],\n",
      "        [-0.0963, -0.3724,  0.1717,  ..., -0.1250, -0.8501, -0.1669],\n",
      "        [ 0.4526,  0.3479,  0.5976,  ...,  0.1685, -0.8864, -0.1977],\n",
Przemek Tredak's avatar
Przemek Tredak committed
354
      "        ...,\n",
355
356
357
      "        [ 0.1698,  0.6062,  0.0385,  ...,  0.4038, -0.4564,  0.0143],\n",
      "        [ 0.0679,  0.2947,  0.2750,  ..., -0.3271, -0.4990,  0.1198],\n",
      "        [ 0.1865,  0.2353,  0.9170,  ...,  0.0673, -0.5567,  0.1246]],\n",
Przemek Tredak's avatar
Przemek Tredak committed
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
      "       device='cuda:0', grad_fn=<_LinearBackward>)\n"
     ]
    }
   ],
   "source": [
    "from quickstart_utils import cast_to_representable\n",
    "\n",
    "inp_representable = cast_to_representable(inp)\n",
    "my_linear.weight.data = cast_to_representable(my_linear.weight.data)\n",
    "\n",
    "out_fp32_representable = my_linear(inp_representable)\n",
    "\n",
    "print(out_fp32_representable)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "03e703bd",
   "metadata": {},
   "source": [
    "This time the difference is really small:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "78f1c2eb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
390
391
392
       "tensor([[0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
Przemek Tredak's avatar
Przemek Tredak committed
393
       "        ...,\n",
394
395
396
397
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.],\n",
       "        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:0',\n",
       "       grad_fn=<SubBackward0>)"
Przemek Tredak's avatar
Przemek Tredak committed
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "out_fp8 - out_fp32_representable"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "63ff9b8c",
   "metadata": {},
   "source": [
    "The differences in result coming from FP8 execution do not matter during the training process, but it is good to understand them, e.g. during debugging the model."
   ]
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
  },
  {
   "cell_type": "markdown",
   "id": "d45e8b6c-803b-4a4f-8835-c19b0a94bc6a",
   "metadata": {},
   "source": [
    "### Using multiple recipes in the same training run\n",
    "\n",
    "Sometimes it is desirable to use multiple recipes in the same training run. An example of this is the NVFP4 training, where a few layers at the end of the training should be run in higher precision. This can be achieved by using multiple autocasts, either completely separately or in a nested way (this could be useful when e.g. we want to have a configurable overarching recipe but still hardcode a different recipe for some pieces of the network)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "c663f694-41d6-47c0-a397-5fc56e692542",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor([[ 0.0547,  0.0039, -0.0664,  ..., -0.2061,  0.2344, -0.3223],\n",
      "        [ 0.0131, -0.1436,  0.0168,  ..., -0.4258,  0.1562, -0.0371],\n",
      "        [ 0.1074, -0.2773,  0.0576,  ..., -0.2070,  0.0640, -0.1611],\n",
      "        ...,\n",
      "        [ 0.0825, -0.0630,  0.0571,  ..., -0.3711,  0.1562, -0.4062],\n",
      "        [-0.1729, -0.1138, -0.0620,  ..., -0.4238,  0.0703, -0.2070],\n",
      "        [-0.0908, -0.2148,  0.2676,  ..., -0.4551,  0.1836, -0.4551]],\n",
      "       device='cuda:0', dtype=torch.bfloat16, grad_fn=<_LinearBackward>)\n"
     ]
    }
   ],
   "source": [
    "my_linear1 = te.Linear(768, 768).bfloat16()  # The first linear - we want to run it in FP4\n",
    "my_linear2 = te.Linear(768, 768).bfloat16()  # The second linear - we want to run it in MXFP8\n",
    "\n",
    "inp = inp.bfloat16()\n",
    "\n",
454
    "with te.autocast(recipe=nvfp4_recipe):\n",
455
    "    y = my_linear1(inp)\n",
456
    "    with te.autocast(recipe=mxfp8_recipe):\n",
457
458
459
460
461
462
    "        out = my_linear2(y)\n",
    "\n",
    "print(out)\n",
    "\n",
    "out.mean().backward()"
   ]
Przemek Tredak's avatar
Przemek Tredak committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
481
   "version": "3.12.3"
Przemek Tredak's avatar
Przemek Tredak committed
482
483
484
485
486
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}