Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1552 additions
and
0 deletions
+1552
-0
docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg
...ion_training/fp8_current_scaling/img/fp8_cast_process.svg
+55
-0
docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg
...p8_current_scaling/img/fp8_current_scaling_all_gather.svg
+78
-0
docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg
...recision_training/fp8_current_scaling/img/fp8_formats.svg
+164
-0
docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg
..._training/fp8_current_scaling/img/fp8_scaling_concept.svg
+112
-0
docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py
...aining/fp8_current_scaling/jax_current_scaling_example.py
+33
-0
docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py
...ng/fp8_current_scaling/pytorch_current_scaling_example.py
+29
-0
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
...sion_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
+163
-0
docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg
...n_training/fp8_delayed_scaling/img/scaling_comparison.svg
+82
-0
docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py
...elayed_scaling/jax_delayed_scaling_distributed_example.py
+15
-0
docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py
...aining/fp8_delayed_scaling/jax_delayed_scaling_example.py
+39
-0
docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py
...ed_scaling/pytorch_delayed_scaling_distributed_example.py
+18
-0
docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py
...ng/fp8_delayed_scaling/pytorch_delayed_scaling_example.py
+37
-0
docs/features/low_precision_training/index.rst
docs/features/low_precision_training/index.rst
+17
-0
docs/features/low_precision_training/introduction/autocast_jax.py
...tures/low_precision_training/introduction/autocast_jax.py
+83
-0
docs/features/low_precision_training/introduction/autocast_pytorch.py
...s/low_precision_training/introduction/autocast_pytorch.py
+69
-0
docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py
...precision_training/introduction/bf16_fp16_training_jax.py
+39
-0
docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py
...ision_training/introduction/bf16_fp16_training_pytorch.py
+52
-0
docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg
...w_precision_training/introduction/img/fp8_linear_flow.svg
+172
-0
docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg
...ision_training/introduction/img/fp_formats_comparison.svg
+183
-0
docs/features/low_precision_training/introduction/img/master_weights_approaches.svg
...n_training/introduction/img/master_weights_approaches.svg
+112
-0
No files found.
docs/features/low_precision_training/fp8_current_scaling/img/fp8_cast_process.svg
0 → 100644
View file @
9df0c4a3
<?xml version="1.0" encoding="UTF-8"?>
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 900 220"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead-cast);
}
</style>
<marker
id=
"arrowhead-cast"
markerWidth=
"10"
markerHeight=
"10"
refX=
"8"
refY=
"3"
orient=
"auto"
markerUnits=
"strokeWidth"
>
<polygon
points=
"0 0, 10 3, 0 6"
fill=
"#616161"
/>
</marker>
</defs>
<!-- Title -->
<text
x=
"450"
y=
"30"
class=
"title"
text-anchor=
"middle"
>
FP8 quantization
</text>
<!-- Step 1: High Precision Tensor -->
<rect
x=
"80"
y=
"80"
width=
"140"
height=
"70"
class=
"hp"
rx=
"6"
/>
<text
x=
"150"
y=
"110"
class=
"text"
text-anchor=
"middle"
>
High Precision
</text>
<text
x=
"150"
y=
"130"
class=
"text"
text-anchor=
"middle"
>
Tensor
</text>
<!-- Arrow 1 -->
<path
d=
"M 220 115 L 270 115"
class=
"arrow"
/>
<!-- Quantize container box -->
<rect
x=
"270"
y=
"60"
width=
"330"
height=
"130"
class=
"quantize"
rx=
"6"
/>
<text
x=
"435"
y=
"205"
class=
"text"
style=
"font-weight: 600; font-size: 14px;"
text-anchor=
"middle"
>
Quantize
</text>
<!-- Step 2: Compute Amax (sub-box) -->
<rect
x=
"280"
y=
"95"
width=
"140"
height=
"50"
class=
"amax"
rx=
"4"
/>
<text
x=
"350"
y=
"118"
class=
"text"
style=
"font-weight: 600;"
text-anchor=
"middle"
>
Compute amax
</text>
<text
x=
"350"
y=
"160"
class=
"small-text"
text-anchor=
"middle"
>
1 tensor read
</text>
<!-- Arrow 2 (inside quantize box) -->
<path
d=
"M 420 120 L 450 120"
class=
"arrow"
/>
<!-- Step 3: Apply Scale + Cast (sub-box) -->
<rect
x=
"450"
y=
"95"
width=
"140"
height=
"50"
class=
"quantize"
rx=
"4"
/>
<text
x=
"520"
y=
"115"
class=
"text"
style=
"font-weight: 600;"
text-anchor=
"middle"
>
Apply Scale
</text>
<text
x=
"520"
y=
"130"
class=
"text"
style=
"font-weight: 600;"
text-anchor=
"middle"
>
+ Cast
</text>
<text
x=
"520"
y=
"160"
class=
"small-text"
text-anchor=
"middle"
>
1 tensor read
</text>
<!-- Arrow 3 -->
<path
d=
"M 600 115 L 650 115"
class=
"arrow"
/>
<!-- Step 4: FP8 Tensor -->
<rect
x=
"650"
y=
"80"
width=
"140"
height=
"70"
class=
"fp8"
rx=
"6"
/>
<text
x=
"720"
y=
"110"
class=
"text"
text-anchor=
"middle"
>
FP8
</text>
<text
x=
"720"
y=
"130"
class=
"text"
text-anchor=
"middle"
>
Tensor
</text>
</svg>
docs/features/low_precision_training/fp8_current_scaling/img/fp8_current_scaling_all_gather.svg
0 → 100644
View file @
9df0c4a3
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 950 170"
width=
"950"
height=
"170"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
/* Arrows */
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead-ag); }
/* All-gather operations - fallback if CSS doesn't load */
.allgather {
fill: #e1f5fe;
stroke: #039be5;
stroke-width: 2;
}
</style>
<marker
id=
"arrowhead-ag"
markerWidth=
"6"
markerHeight=
"6"
refX=
"5"
refY=
"2"
orient=
"auto"
>
<polygon
points=
"0 0, 6 2, 0 4"
fill=
"#616161"
/>
</marker>
</defs>
<!-- Title -->
<text
x=
"475"
y=
"30"
class=
"title"
>
Quantization + all gather for FP8 current scaling
</text>
<!-- High Precision Tensor -->
<rect
x=
"30"
y=
"80"
width=
"110"
height=
"55"
class=
"hp"
rx=
"6"
/>
<text
x=
"85"
y=
"103"
class=
"text"
>
High Precision
</text>
<text
x=
"85"
y=
"120"
class=
"text"
>
Tensor
</text>
<!-- Arrow -->
<path
d=
"M 140 107 L 165 107"
class=
"arrow"
/>
<!-- Compute Amax -->
<rect
x=
"165"
y=
"80"
width=
"100"
height=
"55"
class=
"amax"
rx=
"6"
/>
<text
x=
"215"
y=
"103"
class=
"text"
>
Compute
</text>
<text
x=
"215"
y=
"120"
class=
"text"
>
Amax
</text>
<!-- Arrow -->
<path
d=
"M 265 107 L 290 107"
class=
"arrow"
/>
<!-- Synchronize Amax -->
<rect
x=
"290"
y=
"80"
width=
"100"
height=
"55"
class=
"amax"
rx=
"6"
/>
<text
x=
"340"
y=
"103"
class=
"text"
>
Synchronize
</text>
<text
x=
"340"
y=
"120"
class=
"text"
>
Amax
</text>
<!-- Arrow -->
<path
d=
"M 390 107 L 415 107"
class=
"arrow"
/>
<!-- Scale + Cast -->
<rect
x=
"415"
y=
"80"
width=
"100"
height=
"55"
class=
"quantize"
rx=
"6"
/>
<text
x=
"465"
y=
"103"
class=
"text"
>
Scale +
</text>
<text
x=
"465"
y=
"120"
class=
"text"
>
Cast
</text>
<!-- Arrow -->
<path
d=
"M 515 107 L 540 107"
class=
"arrow"
/>
<!-- FP8 Tensor (intermediate) -->
<rect
x=
"540"
y=
"80"
width=
"100"
height=
"55"
class=
"fp8"
rx=
"6"
/>
<text
x=
"590"
y=
"103"
class=
"text"
>
FP8
</text>
<text
x=
"590"
y=
"120"
class=
"text"
>
Tensor
</text>
<!-- Arrow -->
<path
d=
"M 640 107 L 665 107"
class=
"arrow"
/>
<!-- All-Gather -->
<rect
x=
"665"
y=
"80"
width=
"100"
height=
"55"
class=
"allgather"
rx=
"6"
/>
<text
x=
"715"
y=
"112"
class=
"text"
>
All-Gather
</text>
<!-- Arrow -->
<path
d=
"M 765 107 L 790 107"
class=
"arrow"
/>
<!-- FP8 Gathered Tensor -->
<rect
x=
"790"
y=
"80"
width=
"130"
height=
"55"
class=
"fp8"
rx=
"6"
/>
<text
x=
"855"
y=
"103"
class=
"text"
>
FP8 Gathered
</text>
<text
x=
"855"
y=
"120"
class=
"text"
>
Tensor
</text>
</svg>
docs/features/low_precision_training/fp8_current_scaling/img/fp8_formats.svg
0 → 100644
View file @
9df0c4a3
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 900 280"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-weight: bold; text-anchor: middle; dominant-baseline: middle; font-size: 20px; }
</style>
</defs>
<!-- Header labels - centered -->
<text
x=
"149"
y=
"18"
class=
"header-text"
>
sign
</text>
<text
x=
"220"
y=
"18"
class=
"header-text"
>
exponent
</text>
<text
x=
"420"
y=
"18"
class=
"header-text"
>
mantissa
</text>
<!-- FP16 Format (16 bits: 1 + 5 + 10) -->
<text
x=
"60"
y=
"60"
class=
"format-label"
>
FP16
</text>
<!-- Sign bit (1) -->
<rect
x=
"140"
y=
"45"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"149"
y=
"60"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (5) -->
<rect
x=
"163"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"172"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"186"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"195"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"209"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"218"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"232"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"241"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"255"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"264"
y=
"60"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (10) -->
<rect
x=
"278"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"287"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"301"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"310"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"324"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"333"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"347"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"356"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"370"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"379"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"393"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"402"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"416"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"425"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"439"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"448"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"462"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"471"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"485"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"494"
y=
"60"
class=
"bit-text"
>
1
</text>
<text
x=
"540"
y=
"60"
class=
"value-text"
>
= 0.395264
</text>
<!-- BF16 Format (16 bits: 1 + 8 + 7) -->
<text
x=
"60"
y=
"120"
class=
"format-label"
>
BF16
</text>
<!-- Sign bit (1) -->
<rect
x=
"140"
y=
"105"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"149"
y=
"120"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (8) -->
<rect
x=
"163"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"172"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"186"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"195"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"209"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"218"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"232"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"241"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"255"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"264"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"278"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"287"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"301"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"310"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"324"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"333"
y=
"120"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (7) -->
<rect
x=
"347"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"356"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"370"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"379"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"393"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"402"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"416"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"425"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"439"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"448"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"462"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"471"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"485"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"494"
y=
"120"
class=
"bit-text"
>
0
</text>
<text
x=
"540"
y=
"120"
class=
"value-text"
>
= 0.394531
</text>
<!-- FP8 E4M3 Format (8 bits: 1 + 4 + 3) -->
<text
x=
"60"
y=
"180"
class=
"format-label"
>
FP8 E4M3
</text>
<!-- Sign bit (1) -->
<rect
x=
"140"
y=
"165"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"149"
y=
"180"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (4) -->
<rect
x=
"163"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"172"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"186"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"195"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"209"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"218"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"232"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"241"
y=
"180"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (3) -->
<rect
x=
"255"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"264"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"278"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"287"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"301"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"310"
y=
"180"
class=
"bit-text"
>
1
</text>
<text
x=
"355"
y=
"180"
class=
"value-text"
>
= 0.40625
</text>
<!-- FP8 E5M2 Format (8 bits: 1 + 5 + 2) -->
<text
x=
"60"
y=
"240"
class=
"format-label"
>
FP8 E5M2
</text>
<!-- Sign bit (1) -->
<rect
x=
"140"
y=
"225"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"149"
y=
"240"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (5) -->
<rect
x=
"163"
y=
"225"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"172"
y=
"240"
class=
"bit-text"
>
0
</text>
<rect
x=
"186"
y=
"225"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"195"
y=
"240"
class=
"bit-text"
>
1
</text>
<rect
x=
"209"
y=
"225"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"218"
y=
"240"
class=
"bit-text"
>
1
</text>
<rect
x=
"232"
y=
"225"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"241"
y=
"240"
class=
"bit-text"
>
0
</text>
<rect
x=
"255"
y=
"225"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"264"
y=
"240"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (2) -->
<rect
x=
"278"
y=
"225"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"287"
y=
"240"
class=
"bit-text"
>
1
</text>
<rect
x=
"301"
y=
"225"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"310"
y=
"240"
class=
"bit-text"
>
0
</text>
<text
x=
"355"
y=
"240"
class=
"value-text"
>
= 0.375
</text>
</svg>
docs/features/low_precision_training/fp8_current_scaling/img/fp8_scaling_concept.svg
0 → 100644
View file @
9df0c4a3
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 900 380"
>
<style>
@import url("../_static/css/diagram-colors.css");
.axis-line { stroke: #333; stroke-width: 2.5; }
.value-dot { fill: #2196f3; stroke: #1976d2; stroke-width: 1; }
.arrow { fill: #4caf50; }
.arrow-line { stroke: #4caf50; stroke-width: 3; }
.range-label { font-size: 14px; fill: #555; font-weight: 500; }
</style>
<!-- Top: Original values (before scaling) -->
<text
x=
"450"
y=
"55"
class=
"section-title"
text-anchor=
"middle"
>
Original Tensor Values
</text>
<!-- Top axis -->
<line
x1=
"80"
y1=
"85"
x2=
"820"
y2=
"85"
class=
"axis-line"
/>
<!-- Zero marker (center) -->
<line
x1=
"450"
y1=
"80"
x2=
"450"
y2=
"90"
stroke=
"#333"
stroke-width=
"2"
/>
<text
x=
"450"
y=
"108"
class=
"text"
text-anchor=
"middle"
font-size=
"12px"
>
0
</text>
<!-- Value dots (before scaling - irregular, not symmetric around zero) -->
<circle
cx=
"118"
cy=
"85"
r=
"6"
fill=
"#e53935"
stroke=
"#c62828"
stroke-width=
"2"
/>
<circle
cx=
"159"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"167"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"187"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"199"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"228"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"326"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"368"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"442"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"621"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"649"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"725"
cy=
"85"
r=
"5"
class=
"value-dot"
/>
<!-- amax label -->
<text
x=
"118"
y=
"70"
class=
"text"
fill=
"#e53935"
font-weight=
"700"
font-size=
"14px"
text-anchor=
"middle"
>
amax
</text>
<!-- Original range bracket spanning all values -->
<line
x1=
"118"
y1=
"100"
x2=
"118"
y2=
"110"
stroke=
"#666"
stroke-width=
"1.5"
/>
<line
x1=
"118"
y1=
"110"
x2=
"725"
y2=
"110"
stroke=
"#666"
stroke-width=
"1.5"
/>
<line
x1=
"725"
y1=
"100"
x2=
"725"
y2=
"110"
stroke=
"#666"
stroke-width=
"1.5"
/>
<text
x=
"750"
y=
"114"
class=
"range-label"
text-anchor=
"start"
>
Original range
</text>
<!-- Trapezoid showing compression from original range to FP8 range -->
<polygon
points=
"118,115 725,115 650,165 250,165"
fill=
"#e53935"
opacity=
"0.2"
stroke=
"#e53935"
stroke-width=
"1.5"
/>
<!-- Bottom: After scaling -->
<text
x=
"450"
y=
"190"
class=
"section-title"
text-anchor=
"middle"
>
Scaled Values (fit FP8 range)
</text>
<!-- Bottom axis -->
<line
x1=
"80"
y1=
"220"
x2=
"820"
y2=
"220"
class=
"axis-line"
/>
<!-- Zero marker (center) -->
<line
x1=
"450"
y1=
"215"
x2=
"450"
y2=
"225"
stroke=
"#333"
stroke-width=
"2"
/>
<text
x=
"450"
y=
"238"
class=
"text"
text-anchor=
"middle"
font-size=
"12px"
>
0
</text>
<!-- FP8 range bracket -->
<line
x1=
"250"
y1=
"245"
x2=
"250"
y2=
"255"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<line
x1=
"250"
y1=
"255"
x2=
"650"
y2=
"255"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<line
x1=
"650"
y1=
"245"
x2=
"650"
y2=
"255"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<text
x=
"750"
y=
"259"
class=
"range-label"
text-anchor=
"start"
fill=
"#4caf50"
>
FP8 range
</text>
<!-- Value dots (after scaling - homogeneous scaling from zero, all fit into FP8 range) -->
<circle
cx=
"250"
cy=
"220"
r=
"6"
fill=
"#e53935"
stroke=
"#c62828"
stroke-width=
"2"
/>
<text
x=
"250"
y=
"205"
class=
"text"
fill=
"#e53935"
font-weight=
"700"
font-size=
"12px"
text-anchor=
"middle"
>
- FP8 range max
</text>
<circle
cx=
"275"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"280"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"292"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"299"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"316"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"375"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"401"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"445"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"553"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"569"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"615"
cy=
"220"
r=
"5"
class=
"value-dot"
/>
<!-- Third line: After cast to FP8 (quantized values) -->
<text
x=
"450"
y=
"290"
class=
"section-title"
text-anchor=
"middle"
>
Cast to FP8 (quantized values)
</text>
<!-- Third axis -->
<line
x1=
"80"
y1=
"320"
x2=
"820"
y2=
"320"
class=
"axis-line"
/>
<!-- Zero marker (center) -->
<line
x1=
"450"
y1=
"315"
x2=
"450"
y2=
"325"
stroke=
"#333"
stroke-width=
"2"
/>
<text
x=
"450"
y=
"338"
class=
"text"
text-anchor=
"middle"
font-size=
"12px"
>
0
</text>
<!-- FP8 range bracket -->
<line
x1=
"250"
y1=
"345"
x2=
"250"
y2=
"355"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<line
x1=
"250"
y1=
"355"
x2=
"650"
y2=
"355"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<line
x1=
"650"
y1=
"345"
x2=
"650"
y2=
"355"
stroke=
"#4caf50"
stroke-width=
"1.5"
/>
<text
x=
"750"
y=
"359"
class=
"range-label"
text-anchor=
"start"
fill=
"#4caf50"
>
FP8 range
</text>
<!-- Quantized dots - merged close values to show FP8 granularity -->
<circle
cx=
"250"
cy=
"320"
r=
"6"
fill=
"#e53935"
stroke=
"#c62828"
stroke-width=
"2"
/>
<!-- merged: 275+280 -->
<circle
cx=
"278"
cy=
"317"
r=
"4.5"
class=
"value-dot"
/>
<circle
cx=
"278"
cy=
"323"
r=
"4.5"
class=
"value-dot"
/>
<!-- merged: 292+299 -->
<circle
cx=
"296"
cy=
"317"
r=
"4.5"
class=
"value-dot"
/>
<circle
cx=
"296"
cy=
"323"
r=
"4.5"
class=
"value-dot"
/>
<circle
cx=
"318"
cy=
"320"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"378"
cy=
"320"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"404"
cy=
"320"
r=
"5"
class=
"value-dot"
/>
<circle
cx=
"450"
cy=
"320"
r=
"5"
class=
"value-dot"
/>
<!-- merged: 553+569 -->
<circle
cx=
"562"
cy=
"317"
r=
"4.5"
class=
"value-dot"
/>
<circle
cx=
"562"
cy=
"323"
r=
"4.5"
class=
"value-dot"
/>
<circle
cx=
"615"
cy=
"320"
r=
"5"
class=
"value-dot"
/>
</svg>
docs/features/low_precision_training/fp8_current_scaling/jax_current_scaling_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_CURRENT_SCALING_EXAMPLE
import
jax
import
jax.numpy
as
jnp
import
transformer_engine.jax
as
te
from
transformer_engine.jax.flax
import
DenseGeneral
from
transformer_engine.common.recipe
import
Float8CurrentScaling
,
Format
# Create FP8 Current Scaling recipe
# Available formats:
# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
# - Format.E4M3 -- E4M3 for both forward and backward pass
recipe
=
Float8CurrentScaling
(
fp8_format
=
Format
.
HYBRID
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
# Create and initialize layer
layer
=
DenseGeneral
(
features
=
1024
)
key
=
jax
.
random
.
PRNGKey
(
0
)
x
=
jax
.
random
.
normal
(
key
,
(
32
,
128
,
1024
),
dtype
=
jnp
.
bfloat16
)
var_collect
=
layer
.
init
(
key
,
x
)
# Forward and backward pass
def
loss_fn
(
var_collect
):
output
=
layer
.
apply
(
var_collect
,
x
)
return
output
.
sum
()
loss
,
grads
=
jax
.
value_and_grad
(
loss_fn
)(
var_collect
)
# END_CURRENT_SCALING_EXAMPLE
docs/features/low_precision_training/fp8_current_scaling/pytorch_current_scaling_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_CURRENT_SCALING_EXAMPLE
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
Float8CurrentScaling
,
Format
# Create FP8 Current Scaling recipe
# Available formats:
# - Format.HYBRID (default) -- E4M3 for forward pass, E5M2 for backward pass
# - Format.E4M3 -- E4M3 for both forward and backward pass
recipe
=
Float8CurrentScaling
(
fp8_format
=
Format
.
HYBRID
)
# Create a simple linear layer with bfloat16 parameters
layer
=
te
.
Linear
(
1024
,
1024
,
params_dtype
=
torch
.
bfloat16
)
# Forward and backward pass
inp
=
torch
.
randn
(
32
,
128
,
1024
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
output
=
layer
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
# END_CURRENT_SCALING_EXAMPLE
docs/features/low_precision_training/fp8_delayed_scaling/fp8_delayed_scaling.rst
0 → 100644
View file @
9df0c4a3
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
FP8 Delayed Scaling
===================================
FP8 Delayed Scaling recipe estimates scaling factors from historical amax values rather than computing them
for each tensor. Compared to Current Scaling recipe,
this reduces tensor reads per quantization from two to one,
improving memory efficiency.
Both this and :doc:`FP8 Current Scaling <../fp8_current_scaling/fp8_current_scaling>` recipe use
the same FP8 formats (E4M3/E5M2) with one FP32 scaling factor per tensor.
Reading the FP8 Current Scaling documentation first is recommended.
Quantization with delayed scaling factors
-----------------------------------------
FP8 Current Scaling requires two tensor reads per quantization: one to compute amax,
one to cast. FP8 Delayed Scaling eliminates the first read by predicting the scaling factor
from historical amax values - hence *delayed* (using past values) versus *current* (using present values).
The quantization process works as follows:
1. **Compute scaling factor from history** (no tensor read needed):
The scaling factor is derived from stored ``amax_history`` using the formula:
``scaling_factor = FP8_MAX / amax``
where ``amax`` is computed from history using either ``max`` (maximum over window, default) or ``most_recent`` algorithm.
2. **Quantize the tensor** (one tensor read):
Apply the scaling factor and cast to FP8. Values exceeding FP8 range are clipped.
3. **Update history**:
Record the actual amax from this quantization for future iterations.
Each module maintains an ``amax_history`` tensor of configurable length (``amax_history_len``)
for each quantized tensor.
.. raw:: html
:file: img/scaling_comparison.svg
*Figure 1. Comparison of FP8 Current Scaling and FP8 Delayed Scaling quantization processes.*
Amax History Management
-----------------------
The ``amax_history`` buffer acts as a sliding window of recent amax values.
Position 0 serves as a staging area for the current amax, while positions 1 to N-1
store the history from oldest to newest. Each quantization writes the observed amax
to position 0, and after the pass completes, the history is rotated:
.. code-block:: text
Before rotation: [amax_N, amax_1, amax_2, ..., amax_N-1] (amax_N = current, amax_1 = oldest)
After rotation: [0, amax_2, ..., amax_N-1, amax_N] (amax_1 dropped, amax_N appended)
The scaling factor is computed **before** the rotation, so it uses all ``amax_history_len`` values.
Position 0 serves as a staging area — it is zeroed after the scale update, ready for the next iteration's amax.
The implementation differs between PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
Each module creates two ``amax_history`` tensors, initialized to zero:
- Forward: shape ``(amax_history_len, num_gemms * 3)`` — three FP8 tensors per GEMM (input, weight, output)
- Backward: shape ``(amax_history_len, num_gemms * 2)`` — two FP8 tensors per GEMM (grad_output, grad_input)
When the autocast context exits, a single CUDA kernel processes all tensors at once —
performing amax reduction across GPUs and history rotation. This batched approach
minimizes kernel launch overhead compared to updating each tensor separately.
.. tab:: JAX
Each quantizer maintains its own ``amax_history`` with shape ``(amax_history_len,)``
and updates independently.
Here's how to use FP8 Delayed Scaling in PyTorch and JAX:
.. tabs::
.. tab:: PyTorch
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: pytorch_delayed_scaling_example.py
:language: python
:start-after: # START_DELAYED_SCALING_EXAMPLE
:end-before: # END_DELAYED_SCALING_EXAMPLE
.. tab:: JAX
.. raw:: html
<div style="background: #f0f4f8; border-left: 3px solid #5c7cfa; padding: 6px 12px; font-size: 13px; color: #495057; margin-bottom: 0; border-radius: 4px 4px 0 0;">
Requires SM89 (Ada) or later
</div>
.. literalinclude:: jax_delayed_scaling_example.py
:language: python
:start-after: # START_DELAYED_SCALING_EXAMPLE
:end-before: # END_DELAYED_SCALING_EXAMPLE
Distributed Training
--------------------
FP8 Delayed Scaling uses the same data formats as FP8 Current Scaling - quantized all-gather is supported.
However, amax reduction works slightly differently in different frameworks.
.. tabs::
.. tab:: PyTorch
Amax reduction is controlled by two parameters:
- ``reduce_amax`` in recipe: enables/disables reduction (required for SP and CP)
- ``amax_reduction_group`` in ``autocast``: specifies the process group for reduction
We recommend reducing amax across all GPUs where the tensor is sharded,
including data parallel ranks.
.. literalinclude:: pytorch_delayed_scaling_distributed_example.py
:language: python
:start-after: # START_AMAX_REDUCTION_EXAMPLE
:end-before: # END_AMAX_REDUCTION_EXAMPLE
In data parallel training, some modules may not execute on certain ranks
(e.g., MoE experts that receive no tokens). This is handled as follows:
- **First iteration**: All modules must execute on all ranks to register
their ``amax_history`` tensors in the global buffer. Mismatched registration
would cause the ``all_reduce`` to hang due to different tensor sizes across ranks.
- **Subsequent iterations**: The ``autocast`` context must be entered and exited
on all ranks (this triggers the collective reduction). Individual modules can be
skipped - if no rank executes a module, its history is not rotated and scale
remains unchanged.
.. tab:: JAX
Amax reduction is always enabled and managed automatically.
Reduction scope: all parallelism axes except pipeline parallelism (TP, SP, DP/FSDP).
.. literalinclude:: jax_delayed_scaling_distributed_example.py
:language: python
:start-after: # START_AMAX_REDUCTION_EXAMPLE
:end-before: # END_AMAX_REDUCTION_EXAMPLE
Supported devices
-----------------
Ada and later (SM 8.9+)
\ No newline at end of file
docs/features/low_precision_training/fp8_delayed_scaling/img/scaling_comparison.svg
0 → 100644
View file @
9df0c4a3
<?xml version="1.0" encoding="UTF-8"?>
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 1000 420"
>
<defs>
<style>
/* Common styles loaded from diagram-colors.css: .hp, .fp8, .quantize, .amax, .text, .title, .label, .box-orange, .box-dashed */
/* Diagram-specific styles for arrows */
.arrow {
stroke: #616161;
stroke-width: 2;
fill: none;
marker-end: url(#arrowhead);
}
</style>
<marker
id=
"arrowhead"
markerWidth=
"10"
markerHeight=
"10"
refX=
"8"
refY=
"3"
orient=
"auto"
markerUnits=
"strokeWidth"
>
<polygon
points=
"0 0, 10 3, 0 6"
fill=
"#616161"
/>
</marker>
</defs>
<!-- Current Scaling Section -->
<text
x=
"250"
y=
"30"
class=
"title"
>
Current Scaling
</text>
<!-- Tensor box -->
<rect
x=
"150"
y=
"60"
width=
"200"
height=
"60"
class=
"hp"
rx=
"5"
/>
<text
x=
"250"
y=
"95"
class=
"text"
>
Tensor
</text>
<!-- Arrow to amax computation -->
<path
d=
"M 250 120 L 250 160"
class=
"arrow"
/>
<!-- Amax computation box -->
<rect
x=
"150"
y=
"160"
width=
"200"
height=
"60"
class=
"amax"
rx=
"5"
/>
<text
x=
"250"
y=
"195"
class=
"text"
>
Amax Computation
</text>
<!-- Arrow to quantization -->
<path
d=
"M 250 220 L 250 260"
class=
"arrow"
/>
<!-- Quantization box -->
<rect
x=
"125"
y=
"260"
width=
"250"
height=
"60"
class=
"quantize"
rx=
"5"
/>
<text
x=
"250"
y=
"285"
class=
"text"
>
Quantization
</text>
<text
x=
"250"
y=
"305"
class=
"label"
>
(uses tensor + amax)
</text>
<!-- Arrow to FP8 tensor -->
<path
d=
"M 250 320 L 250 360"
class=
"arrow"
/>
<!-- FP8 Tensor result -->
<rect
x=
"150"
y=
"360"
width=
"200"
height=
"40"
class=
"fp8"
rx=
"5"
/>
<text
x=
"250"
y=
"385"
class=
"text"
>
FP8 Tensor
</text>
<!-- Delayed Scaling Section -->
<text
x=
"750"
y=
"30"
class=
"title"
>
Delayed Scaling
</text>
<!-- Tensor box with amax history subbox -->
<rect
x=
"650"
y=
"60"
width=
"200"
height=
"80"
class=
"hp"
rx=
"5"
/>
<text
x=
"750"
y=
"90"
class=
"text"
>
Tensor
</text>
<!-- Amax history subbox (below tensor) -->
<rect
x=
"660"
y=
"110"
width=
"180"
height=
"25"
class=
"box-orange box-dashed"
rx=
"3"
/>
<text
x=
"750"
y=
"127"
class=
"label"
>
amax history
</text>
<!-- Arrow to quantization -->
<path
d=
"M 750 140 L 750 180"
class=
"arrow"
/>
<text
x=
"820"
y=
"162"
class=
"small-text"
style=
"text-anchor: start;"
>
read amax
</text>
<!-- Quantization box -->
<rect
x=
"625"
y=
"180"
width=
"250"
height=
"80"
class=
"quantize"
rx=
"5"
/>
<text
x=
"750"
y=
"210"
class=
"text"
>
Quantization
</text>
<text
x=
"750"
y=
"230"
class=
"label"
>
(uses tensor + amax from history)
</text>
<text
x=
"750"
y=
"250"
class=
"label"
>
(updates amax history)
</text>
<!-- Arrow back to history (curved) -->
<path
d=
"M 625 220 Q 590 220 590 127 L 660 127"
class=
"arrow"
/>
<text
x=
"565"
y=
"175"
class=
"small-text"
style=
"text-anchor: end;"
>
update amax
</text>
<!-- Arrow to FP8 tensor -->
<path
d=
"M 750 260 L 750 300"
class=
"arrow"
/>
<!-- FP8 Tensor result -->
<rect
x=
"650"
y=
"300"
width=
"200"
height=
"40"
class=
"fp8"
rx=
"5"
/>
<text
x=
"750"
y=
"325"
class=
"text"
>
FP8 Tensor
</text>
</svg>
docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_distributed_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_AMAX_REDUCTION_EXAMPLE
import
transformer_engine.jax
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
# Amax reduction scope is managed internally
recipe
=
DelayedScaling
(
reduce_amax
=
True
)
# Must be True in JAX
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
output
=
layer
.
apply
(
params
,
inp
)
# END_AMAX_REDUCTION_EXAMPLE
docs/features/low_precision_training/fp8_delayed_scaling/jax_delayed_scaling_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
transformer_engine.jax.quantize
import
get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert
get_device_compute_capability
()
>=
89
,
"This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import
jax
import
jax.numpy
as
jnp
import
transformer_engine.jax
as
te
from
transformer_engine.jax.flax
import
DenseGeneral
from
transformer_engine.common.recipe
import
DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe
=
DelayedScaling
(
margin
=
0
,
# Margin for scaling factor computation (default: 0)
amax_history_len
=
1024
,
# Length of amax history window (default: 1024)
amax_compute_algo
=
"max"
,
# How to compute amax from history (default: "max")
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
# Initialize layer and data
layer
=
DenseGeneral
(
features
=
1024
)
key
=
jax
.
random
.
PRNGKey
(
0
)
x
=
jax
.
random
.
normal
(
key
,
(
32
,
128
,
1024
),
dtype
=
jnp
.
bfloat16
)
var_collect
=
layer
.
init
(
key
,
x
)
# Forward and backward pass
def
loss_fn
(
var_collect
):
output
=
layer
.
apply
(
var_collect
,
x
)
return
output
.
sum
()
loss
,
grads
=
jax
.
value_and_grad
(
loss_fn
)(
var_collect
)
# END_DELAYED_SCALING_EXAMPLE
docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_distributed_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_AMAX_REDUCTION_EXAMPLE
import
torch.distributed
as
dist
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
# Create process group for amax reduction (e.g., all 8 GPUs)
amax_reduction_group
=
dist
.
new_group
(
ranks
=
[
0
,
1
,
2
,
3
,
4
,
5
,
6
,
7
])
recipe
=
DelayedScaling
(
reduce_amax
=
True
)
with
te
.
autocast
(
recipe
=
recipe
,
amax_reduction_group
=
amax_reduction_group
):
output
=
model
(
inp
)
# END_AMAX_REDUCTION_EXAMPLE
docs/features/low_precision_training/fp8_delayed_scaling/pytorch_delayed_scaling_example.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
# Requires Ada (SM89) or newer for FP8 support
assert
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
or
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
torch
.
cuda
.
get_device_capability
()[
1
]
>=
9
),
"This example requires SM89 (Ada) or newer"
# START_DELAYED_SCALING_EXAMPLE
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
# Create FP8 Delayed Scaling recipe
recipe
=
DelayedScaling
(
margin
=
0
,
# Margin for scaling factor computation (default: 0)
amax_history_len
=
1024
,
# Length of amax history window (default: 1024)
amax_compute_algo
=
"max"
,
# How to compute amax from history (default: "max")
)
# Create a linear layer with bfloat16 parameters
layer
=
te
.
Linear
(
1024
,
1024
,
params_dtype
=
torch
.
bfloat16
)
# Forward and backward pass
inp
=
torch
.
randn
(
32
,
128
,
1024
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda"
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
output
=
layer
(
inp
)
loss
=
output
.
sum
()
loss
.
backward
()
# END_DELAYED_SCALING_EXAMPLE
docs/features/low_precision_training/index.rst
0 → 100644
View file @
9df0c4a3
..
Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE for license information.
Low precision training
===================================
.. toctree::
introduction/introduction.rst
performance_considerations/performance_considerations.rst
fp8_current_scaling/fp8_current_scaling.rst
fp8_delayed_scaling/fp8_delayed_scaling.rst
fp8_blockwise_scaling/fp8_blockwise_scaling.rst
mxfp8/mxfp8.rst
nvfp4/nvfp4.rst
\ No newline at end of file
docs/features/low_precision_training/introduction/autocast_jax.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from
transformer_engine.jax.quantize
import
get_device_compute_capability
# Requires Ada (SM89) or newer for FP8 support
assert
get_device_compute_capability
()
>=
89
,
"This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import
jax
import
jax.numpy
as
jnp
import
transformer_engine.jax
as
te
from
transformer_engine.jax.flax
import
TransformerLayer
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
# Set up recipe
recipe
=
DelayedScaling
()
# Model initialization must happen inside autocast
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
layer
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
,
)
init_key
,
dropout_key
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
x
=
jax
.
random
.
normal
(
init_key
,
(
32
,
128
,
1024
),
dtype
=
jnp
.
bfloat16
)
var_collect
=
layer
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
x
)
# Forward and backward pass (both inside autocast for JAX)
def
loss_fn
(
var_collect
):
output
=
layer
.
apply
(
var_collect
,
x
,
rngs
=
{
"dropout"
:
dropout_key
})
return
output
.
sum
()
loss
,
grads
=
jax
.
value_and_grad
(
loss_fn
)(
var_collect
)
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
E4M3
)
decoder_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
encoder_recipe
):
encoder
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
)
encoder_var_collect
=
encoder
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
x
)
hidden
=
encoder
.
apply
(
encoder_var_collect
,
x
,
rngs
=
{
"dropout"
:
dropout_key
})
with
te
.
autocast
(
enabled
=
True
,
recipe
=
decoder_recipe
):
decoder
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
)
decoder_var_collect
=
decoder
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
hidden
)
output
=
decoder
.
apply
(
decoder_var_collect
,
hidden
,
rngs
=
{
"dropout"
:
dropout_key
})
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
E4M3
)
inner_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
outer_recipe
):
# layer1 uses outer_recipe
layer1
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
)
var_collect1
=
layer1
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
x
)
hidden
=
layer1
.
apply
(
var_collect1
,
x
,
rngs
=
{
"dropout"
:
dropout_key
})
with
te
.
autocast
(
enabled
=
True
,
recipe
=
inner_recipe
):
# layer2 uses inner_recipe (overrides outer)
layer2
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
)
var_collect2
=
layer2
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
hidden
)
hidden
=
layer2
.
apply
(
var_collect2
,
hidden
,
rngs
=
{
"dropout"
:
dropout_key
})
# layer3 uses outer_recipe again
layer3
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
)
var_collect3
=
layer3
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
hidden
)
output
=
layer3
.
apply
(
var_collect3
,
hidden
,
rngs
=
{
"dropout"
:
dropout_key
})
# END_AUTOCAST_NESTED
docs/features/low_precision_training/introduction/autocast_pytorch.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import
torch
# Requires Ada (SM89) or newer for FP8 support
assert
torch
.
cuda
.
get_device_capability
()[
0
]
>=
9
or
(
torch
.
cuda
.
get_device_capability
()[
0
]
==
8
and
torch
.
cuda
.
get_device_capability
()[
1
]
>=
9
),
"This example requires SM89 (Ada) or newer"
# START_AUTOCAST_BASIC
import
torch
import
transformer_engine.pytorch
as
te
from
transformer_engine.common.recipe
import
DelayedScaling
,
Format
recipe
=
DelayedScaling
()
layer
=
te
.
Linear
(
1024
,
1024
)
inp
=
torch
.
randn
(
32
,
1024
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
recipe
):
output
=
layer
(
inp
)
# .backward() is called outside of autocast
loss
=
output
.
sum
()
loss
.
backward
()
# END_AUTOCAST_BASIC
# START_AUTOCAST_SEQUENTIAL
encoder_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
E4M3
)
decoder_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
)
encoder
=
te
.
Linear
(
1024
,
1024
)
decoder
=
te
.
Linear
(
1024
,
1024
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
encoder_recipe
):
hidden
=
encoder
(
inp
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
decoder_recipe
):
output
=
decoder
(
hidden
)
# END_AUTOCAST_SEQUENTIAL
# START_AUTOCAST_NESTED
outer_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
E4M3
)
inner_recipe
=
DelayedScaling
(
fp8_format
=
Format
.
HYBRID
)
layer1
=
te
.
Linear
(
1024
,
1024
)
layer2
=
te
.
Linear
(
1024
,
1024
)
layer3
=
te
.
Linear
(
1024
,
1024
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
outer_recipe
):
# layer1 uses outer_recipe
x
=
layer1
(
inp
)
with
te
.
autocast
(
enabled
=
True
,
recipe
=
inner_recipe
):
# layer2 uses inner_recipe (overrides outer)
x
=
layer2
(
x
)
# layer3 uses outer_recipe again
output
=
layer3
(
x
)
# END_AUTOCAST_NESTED
docs/features/low_precision_training/introduction/bf16_fp16_training_jax.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import
jax
import
jax.numpy
as
jnp
from
transformer_engine.jax.flax
import
TransformerLayer
def
run_forward_backward
(
params_dtype
,
compute_dtype
):
# Create TransformerLayer
layer
=
TransformerLayer
(
hidden_size
=
1024
,
mlp_hidden_size
=
4096
,
num_attention_heads
=
16
,
dtype
=
params_dtype
,
)
# Initialize parameters and optimizer
init_key
,
dropout_key
=
jax
.
random
.
split
(
jax
.
random
.
PRNGKey
(
0
))
x
=
jax
.
random
.
normal
(
init_key
,
(
32
,
128
,
1024
),
dtype
=
compute_dtype
)
var_collect
=
layer
.
init
({
"params"
:
init_key
,
"dropout"
:
dropout_key
},
x
)
# Forward and backward pass
def
loss_fn
(
var_collect
):
output
=
layer
.
apply
(
var_collect
,
x
,
rngs
=
{
"dropout"
:
dropout_key
})
assert
output
.
dtype
==
compute_dtype
return
output
.
sum
()
loss
,
grads
=
jax
.
value_and_grad
(
loss_fn
)(
var_collect
)
run_forward_backward
(
jnp
.
float32
,
jnp
.
float32
)
# high precision training
run_forward_backward
(
jnp
.
float32
,
jnp
.
bfloat16
)
# bfloat16 training with master weights in FP32
run_forward_backward
(
jnp
.
bfloat16
,
jnp
.
bfloat16
)
# bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
docs/features/low_precision_training/introduction/bf16_fp16_training_pytorch.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
# START_BF16_FP16_TRAINING
import
torch
import
transformer_engine.pytorch
as
te
from
contextlib
import
nullcontext
def
run_forward_backward
(
params_dtype
,
autocast_precision
,
grad_scaler_enabled
):
if
grad_scaler_enabled
:
grad_scaler
=
torch
.
amp
.
GradScaler
(
"cuda"
)
layer
=
te
.
TransformerLayer
(
hidden_size
=
1024
,
ffn_hidden_size
=
4096
,
num_attention_heads
=
16
,
params_dtype
=
params_dtype
,
)
x
=
torch
.
randn
(
32
,
128
,
1024
,
dtype
=
params_dtype
,
device
=
"cuda"
)
autocast_ctx
=
(
torch
.
autocast
(
device_type
=
"cuda"
,
dtype
=
autocast_precision
)
if
autocast_precision
is
not
None
else
nullcontext
()
)
with
autocast_ctx
:
output
=
layer
(
x
)
assert
(
output
.
dtype
==
autocast_precision
if
autocast_precision
is
not
None
else
params_dtype
)
loss
=
output
.
sum
()
if
grad_scaler_enabled
:
grad_scaler
.
scale
(
loss
).
backward
()
else
:
loss
.
backward
()
run_forward_backward
(
torch
.
float32
,
torch
.
float32
,
False
)
# high precision training
run_forward_backward
(
torch
.
float32
,
torch
.
bfloat16
,
False
)
# bfloat16 training with master weights in FP32
run_forward_backward
(
torch
.
float32
,
torch
.
float16
,
True
)
# fp16 training with master weights in FP32, needs loss scaling
run_forward_backward
(
torch
.
bfloat16
,
torch
.
bfloat16
,
False
)
# bfloat16 training with weights in BF16
# END_BF16_FP16_TRAINING
docs/features/low_precision_training/introduction/img/fp8_linear_flow.svg
0 → 100644
View file @
9df0c4a3
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 850 780"
width=
"850"
height=
"780"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
</style>
<marker
id=
"arrowhead"
markerWidth=
"6"
markerHeight=
"6"
refX=
"5"
refY=
"2"
orient=
"auto"
>
<polygon
points=
"0 0, 6 2, 0 4"
fill=
"#616161"
/>
</marker>
</defs>
<!-- Title -->
<text
x=
"425"
y=
"30"
class=
"title"
>
FP8 Linear Layer – Forward and Backward Pass
</text>
<!-- Forward Pass Section -->
<text
x=
"425"
y=
"65"
class=
"section-title"
style=
"fill: #1565c0;"
>
Forward Pass
</text>
<!-- Forward: Input^T FP8 (top, saved for backward) -->
<rect
x=
"270"
y=
"70"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"310"
y=
"100"
class=
"text"
>
Input
<tspan
baseline-shift=
"super"
style=
"font-size: 9px;"
>
T
</tspan></text>
<!-- Forward: Input High Precision -->
<rect
x=
"30"
y=
"130"
width=
"100"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"80"
y=
"160"
class=
"text"
>
Input
</text>
<!-- Forward: Arrow -->
<path
d=
"M 130 155 L 155 155"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Quantize Input -->
<rect
x=
"155"
y=
"130"
width=
"90"
height=
"50"
class=
"quantize"
rx=
"6"
/>
<text
x=
"200"
y=
"160"
class=
"text"
>
Quantize
</text>
<!-- Forward: Arrow to Input^T (going up) -->
<path
d=
"M 245 140 L 270 110"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Arrow to Input -->
<path
d=
"M 245 155 L 270 155"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Input FP8 -->
<rect
x=
"270"
y=
"130"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"310"
y=
"160"
class=
"text"
>
Input
</text>
<!-- Forward: Arrow from Input to GEMM -->
<path
d=
"M 350 155 L 400 170"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"370"
y=
"145"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
N
</text>
<!-- Forward: Weights High Precision -->
<rect
x=
"30"
y=
"195"
width=
"100"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"80"
y=
"225"
class=
"text"
>
Weight
</text>
<!-- Forward: Arrow -->
<path
d=
"M 130 220 L 155 220"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Quantize Weights -->
<rect
x=
"155"
y=
"195"
width=
"90"
height=
"50"
class=
"quantize"
rx=
"6"
/>
<text
x=
"200"
y=
"225"
class=
"text"
>
Quantize
</text>
<!-- Forward: Arrow to Weight -->
<path
d=
"M 245 220 L 270 220"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Arrow to Weight^T (going down) -->
<path
d=
"M 245 235 L 270 270"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Weights FP8 -->
<rect
x=
"270"
y=
"195"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"310"
y=
"225"
class=
"text"
>
Weight
</text>
<!-- Forward: Weight^T FP8 (bottom, saved for backward) -->
<rect
x=
"270"
y=
"255"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"310"
y=
"285"
class=
"text"
>
Weight
<tspan
baseline-shift=
"super"
style=
"font-size: 9px;"
>
T
</tspan></text>
<!-- Forward: Arrow from Weight to GEMM -->
<path
d=
"M 350 220 L 400 200"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"370"
y=
"230"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
T
</text>
<!-- Forward: GEMM -->
<rect
x=
"400"
y=
"160"
width=
"130"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"465"
y=
"180"
class=
"text"
style=
"font-weight: 600;"
>
FP8 GEMM
</text>
<text
x=
"465"
y=
"200"
class=
"text"
style=
"font-size: 11px;"
>
(TN)
</text>
<!-- Forward: Arrow -->
<path
d=
"M 530 185 L 580 185"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Forward: Output -->
<rect
x=
"580"
y=
"160"
width=
"110"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"635"
y=
"190"
class=
"text"
>
Output
</text>
<!-- Divider Line -->
<line
x1=
"30"
y1=
"310"
x2=
"820"
y2=
"310"
stroke=
"#ddd"
stroke-width=
"2"
/>
<!-- Backward Pass Section -->
<text
x=
"425"
y=
"345"
class=
"section-title"
style=
"fill: #c62828;"
>
Backward Pass
</text>
<!-- Backward: Weight^T (from forward, top input to GEMM1) -->
<rect
x=
"495"
y=
"355"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"535"
y=
"385"
class=
"text"
>
Weight
<tspan
baseline-shift=
"super"
style=
"font-size: 9px;"
>
T
</tspan></text>
<!-- Backward: Output gradient High Precision -->
<rect
x=
"30"
y=
"480"
width=
"130"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"95"
y=
"510"
class=
"text"
>
Output grad.
</text>
<!-- Backward: Arrow -->
<path
d=
"M 160 505 L 180 505"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Backward: Quantize Output gradient -->
<rect
x=
"180"
y=
"480"
width=
"90"
height=
"50"
class=
"quantize"
rx=
"6"
/>
<text
x=
"225"
y=
"510"
class=
"text"
>
Quantize
</text>
<!-- Backward: Arrow to Output grad (going up) -->
<path
d=
"M 270 490 L 290 465"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Backward: Arrow to Output grad^T (going down) -->
<path
d=
"M 270 520 L 290 545"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Backward: Output gradient FP8 (for input gradient) -->
<rect
x=
"290"
y=
"440"
width=
"110"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"345"
y=
"470"
class=
"text"
>
Output grad.
</text>
<!-- Backward: Output gradient^T FP8 (for weight gradient) -->
<rect
x=
"290"
y=
"520"
width=
"110"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"345"
y=
"550"
class=
"text"
>
Output grad.
<tspan
baseline-shift=
"super"
style=
"font-size: 9px;"
>
T
</tspan></text>
<!-- Backward: GEMM 1 (for input gradient) -->
<rect
x=
"470"
y=
"440"
width=
"130"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"535"
y=
"460"
class=
"text"
style=
"font-weight: 600;"
>
FP8 GEMM
</text>
<text
x=
"535"
y=
"480"
class=
"text"
style=
"font-size: 11px;"
>
(TN)
</text>
<!-- Backward: Input gradient -->
<rect
x=
"640"
y=
"440"
width=
"130"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"705"
y=
"470"
class=
"text"
>
Input grad.
</text>
<!-- Backward: GEMM 2 (for weight gradient) -->
<rect
x=
"470"
y=
"520"
width=
"130"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"535"
y=
"540"
class=
"text"
style=
"font-weight: 600;"
>
FP8 GEMM
</text>
<text
x=
"535"
y=
"560"
class=
"text"
style=
"font-size: 11px;"
>
(TN)
</text>
<!-- Backward: Weight gradient -->
<rect
x=
"640"
y=
"520"
width=
"130"
height=
"50"
class=
"hp"
rx=
"6"
/>
<text
x=
"705"
y=
"550"
class=
"text"
>
Weight grad.
</text>
<!-- Backward: Input^T (from forward, bottom input to GEMM2) -->
<rect
x=
"495"
y=
"605"
width=
"80"
height=
"50"
class=
"fp8"
rx=
"6"
/>
<text
x=
"535"
y=
"635"
class=
"text"
>
Input
<tspan
baseline-shift=
"super"
style=
"font-size: 9px;"
>
T
</tspan></text>
<!-- Backward: Arrows -->
<!-- Output gradient FP8 to top GEMM -->
<path
d=
"M 400 465 L 470 465"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"430"
y=
"457"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
N
</text>
<!-- Weight^T to top GEMM -->
<path
d=
"M 535 405 L 535 440"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"543"
y=
"427"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
T
</text>
<!-- Top GEMM to input gradient -->
<path
d=
"M 600 465 L 640 465"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Output gradient^T FP8 to bottom GEMM -->
<path
d=
"M 400 545 L 470 545"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"430"
y=
"537"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
N
</text>
<!-- Input^T to bottom GEMM -->
<path
d=
"M 535 605 L 535 570"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<text
x=
"543"
y=
"597"
class=
"text"
style=
"font-size: 10px; font-weight: bold;"
>
T
</text>
<!-- Bottom GEMM to weight gradient -->
<path
d=
"M 600 545 L 640 545"
stroke=
"#616161"
stroke-width=
"2"
fill=
"none"
marker-end=
"url(#arrowhead)"
/>
<!-- Legend -->
<g
transform=
"translate(30, 680)"
>
<!-- Higher Precision -->
<rect
x=
"0"
y=
"0"
width=
"80"
height=
"40"
rx=
"5"
class=
"hp"
/>
<text
x=
"95"
y=
"23"
class=
"text"
style=
"text-anchor: start;"
>
Higher Precision (FP32/BF16/FP16)
</text>
<!-- Lower Precision -->
<rect
x=
"380"
y=
"0"
width=
"80"
height=
"40"
rx=
"5"
class=
"fp8"
/>
<text
x=
"475"
y=
"23"
class=
"text"
style=
"text-anchor: start;"
>
Lower Precision (FP8, MXFP8 etc.)
</text>
</g>
</svg>
docs/features/low_precision_training/introduction/img/fp_formats_comparison.svg
0 → 100644
View file @
9df0c4a3
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 900 210"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.sign-bit { fill: #9db4d0; stroke: #333; stroke-width: 1; }
.exponent-bit { fill: #d9a066; stroke: #333; stroke-width: 1; }
.mantissa-bit { fill: #a8d99c; stroke: #333; stroke-width: 1; }
.bit-text { fill: #000; text-anchor: middle; dominant-baseline: middle; font-size: 16px; }
.header-text { fill: #555; font-weight: normal; text-anchor: middle; font-size: 18px; }
.value-text { fill: #333; font-size: 18px; }
.format-label { fill: #333; font-size: 20px; font-weight: bold; text-anchor: middle; dominant-baseline: middle; }
</style>
</defs>
<!-- Header labels - centered -->
<text
x=
"79"
y=
"18"
class=
"header-text"
>
sign
</text>
<text
x=
"173"
y=
"18"
class=
"header-text"
>
exponent
</text>
<text
x=
"530"
y=
"18"
class=
"header-text"
>
mantissa
</text>
<!-- FP32 Format (32 bits: 1 + 8 + 23) -->
<text
x=
"30"
y=
"60"
class=
"format-label"
>
FP32
</text>
<!-- Sign bit (1) -->
<rect
x=
"70"
y=
"45"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"79"
y=
"60"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (8) -->
<rect
x=
"93"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"102"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"116"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"125"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"139"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"148"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"162"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"171"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"185"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"194"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"208"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"217"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"231"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"240"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"254"
y=
"45"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"263"
y=
"60"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (23) -->
<rect
x=
"277"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"286"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"300"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"309"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"323"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"332"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"346"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"355"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"369"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"378"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"392"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"401"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"415"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"424"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"438"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"447"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"461"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"470"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"484"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"493"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"507"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"516"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"530"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"539"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"553"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"562"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"576"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"585"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"599"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"608"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"622"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"631"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"645"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"654"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"668"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"677"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"691"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"700"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"714"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"723"
y=
"60"
class=
"bit-text"
>
1
</text>
<rect
x=
"737"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"746"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"760"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"769"
y=
"60"
class=
"bit-text"
>
0
</text>
<rect
x=
"783"
y=
"45"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"792"
y=
"60"
class=
"bit-text"
>
0
</text>
<text
x=
"820"
y=
"60"
class=
"value-text"
>
= 0.3952
</text>
<!-- BF16 Format (16 bits: 1 + 8 + 7) -->
<text
x=
"30"
y=
"120"
class=
"format-label"
>
BF16
</text>
<!-- Sign bit (1) -->
<rect
x=
"70"
y=
"105"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"79"
y=
"120"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (8) -->
<rect
x=
"93"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"102"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"116"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"125"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"139"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"148"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"162"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"171"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"185"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"194"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"208"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"217"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"231"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"240"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"254"
y=
"105"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"263"
y=
"120"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (7) -->
<rect
x=
"277"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"286"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"300"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"309"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"323"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"332"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"346"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"355"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"369"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"378"
y=
"120"
class=
"bit-text"
>
0
</text>
<rect
x=
"392"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"401"
y=
"120"
class=
"bit-text"
>
1
</text>
<rect
x=
"415"
y=
"105"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"424"
y=
"120"
class=
"bit-text"
>
0
</text>
<text
x=
"820"
y=
"120"
class=
"value-text"
>
≈ 0.3945
</text>
<!-- FP16 Format (16 bits: 1 + 5 + 10) -->
<text
x=
"30"
y=
"180"
class=
"format-label"
>
FP16
</text>
<!-- Sign bit (1) -->
<rect
x=
"70"
y=
"165"
width=
"18"
height=
"30"
class=
"sign-bit"
/>
<text
x=
"79"
y=
"180"
class=
"bit-text"
>
0
</text>
<!-- Exponent bits (5) -->
<rect
x=
"93"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"102"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"116"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"125"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"139"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"148"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"162"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"171"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"185"
y=
"165"
width=
"18"
height=
"30"
class=
"exponent-bit"
/>
<text
x=
"194"
y=
"180"
class=
"bit-text"
>
1
</text>
<!-- Mantissa bits (10) -->
<rect
x=
"208"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"217"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"231"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"240"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"254"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"263"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"277"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"286"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"300"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"309"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"323"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"332"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"346"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"355"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"369"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"378"
y=
"180"
class=
"bit-text"
>
0
</text>
<rect
x=
"392"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"401"
y=
"180"
class=
"bit-text"
>
1
</text>
<rect
x=
"415"
y=
"165"
width=
"18"
height=
"30"
class=
"mantissa-bit"
/>
<text
x=
"424"
y=
"180"
class=
"bit-text"
>
0
</text>
<text
x=
"820"
y=
"180"
class=
"value-text"
>
≈ 0.3950
</text>
</svg>
docs/features/low_precision_training/introduction/img/master_weights_approaches.svg
0 → 100644
View file @
9df0c4a3
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
xmlns=
"http://www.w3.org/2000/svg"
viewBox=
"0 0 1050 580"
width=
"1050"
height=
"580"
>
<defs>
<style>
@import url("../_static/css/diagram-colors.css");
.arrow { stroke: #616161; stroke-width: 2; fill: none; marker-end: url(#arrowhead); }
.column-title { font-family: 'Segoe UI', Arial, sans-serif; font-size: 14px; font-weight: 600; text-anchor: middle; fill: #424242; }
.divider { stroke: #bdbdbd; stroke-width: 1.5; stroke-dasharray: 8,6; }
</style>
<marker
id=
"arrowhead"
markerWidth=
"6"
markerHeight=
"6"
refX=
"5"
refY=
"2"
orient=
"auto"
>
<polygon
points=
"0 0, 6 2, 0 4"
fill=
"#616161"
/>
</marker>
</defs>
<!-- Title -->
<text
x=
"525"
y=
"30"
class=
"title"
>
Master Weights Storage Approaches
</text>
<!-- Vertical dividers (dashed lines) -->
<line
x1=
"350"
y1=
"50"
x2=
"350"
y2=
"560"
class=
"divider"
/>
<line
x1=
"700"
y1=
"50"
x2=
"700"
y2=
"560"
class=
"divider"
/>
<!-- Column 1: Low Precision Only -->
<text
x=
"175"
y=
"75"
class=
"column-title"
>
Low Precision Weights
</text>
<text
x=
"175"
y=
"93"
class=
"small-text"
>
(no master weights)
</text>
<!-- Model box -->
<rect
x=
"60"
y=
"145"
width=
"230"
height=
"90"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"175"
y=
"168"
class=
"label"
>
Model
</text>
<rect
x=
"80"
y=
"183"
width=
"190"
height=
"40"
class=
"hp"
rx=
"4"
/>
<text
x=
"175"
y=
"208"
class=
"text"
>
Weights (BF16/FP16)
</text>
<!-- Arrow down -->
<path
d=
"M 175 235 L 175 300"
class=
"arrow"
/>
<!-- Computation -->
<rect
x=
"90"
y=
"300"
width=
"170"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"175"
y=
"330"
class=
"text"
>
Forward/Backward
</text>
<!-- Arrow down -->
<path
d=
"M 175 350 L 175 415"
class=
"arrow"
/>
<!-- Optimizer box -->
<rect
x=
"60"
y=
"415"
width=
"230"
height=
"90"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"175"
y=
"438"
class=
"label"
>
Optimizer
</text>
<rect
x=
"80"
y=
"453"
width=
"190"
height=
"40"
class=
"fp32"
rx=
"4"
/>
<text
x=
"175"
y=
"478"
class=
"text"
>
State (FP32)
</text>
<!-- Column 2: Master Weights in Model -->
<text
x=
"525"
y=
"75"
class=
"column-title"
>
Master Weights in Model
</text>
<!-- Model box -->
<rect
x=
"410"
y=
"145"
width=
"230"
height=
"90"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"525"
y=
"168"
class=
"label"
>
Model
</text>
<rect
x=
"430"
y=
"183"
width=
"190"
height=
"40"
class=
"fp32"
rx=
"4"
/>
<text
x=
"525"
y=
"208"
class=
"text"
>
Weights (FP32)
</text>
<!-- Arrow down with cast -->
<path
d=
"M 525 235 L 525 300"
class=
"arrow"
/>
<rect
x=
"465"
y=
"255"
width=
"120"
height=
"26"
class=
"quantize"
rx=
"4"
/>
<text
x=
"525"
y=
"273"
class=
"small-text"
>
cast to BF16/FP16
</text>
<!-- Computation -->
<rect
x=
"440"
y=
"300"
width=
"170"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"525"
y=
"330"
class=
"text"
>
Forward/Backward
</text>
<!-- Arrow down -->
<path
d=
"M 525 350 L 525 415"
class=
"arrow"
/>
<!-- Optimizer box -->
<rect
x=
"410"
y=
"415"
width=
"230"
height=
"90"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"525"
y=
"438"
class=
"label"
>
Optimizer
</text>
<rect
x=
"430"
y=
"453"
width=
"190"
height=
"40"
class=
"fp32"
rx=
"4"
/>
<text
x=
"525"
y=
"478"
class=
"text"
>
State (FP32)
</text>
<!-- Column 3: Master Weights in Optimizer -->
<text
x=
"875"
y=
"75"
class=
"column-title"
>
Master Weights in Optimizer
</text>
<!-- Cast box above Model -->
<rect
x=
"815"
y=
"105"
width=
"120"
height=
"26"
class=
"quantize"
rx=
"4"
/>
<text
x=
"875"
y=
"123"
class=
"small-text"
>
cast to BF16/FP16
</text>
<!-- Arrow from cast to Model -->
<path
d=
"M 875 131 L 875 145"
class=
"arrow"
/>
<!-- Model box -->
<rect
x=
"760"
y=
"145"
width=
"230"
height=
"90"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"875"
y=
"168"
class=
"label"
>
Model
</text>
<rect
x=
"780"
y=
"183"
width=
"190"
height=
"40"
class=
"hp"
rx=
"4"
/>
<text
x=
"875"
y=
"208"
class=
"text"
>
Weights (BF16/FP16)
</text>
<!-- Arrow down -->
<path
d=
"M 875 235 L 875 300"
class=
"arrow"
/>
<!-- Computation -->
<rect
x=
"790"
y=
"300"
width=
"170"
height=
"50"
class=
"gemm"
rx=
"6"
/>
<text
x=
"875"
y=
"330"
class=
"text"
>
Forward/Backward
</text>
<!-- Arrow down -->
<path
d=
"M 875 350 L 875 415"
class=
"arrow"
/>
<!-- Optimizer box with State and Master -->
<rect
x=
"760"
y=
"415"
width=
"230"
height=
"140"
rx=
"6"
fill=
"#f5f5f5"
stroke=
"#9e9e9e"
stroke-width=
"1.5"
/>
<text
x=
"875"
y=
"438"
class=
"label"
>
Optimizer
</text>
<rect
x=
"780"
y=
"453"
width=
"190"
height=
"40"
class=
"fp32"
rx=
"4"
/>
<text
x=
"875"
y=
"478"
class=
"text"
>
State (FP32)
</text>
<rect
x=
"780"
y=
"503"
width=
"190"
height=
"40"
class=
"fp32"
rx=
"4"
/>
<text
x=
"875"
y=
"528"
class=
"text"
>
Master (FP32)
</text>
<!-- Arrow from Master to cast -->
<path
d=
"M 970 523 L 1010 523 L 1010 118 L 935 118"
class=
"arrow"
/>
</svg>
Prev
1
2
3
4
5
6
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment