Unverified Commit 7042d7ae authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

TE Gemma tutorial attempt#2 (#1839)



* add tutorial files and other local changes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove extraneous code for easy debu
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make cuda graphs work with non-paged and paged attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* perf imp for kv cache ops
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add code for calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changes to make quantizers work with fp8_calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use quantizer only if needed
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* functionality of the tutorial tested and perf checked
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove files and update headers/licenses
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* update header/license
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update tutorial for review
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make weights downloadable on the fly; remove extra print statements
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and update comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add comma back, typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* sequence_start_positions should be None for training
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add paged attention numberes and update requirements.txt file
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make tutorial work on blackwell
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove gemma FT tutorial for now
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixing the headings placement and rewording attention -> kv caching
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixes from comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the images
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* misc fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add more comments to te_gemma.py and cleanup utils.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more information about the hierarchy of the classes used in the tutorial
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add better cuda graphs picture
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* addd updated cuda graphs pictures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add illustrated cuda graphs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes in documentation
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add torch.no_grad() to force reduced memory usage
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* some fixes from recent comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes from remaining comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add te_rope_emb to class desc
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix tutorial wording; add calibration fix to grouped_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent ba37529c
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="1280"
height="375.32169"
overflow="hidden"
version="1.1"
id="svg62"
sodipodi:docname="calibration.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="1.2875"
inkscape:cx="594.17476"
inkscape:cy="301.74757"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g62" />
<defs
id="defs62">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath62">
<rect
style="fill:none"
id="rect63"
width="1371.8354"
height="375.32169"
x="-39.964806"
y="153.77762" />
</clipPath>
</defs>
<g
id="g62"
clip-path="url(#clipPath62)"
transform="translate(0,-153.77762)">
<rect
x="0"
y="0"
width="1280"
height="720"
fill="#ffffff"
id="rect1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text1"
transform="translate(39.6169,204)">FP8 with initial scaling factors</text>
<rect
x="25"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect2" />
<rect
x="40"
y="351"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(76.8203,374)"
id="text3">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(63.067,390)"
id="text4">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(70.067,406)"
id="text5">weight</text>
<rect
x="40"
y="433"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f7cbcb"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(74.3203,445)"
id="text6">Initial</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(55.7337,461)"
id="text7">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(68.6536,477)"
id="text8">factors</text>
<rect
x="183"
y="363"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(212.27,382)"
id="text9">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(201.77,398)"
id="text10">Weight</text>
<rect
x="288"
y="307"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(316.622,325)"
id="text11">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(312.202,341)"
id="text12">Input</text>
<rect
x="277"
y="224"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(314.289,247)"
id="text13">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(300.535,263)"
id="text14">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(312.455,279)"
id="text15">input</text>
<rect
x="288"
y="363"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect15" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(316.619,382)"
id="text16">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(307.952,398)"
id="text17">GEMM</text>
<path
d="M 0.015735,-0.999876 34.0184,-0.464776 33.987,1.53498 -0.015735,0.999876 Z M 32.7325,-3.48538 40.6686,0.64 32.6066,4.51362 Z"
transform="matrix(1,0,0,-1,143,386.64)"
id="path17" />
<path
d="m 265,385 h 15.791 v 2 H 265 Z m 14.458,-3 8,4 -8,4 z"
id="path18" />
<path
d="m 330,351 v 5.349 h -2 V 351 Z m 3,4.016 -4,8 -4,-8 z"
id="path19" />
<path
d="m 330,295 v 5.349 h -2 V 295 Z m 3,4.016 -4,8 -4,-8 z"
id="path20" />
<path
d="m 246.452,367 3.86,5.911 -1.592,0.697 4.862,5.065 -1.593,0.846 6.011,8.481 -10.193,-6.499 1.944,-0.903 -6.334,-4.163 2.27,-1.286 -6.687,-4.367 z"
fill="#ff0000"
fill-rule="evenodd"
id="path21" />
<path
d="m 351.844,310 4.063,5.63 -1.676,0.664 5.118,4.824 -1.676,0.805 6.327,8.077 -10.73,-6.19 2.047,-0.859 -6.667,-3.965 2.389,-1.225 -7.039,-4.159 z"
fill="#ff0000"
fill-rule="evenodd"
id="path22" />
<path
d="m 353.452,367 3.86,5.63 -1.592,0.664 4.862,4.824 -1.592,0.805 6.01,8.077 -10.193,-6.19 1.944,-0.859 -6.333,-3.965 2.269,-1.225 -6.687,-4.159 z"
fill="#ff0000"
fill-rule="evenodd"
id="path23" />
<path
d="M 0.0369111,-0.999319 6.03282,-0.777852 5.959,1.22079 -0.0369111,0.999319 Z M 8.10061,-0.673792 14.0656,-0.0265058 13.8498,1.96182 7.88485,1.31454 Z M 16.1141,0.247765 22.0214,1.2984 21.6712,3.2675 15.7639,2.21686 Z M 24.0392,1.73603 29.8647,3.17232 29.386,5.11418 23.5604,3.67789 Z m 7.8008,2.03501 5.5037,1.74085 0.2724,0.10597 -0.7252,1.86389 -0.2422,-0.09424 0.061,0.0215 -5.4729,-1.73109 z m 7.64,2.57203 4.6694,1.81681 0.9564,0.44699 -0.8468,1.81193 -0.9264,-0.43303 0.0608,0.02603 -4.6387,-1.80484 z m 7.4377,3.1106 3.679,1.71943 1.7562,0.9737 -0.9698,1.7492 -1.726,-0.957 0.0615,0.0313 -3.6477,-1.7047 z m 7.1844,3.66303 2.5248,1.3999 2.6656,1.7467 -1.0962,1.6728 -2.6347,-1.7264 0.0632,0.0381 -2.4926,-1.382 z m 6.8632,4.2428 1.2157,0.7966 3.6532,2.8401 -1.2275,1.579 -3.6214,-2.8153 0.0657,0.0469 -1.1819,-0.7745 z m 6.4982,4.9445 4.161,3.8856 0.2694,0.3094 -1.5085,1.3132 -0.2354,-0.2704 0.0717,0.0743 -4.1232,-3.8503 z m 5.7436,5.7035 2.187,2.5121 1.5663,2.3083 -1.6549,1.123 -1.5326,-2.2585 0.0732,0.0951 -2.1475,-2.4668 z m 4.8763,6.4753 0.3629,0.5348 2.2811,4.6599 0.1233,0.5117 -1.9444,0.4684 -0.0976,-0.4049 0.0741,0.2055 -2.2018,-4.4979 0.0707,0.1218 -0.3232,-0.4763 z m 3.2357,7.6508 0.0649,0.2693 -1.9444,0.4684 -0.0649,-0.2693 z m 2.7866,-1.5294 -2.4916,8.5902 -5.3772,-7.1474 z"
transform="matrix(1,0,0,-1,143,457.194)"
id="path24" />
<path
d="M 0.0162082,-0.999869 6.01542,-0.902619 5.983,1.09712 -0.0162082,0.999869 Z M 8.01516,-0.870203 14.0144,-0.772954 13.9819,1.22678 7.98274,1.12953 Z m 7.99894,0.129666 1.4291,0.023166 4.5969,0.218922 -0.0952,1.997739 -4.5812,-0.21818 0.0314,0.001 -1.4134,-0.02291 z m 8.0237,0.337228 5.9932,0.28542 -0.0951,1.997739 -5.9932,-0.28542 z m 7.991,0.3805605 2.7203,0.1295545 3.2969,0.257334 -0.1557,1.99394 -3.2817,-0.25616 0.0303,0.0019 -2.7053,-0.12883 z M 40.0399,0.519777 46.0217,0.986686 45.8661,2.98062 39.8843,2.51371 Z M 48.0156,1.14232 51.7818,1.43629 54.021,1.67858 53.8059,3.66697 51.5816,3.42629 51.6113,3.42907 47.86,3.13626 Z m 7.9938,0.75141 5.9652,0.64546 -0.2152,1.98839 -5.9652,-0.64545 z m 7.9536,0.86061 4.4428,0.48073 1.5466,0.21469 -0.275,1.98101 -1.5317,-0.21262 0.0299,0.00369 -4.4278,-0.47911 z m 7.9704,0.97042 5.943,0.825 -0.275,1.981 -5.943,-0.82499 z m 7.924,1.09999 4.6281,0.64246 1.3393,0.22871 L 85.4882,7.66738 84.1642,7.4413 84.195,7.44607 79.5824,6.80576 Z M 87.7963,6.03256 93.7107,7.0425 93.374,9.01396 87.4596,8.00403 Z m 7.8858,1.34658 4.2034,0.71775 1.7325,0.35545 -0.401,1.95916 -1.7172,-0.3521 0.0326,0.0062 -4.1869,-0.71499 z m 7.8959,1.47505 5.877,1.20561 -0.402,1.9592 -5.877,-1.2056 z m 7.836,1.60741 3.057,0.6269 2.834,0.6896 -0.473,1.9433 -2.817,-0.6853 0.036,0.0079 -3.038,-0.6232 z m 7.834,1.7894 5.83,1.4186 -0.473,1.9433 -5.83,-1.4187 z m 7.773,1.8915 1.085,0.264 4.733,1.3611 -0.553,1.9221 -4.713,-1.3554 0.04,0.0106 -1.065,-0.2591 z m 7.74,2.1779 5.766,1.6583 -0.553,1.9221 -5.766,-1.6583 z m 7.707,2.3135 5.679,1.9375 -0.646,1.8928 -5.679,-1.9374 z m 7.572,2.5833 1.951,0.6658 3.7,1.5162 -0.759,1.8506 -3.672,-1.5048 0.057,0.0211 -1.923,-0.656 z m 7.501,2.9404 4.424,1.8129 1.151,0.5493 -0.861,1.8051 -1.126,-0.5372 0.052,0.0228 -4.398,-1.8022 z m 7.38,3.2235 1.496,0.7138 3.871,2.0659 -0.942,1.7644 -3.85,-2.0553 0.04,0.0204 -1.476,-0.7042 z m 7.13,3.8317 2.082,1.2578 2.982,2.0725 -1.142,1.6424 -2.956,-2.0544 0.054,0.0348 -2.054,-1.2413 z m 6.687,4.6257 1.439,1.1759 2.398,2.378 0.577,0.9359 -1.703,1.0495 -0.514,-0.8346 0.147,0.1853 -2.279,-2.2602 0.072,0.0643 -1.402,-1.1454 z m 5.463,6.1924 0.25,0.4051 -1.703,1.0494 -0.249,-0.4051 z m 2.666,-1.7349 -0.904,8.8984 -6.576,-6.0624 z"
transform="matrix(1,0,0,-1,143,457.194)"
id="path25" />
<path
d="M 0.0156214,-0.999878 6.01489,-0.90615 5.98365,1.09361 -0.0156214,0.999878 Z M 8.01465,-0.874907 14.0139,-0.781179 13.9827,1.21858 7.9834,1.12485 Z m 7.99905,0.124971 5.9992,0.093728 -0.0312,1.999758 -5.9993,-0.09373 z m 8.0288,0.139661 5.9937,0.275105 -0.0917,1.9979 -5.9937,-0.27511 z m 7.9916,0.366807 5.9937,0.2751062 L 37.9361,2.02953 31.9424,1.75443 Z M 40.0257,0.12334 46.0194,0.398446 45.9277,2.39634 39.934,2.12124 Z m 8.0188,0.399862 5.9831,0.450096 -0.15,1.994362 -5.9831,-0.45009 z M 56.022,1.12333 62.0051,1.57343 61.855,3.56779 55.872,3.1177 Z m 7.9774,0.60013 5.9792,0.4498 0.0326,0.0034 -0.2075,1.98921 -0.0182,-0.00191 0.0287,0.00258 -5.9648,-0.44872 z M 72.0004,2.3841 77.968,3.00644 77.7606,4.99565 71.793,4.37332 Z m 7.9569,0.82979 5.9676,0.62233 -0.2074,1.98922 -5.9677,-0.62234 z m 7.9568,0.82978 4.5303,0.47244 1.4612,0.1955 -0.2652,1.98233 -1.4468,-0.19356 0.0289,0.00344 -4.5158,-0.47094 z m 7.9738,0.93315 5.9471,0.79566 -0.265,1.98233 -5.9473,-0.79565 z m 7.9291,1.06087 5.947,0.79565 -0.265,1.98234 -5.947,-0.79565 z m 7.93,1.06087 2.427,0.32477 3.534,0.58152 -0.325,1.97346 -3.519,-0.57908 0.03,0.00444 -2.413,-0.32277 z m 7.934,1.23107 5.92,0.97434 -0.324,1.97343 -5.921,-0.9743 z m 7.894,1.29912 5.92,0.97435 -0.325,1.9734 -5.92,-0.9743 z m 7.922,1.32085 5.886,1.1635 -0.387,1.9621 -5.887,-1.1636 z m 7.848,1.5514 5.887,1.1636 -0.388,1.962 -5.886,-1.1636 z m 7.849,1.5514 3.498,0.6916 2.404,0.5638 -0.457,1.9472 -2.387,-0.5598 0.035,0.0074 -3.481,-0.6881 z m 7.849,1.7121 5.842,1.3699 -0.457,1.9472 -5.842,-1.37 z m 7.789,1.8266 5.841,1.37 -0.456,1.9472 -5.842,-1.37 z m 7.806,1.8786 5.798,1.5413 -0.513,1.9329 -5.799,-1.5414 z m 7.748,2.0737 5.764,1.6661 -0.555,1.9214 -5.764,-1.6662 z m 7.708,2.2287 5.723,1.8017 -0.6,1.9077 -5.723,-1.8017 z m 7.631,2.4022 0.212,0.0667 5.489,1.8864 -0.65,1.8914 -5.477,-1.8821 0.025,0.0081 -0.199,-0.0628 z m 7.592,2.6032 0.055,0.0187 5.588,2.1046 -0.705,1.8716 -5.574,-2.0994 0.027,0.0099 -0.041,-0.014 z m 7.539,2.8566 5.542,2.2986 -0.766,1.8474 -5.542,-2.2986 z m 7.401,3.1331 4.585,2.1085 0.887,0.4562 -0.915,1.7785 -0.867,-0.4461 0.04,0.0193 -4.566,-2.0993 z m 7.251,3.4796 2.806,1.4437 2.503,1.4574 -1.007,1.7284 -2.48,-1.4443 0.046,0.025 -2.783,-1.4317 z m 7.037,3.9076 0.695,0.4048 4.359,2.9197 0.07,0.055 -1.237,1.5713 -0.04,-0.0312 0.062,0.0452 -4.301,-2.8808 0.054,0.0333 -0.668,-0.389 z m 6.695,4.617 2.13,1.677 2.447,2.3395 -1.382,1.4456 -2.413,-2.3064 0.073,0.0628 -2.092,-1.6473 z m 6.009,5.5752 1.803,2.1966 1.687,2.8491 -1.721,1.0191 -1.648,-2.7833 0.087,0.1248 -1.754,-2.1376 z m 4.288,7.0798 0.519,1.4444 0.388,3.1989 -0.075,1.6336 -1.998,-0.0921 0.072,-1.5501 0.006,0.1665 -0.364,-3.0041 0.051,0.218 -0.481,-1.3383 z m 0.74,8.2748 -0.276,5.9936 -1.998,-0.092 0.276,-5.9937 z m -0.368,7.9915 -0.021,0.4617 -0.75,5.5753 -1.982,-0.2665 0.744,-5.5319 -0.008,0.0872 0.019,-0.4179 z m -1.038,8.0192 -0.552,4.1104 -0.423,1.8921 -1.952,-0.4361 0.413,-1.85 -0.015,0.0848 0.547,-4.0678 z m -1.411,7.9544 -1.308,5.8556 -1.952,-0.4361 1.308,-5.8556 z m -1.815,7.8602 -1.268,4.2755 -0.529,1.5 -1.886,-0.664 0.52,-1.477 -0.016,0.048 1.262,-4.2515 z m -2.461,7.6625 -0.475,1.35 -1.776,4.269 -1.847,-0.768 1.766,-4.244 -0.02,0.053 0.466,-1.325 z m -3.159,7.473 -1.084,2.197 -1.85,3.12 -1.72,-1.02 1.83,-3.087 -0.036,0.067 1.067,-2.162 z m -4.134,7.032 -1.238,1.698 -3.124,2.632 -1.289,-1.53 3.032,-2.554 -0.163,0.176 1.166,-1.6 z m -2.038,6.56 -8.945,-0.04 5.399,-7.131 z"
transform="matrix(1,0,0,-1,143,457.214)"
id="path26" />
<path
d="M 821,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path27" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text27"
x="566.53845"
y="203.2233">Calibration</text>
<rect
x="461"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect27" />
<rect
x="476"
y="351"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect28" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(513.235,374)"
id="text28">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(499.482,390)"
id="text29">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(506.482,406)"
id="text30">weight</text>
<rect
x="476"
y="433"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#ffffff"
id="rect30" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(493.898,453)"
id="text31">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(505.065,469)"
id="text32">factors</text>
<rect
x="679"
y="224"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect32" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(716.025,247)"
id="text33">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(702.272,263)"
id="text34">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(714.192,279)"
id="text35">input</text>
<rect
x="679"
y="351"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect35" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(716.026,374)"
id="text36">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(702.272,390)"
id="text37">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(709.692,406)"
id="text38">GEMM</text>
<path
d="M 0.00641402,-0.999979 93.1192,-0.402739 93.1064,1.59722 -0.00641402,0.999979 Z M 91.8051,-3.41123 99.7793,0.64 91.7538,4.5886 Z"
transform="matrix(1,0,0,-1,579,386.64)"
id="path38" />
<path
d="m 732,295 v 49.395 h -2 V 295 Z m 3,48.061 -4,8 -4,-8 z"
id="path39" />
<path
d="m 731.277,421.127 -0.235,1.83 -0.725,1.912 -1.158,1.852 -0.336,0.389 -1.512,-1.309 0.286,-0.33 -0.092,0.125 1.064,-1.703 -0.087,0.175 0.649,-1.71 -0.057,0.228 0.219,-1.713 z m -3.865,7.565 -1.743,1.579 -2.309,1.723 -0.835,0.493 -1.017,-1.721 0.789,-0.466 -0.09,0.059 2.228,-1.663 -0.073,0.061 1.707,-1.547 z m -6.609,4.813 -3.055,1.804 -2.319,1.069 -0.837,-1.817 2.273,-1.047 -0.09,0.047 3.011,-1.778 z m -7.191,3.709 -2.795,1.287 -2.806,1.05 -0.701,-1.873 2.772,-1.037 -0.068,0.028 2.761,-1.272 z m -7.474,3.039 -3.452,1.291 -2.263,0.706 -0.595,-1.91 2.237,-0.697 -0.053,0.019 3.425,-1.282 z m -7.625,2.592 -5.053,1.575 -0.727,0.191 -0.509,-1.935 0.706,-0.185 -0.044,0.013 5.032,-1.568 z m -7.715,2.274 -5.803,1.524 -0.508,-1.934 5.803,-1.524 z m -7.776,2.034 -5.857,1.301 -0.434,-1.952 5.857,-1.302 z m -7.81,1.735 -3.065,0.682 -2.844,0.533 -0.368,-1.966 2.827,-0.53 -0.033,0.007 3.049,-0.678 z m -7.875,1.583 -5.897,1.105 -0.368,-1.966 5.897,-1.105 z m -7.897,1.453 -5.929,0.925 -0.308,-1.976 5.928,-0.925 z m -7.905,1.233 -3.789,0.591 -2.176,0.276 -0.252,-1.984 2.162,-0.274 -0.028,0.004 3.775,-0.589 z m -7.949,1.118 -5.953,0.755 -0.251,-1.984 5.952,-0.755 z m -7.937,1.007 -0.987,0.125 -5.008,0.495 -0.197,-1.99 4.994,-0.494 -0.027,0.003 0.974,-0.123 z m -7.985,0.817 -5.971,0.59 -0.197,-1.991 5.971,-0.59 z m -7.992,0.749 -5.985,0.427 -0.142,-1.995 5.985,-0.427 z m -7.98,0.569 -4.417,0.315 -1.598,0.069 -0.087,-1.998 1.585,-0.069 -0.028,0.002 4.403,-0.314 z m -8.013,0.471 -5.994,0.261 -0.087,-1.998 5.994,-0.261 z m -7.992,0.348 -2.494,0.108 -3.532,0.053 -0.03,-2 3.519,-0.052 -0.029,10e-4 2.479,-0.108 z m -8.026,0.19 -1.98,0.03 -0.03,-2 1.98,-0.029 z m -0.603,3.01 -8.058,-3.881 7.94,-4.119 z"
id="path40" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
transform="translate(857.055,204)"
id="text40">FP8 with calibrated scaling factors</text>
<rect
x="868"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect40" />
<rect
x="883"
y="351"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect41" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(919.685,374)"
id="text41">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(905.932,390)"
id="text42">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(912.932,406)"
id="text43">weight</text>
<rect
x="883"
y="433"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#92d050"
id="rect43" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(902.185,445)"
id="text44">Calibrated</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(898.599,461)"
id="text45">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(911.519,477)"
id="text46">factors</text>
<rect
x="1026"
y="363"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect46" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1055.14,382)"
id="text47">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1044.64,398)"
id="text48">Weight</text>
<rect
x="1131"
y="307"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect48" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1159.49,325)"
id="text49">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1155.07,341)"
id="text50">Input</text>
<rect
x="1120"
y="224"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect50" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1157.15,247)"
id="text51">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1143.4,263)"
id="text52">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1155.32,279)"
id="text53">input</text>
<rect
x="1131"
y="363"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect53" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1159.48,382)"
id="text54">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(1150.82,398)"
id="text55">GEMM</text>
<path
d="M 0.015735,-0.999876 34.0184,-0.464776 33.987,1.53498 -0.015735,0.999876 Z M 32.7325,-3.48538 40.6686,0.64 32.6066,4.51362 Z"
transform="matrix(1,0,0,-1,986,386.64)"
id="path55" />
<path
d="m 1108,385 h 15.79 v 2 H 1108 Z m 14.46,-3 8,4 -8,4 z"
id="path56" />
<path
d="m 1173,351 v 5.349 h -2 V 351 Z m 3,4.016 -4,8 -4,-8 z"
id="path57" />
<path
d="m 1173,295 v 5.349 h -2 V 295 Z m 3,4.016 -4,8 -4,-8 z"
id="path58" />
<path
d="M 0.0369111,-0.999319 6.03282,-0.777852 5.959,1.22079 -0.0369111,0.999319 Z M 8.10061,-0.673792 14.0656,-0.0265058 13.8498,1.96182 7.88485,1.31454 Z M 16.1141,0.247765 22.0214,1.2984 21.6712,3.2675 15.7639,2.21686 Z M 24.0392,1.73603 29.8647,3.17232 29.386,5.11418 23.5604,3.67789 Z m 7.8008,2.03501 5.5037,1.74085 0.2724,0.10597 -0.7252,1.86389 -0.2422,-0.09424 0.061,0.0215 -5.4729,-1.73109 z m 7.64,2.57203 4.6694,1.81681 0.9564,0.44699 -0.8468,1.81193 -0.9264,-0.43303 0.0608,0.02603 -4.6387,-1.80484 z m 7.4377,3.1106 3.679,1.71943 1.7562,0.9737 -0.9698,1.7492 -1.726,-0.957 0.0615,0.0313 -3.6477,-1.7047 z m 7.1844,3.66303 2.5248,1.3999 2.6656,1.7467 -1.0962,1.6728 -2.6347,-1.7264 0.0632,0.0381 -2.4926,-1.382 z m 6.8632,4.2428 1.2157,0.7966 3.6532,2.8401 -1.2275,1.579 -3.6214,-2.8153 0.0657,0.0469 -1.1819,-0.7745 z m 6.4982,4.9445 4.161,3.8856 0.2694,0.3094 -1.5085,1.3132 -0.2354,-0.2704 0.0717,0.0743 -4.1232,-3.8503 z m 5.7436,5.7035 2.187,2.5121 1.5663,2.3083 -1.6549,1.123 -1.5326,-2.2585 0.0732,0.0951 -2.1475,-2.4668 z m 4.8763,6.4753 0.3629,0.5348 2.2811,4.6599 0.1233,0.5117 -1.9444,0.4684 -0.0976,-0.4049 0.0741,0.2055 -2.2018,-4.4979 0.0707,0.1218 -0.3232,-0.4763 z m 3.2357,7.6508 0.0649,0.2693 -1.9444,0.4684 -0.0649,-0.2693 z m 2.7866,-1.5294 -2.4916,8.5902 -5.3772,-7.1474 z"
transform="matrix(1,0,0,-1,986,457.194)"
id="path59" />
<path
d="M 0.0162082,-0.999869 6.01542,-0.902619 5.983,1.09712 -0.0162082,0.999869 Z M 8.01516,-0.870203 14.0144,-0.772954 13.9819,1.22678 7.98274,1.12953 Z m 7.99894,0.129666 1.4291,0.023166 4.5969,0.218922 -0.0952,1.997739 -4.5812,-0.21818 0.0314,0.001 -1.4134,-0.02291 z m 8.0237,0.337228 5.9932,0.28542 -0.0951,1.997739 -5.9932,-0.28542 z m 7.991,0.3805605 2.7203,0.1295545 3.2969,0.257334 -0.1557,1.99394 -3.2817,-0.25616 0.0303,0.0019 -2.7053,-0.12883 z M 40.0399,0.519777 46.0217,0.986686 45.8661,2.98062 39.8843,2.51371 Z M 48.0156,1.14232 51.7818,1.43629 54.021,1.67858 53.8059,3.66697 51.5816,3.42629 51.6113,3.42907 47.86,3.13626 Z m 7.9938,0.75141 5.9652,0.64546 -0.2152,1.98839 -5.9652,-0.64545 z m 7.9536,0.86061 4.4428,0.48073 1.5466,0.21469 -0.275,1.98101 -1.5317,-0.21262 0.0299,0.00369 -4.4278,-0.47911 z m 7.9704,0.97042 5.943,0.825 -0.275,1.981 -5.943,-0.82499 z m 7.924,1.09999 4.6281,0.64246 1.3393,0.22871 L 85.4882,7.66738 84.1642,7.4413 84.195,7.44607 79.5824,6.80576 Z M 87.7963,6.03256 93.7107,7.0425 93.374,9.01396 87.4596,8.00403 Z m 7.8858,1.34658 4.2034,0.71775 1.7325,0.35545 -0.401,1.95916 -1.7172,-0.3521 0.0326,0.0062 -4.1869,-0.71499 z m 7.8959,1.47505 5.877,1.20561 -0.402,1.9592 -5.877,-1.2056 z m 7.836,1.60741 3.057,0.6269 2.834,0.6896 -0.473,1.9433 -2.817,-0.6853 0.036,0.0079 -3.038,-0.6232 z m 7.834,1.7894 5.83,1.4186 -0.473,1.9433 -5.83,-1.4187 z m 7.773,1.8915 1.085,0.264 4.733,1.3611 -0.553,1.9221 -4.713,-1.3554 0.04,0.0106 -1.065,-0.2591 z m 7.74,2.1779 5.766,1.6583 -0.553,1.9221 -5.766,-1.6583 z m 7.707,2.3135 5.679,1.9375 -0.646,1.8928 -5.679,-1.9374 z m 7.572,2.5833 1.951,0.6658 3.7,1.5162 -0.759,1.8506 -3.672,-1.5048 0.057,0.0211 -1.923,-0.656 z m 7.501,2.9404 4.424,1.8129 1.151,0.5493 -0.861,1.8051 -1.126,-0.5372 0.052,0.0228 -4.398,-1.8022 z m 7.38,3.2235 1.496,0.7138 3.871,2.0659 -0.942,1.7644 -3.85,-2.0553 0.04,0.0204 -1.476,-0.7042 z m 7.13,3.8317 2.082,1.2578 2.982,2.0725 -1.142,1.6424 -2.956,-2.0544 0.054,0.0348 -2.054,-1.2413 z m 6.687,4.6257 1.439,1.1759 2.398,2.378 0.577,0.9359 -1.703,1.0495 -0.514,-0.8346 0.147,0.1853 -2.279,-2.2602 0.072,0.0643 -1.402,-1.1454 z m 5.463,6.1924 0.25,0.4051 -1.703,1.0494 -0.249,-0.4051 z m 2.666,-1.7349 -0.904,8.8984 -6.576,-6.0624 z"
transform="matrix(1,0,0,-1,986,457.194)"
id="path60" />
<path
d="M 0.0156214,-0.999878 6.01489,-0.90615 5.98365,1.09361 -0.0156214,0.999878 Z M 8.01465,-0.874907 14.0139,-0.781179 13.9827,1.21858 7.9834,1.12485 Z m 7.99905,0.124971 5.9992,0.093728 -0.0312,1.999758 -5.9993,-0.09373 z m 8.0288,0.139661 5.9937,0.275105 -0.0917,1.9979 -5.9937,-0.27511 z m 7.9916,0.366807 5.9937,0.2751062 L 37.9361,2.02953 31.9424,1.75443 Z M 40.0257,0.12334 46.0194,0.398446 45.9277,2.39634 39.934,2.12124 Z m 8.0188,0.399862 5.9831,0.450096 -0.15,1.994362 -5.9831,-0.45009 z M 56.022,1.12333 62.0051,1.57343 61.855,3.56779 55.872,3.1177 Z m 7.9774,0.60013 5.9792,0.4498 0.0326,0.0034 -0.2075,1.98921 -0.0182,-0.00191 0.0287,0.00258 -5.9648,-0.44872 z M 72.0004,2.3841 77.968,3.00644 77.7606,4.99565 71.793,4.37332 Z m 7.9569,0.82979 5.9676,0.62233 -0.2074,1.98922 -5.9677,-0.62234 z m 7.9568,0.82978 4.5303,0.47244 1.4612,0.1955 -0.2652,1.98233 -1.4468,-0.19356 0.0289,0.00344 -4.5158,-0.47094 z m 7.9738,0.93315 5.9471,0.79566 -0.265,1.98233 -5.9473,-0.79565 z m 7.9291,1.06087 5.947,0.79565 -0.265,1.98234 -5.947,-0.79565 z m 7.93,1.06087 2.427,0.32477 3.534,0.58152 -0.325,1.97346 -3.519,-0.57908 0.03,0.00444 -2.413,-0.32277 z m 7.934,1.23107 5.92,0.97434 -0.324,1.97343 -5.921,-0.9743 z m 7.894,1.29912 5.92,0.97435 -0.325,1.9734 -5.92,-0.9743 z m 7.922,1.32085 5.886,1.1635 -0.387,1.9621 -5.887,-1.1636 z m 7.848,1.5514 5.887,1.1636 -0.388,1.962 -5.886,-1.1636 z m 7.849,1.5514 3.498,0.6916 2.404,0.5638 -0.457,1.9472 -2.387,-0.5598 0.035,0.0074 -3.481,-0.6881 z m 7.849,1.7121 5.842,1.3699 -0.457,1.9472 -5.842,-1.37 z m 7.789,1.8266 5.841,1.37 -0.456,1.9472 -5.842,-1.37 z m 7.806,1.8786 5.798,1.5413 -0.513,1.9329 -5.799,-1.5414 z m 7.748,2.0737 5.764,1.6661 -0.555,1.9214 -5.764,-1.6662 z m 7.708,2.2287 5.723,1.8017 -0.6,1.9077 -5.723,-1.8017 z m 7.631,2.4022 0.212,0.0667 5.489,1.8864 -0.65,1.8914 -5.477,-1.8821 0.025,0.0081 -0.199,-0.0628 z m 7.592,2.6032 0.055,0.0187 5.588,2.1046 -0.705,1.8716 -5.574,-2.0994 0.027,0.0099 -0.041,-0.014 z m 7.539,2.8566 5.542,2.2986 -0.766,1.8474 -5.542,-2.2986 z m 7.401,3.1331 4.585,2.1085 0.887,0.4562 -0.915,1.7785 -0.867,-0.4461 0.04,0.0193 -4.566,-2.0993 z m 7.251,3.4796 2.806,1.4437 2.503,1.4574 -1.007,1.7284 -2.48,-1.4443 0.046,0.025 -2.783,-1.4317 z m 7.037,3.9076 0.695,0.4048 4.359,2.9197 0.07,0.055 -1.237,1.5713 -0.04,-0.0312 0.062,0.0452 -4.301,-2.8808 0.054,0.0333 -0.668,-0.389 z m 6.695,4.617 2.13,1.677 2.447,2.3395 -1.382,1.4456 -2.413,-2.3064 0.073,0.0628 -2.092,-1.6473 z m 6.009,5.5752 1.803,2.1966 1.687,2.8491 -1.721,1.0191 -1.648,-2.7833 0.087,0.1248 -1.754,-2.1376 z m 4.288,7.0798 0.519,1.4444 0.388,3.1989 -0.075,1.6336 -1.998,-0.0921 0.072,-1.5501 0.006,0.1665 -0.364,-3.0041 0.051,0.218 -0.481,-1.3383 z m 0.74,8.2748 -0.276,5.9936 -1.998,-0.092 0.276,-5.9937 z m -0.368,7.9915 -0.021,0.4617 -0.75,5.5753 -1.982,-0.2665 0.744,-5.5319 -0.008,0.0872 0.019,-0.4179 z m -1.038,8.0192 -0.552,4.1104 -0.423,1.8921 -1.952,-0.4361 0.413,-1.85 -0.015,0.0848 0.547,-4.0678 z m -1.411,7.9544 -1.308,5.8556 -1.952,-0.4361 1.308,-5.8556 z m -1.815,7.8602 -1.268,4.2755 -0.529,1.5 -1.886,-0.664 0.52,-1.477 -0.016,0.048 1.262,-4.2515 z m -2.461,7.6625 -0.475,1.35 -1.776,4.269 -1.847,-0.768 1.766,-4.244 -0.02,0.053 0.466,-1.325 z m -3.159,7.473 -1.084,2.197 -1.85,3.12 -1.72,-1.02 1.83,-3.087 -0.036,0.067 1.067,-2.162 z m -4.134,7.032 -1.238,1.698 -3.124,2.632 -1.289,-1.53 3.032,-2.554 -0.163,0.176 1.166,-1.6 z m -2.038,6.56 -8.945,-0.04 5.399,-7.131 z"
transform="matrix(1,0,0,-1,986,457.214)"
id="path61" />
<path
d="M 422,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path62" />
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="388.06406"
overflow="hidden"
version="1.1"
id="svg41"
sodipodi:docname="calibration_1_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="3.0072355"
inkscape:cx="479.8427"
inkscape:cy="194.03203"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g41" />
<defs
id="defs41">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath41">
<rect
style="fill:none"
id="rect42"
width="1006.0705"
height="388.06406"
x="-19.692804"
y="26.643204" />
</clipPath>
</defs>
<g
id="g41"
clip-path="url(#clipPath41)"
transform="translate(0,-26.643204)">
<rect
x="0"
y="0"
width="960"
height="480"
fill="#ffffff"
id="rect1" />
<rect
x="81"
y="206"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect2" />
<rect
x="96"
y="231"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(133.202,254)"
id="text3">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(119.448,270)"
id="text4">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(126.448,286)"
id="text5">weight</text>
<rect
x="96"
y="313"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f7cbcb"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(130.702,325)"
id="text6">Initial</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(112.115,341)"
id="text7">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(125.035,357)"
id="text8">factors</text>
<rect
x="240"
y="243"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(268.651,262)"
id="text9">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(258.151,278)"
id="text10">Weight</text>
<rect
x="344"
y="187"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(373.003,205)"
id="text11">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(368.583,221)"
id="text12">Input</text>
<rect
x="334"
y="104"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(370.67,127)"
id="text13">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(356.917,143)"
id="text14">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(368.837,159)"
id="text15">input</text>
<rect
x="344"
y="243"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect15" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(373,262)"
id="text16">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(364.333,278)"
id="text17">GEMM</text>
<path
d="M 0.015735,-0.999876 34.0184,-0.464776 33.987,1.53498 -0.015735,0.999876 Z M 32.7325,-3.48538 40.6686,0.64 32.6066,4.51362 Z"
transform="matrix(1,0,0,-1,199,266.64)"
id="path17" />
<path
d="m 322,265 h 15.791 v 2 H 322 Z m 14.458,-3 8,4 -8,4 z"
id="path18" />
<path
d="m 386,231 v 5.349 h -2 V 231 Z m 3,4.016 -4,8 -4,-8 z"
id="path19" />
<path
d="m 386,175 v 5.349 h -2 V 175 Z m 3,4.016 -4,8 -4,-8 z"
id="path20" />
<path
d="m 302.844,247 4.063,5.911 -1.676,0.697 5.118,5.065 -1.676,0.846 6.327,8.481 -10.73,-6.499 2.047,-0.903 -6.667,-4.163 2.389,-1.286 -7.039,-4.367 z"
fill="#ff0000"
fill-rule="evenodd"
id="path21" />
<path
d="m 408.452,190 3.86,5.63 -1.592,0.664 4.862,4.824 -1.592,0.805 6.01,8.077 -10.193,-6.19 1.944,-0.859 -6.333,-3.965 2.269,-1.225 -6.687,-4.159 z"
fill="#ff0000"
fill-rule="evenodd"
id="path22" />
<path
d="m 409.452,247 3.86,5.63 -1.592,0.664 4.862,4.824 -1.592,0.805 6.01,8.077 -10.193,-6.19 1.944,-0.859 -6.333,-3.965 2.269,-1.225 -6.687,-4.159 z"
fill="#ff0000"
fill-rule="evenodd"
id="path23" />
<path
d="M 0.0369111,-0.999319 6.03282,-0.777852 5.959,1.22079 -0.0369111,0.999319 Z M 8.10061,-0.673792 14.0656,-0.0265058 13.8498,1.96182 7.88485,1.31454 Z M 16.1141,0.247765 22.0214,1.2984 21.6712,3.2675 15.7639,2.21686 Z M 24.0392,1.73603 29.8647,3.17232 29.386,5.11418 23.5604,3.67789 Z m 7.8008,2.03501 5.5037,1.74085 0.2724,0.10597 -0.7252,1.86389 -0.2422,-0.09424 0.061,0.0215 -5.4729,-1.73109 z m 7.64,2.57203 4.6694,1.81681 0.9564,0.44699 -0.8468,1.81193 -0.9264,-0.43303 0.0608,0.02603 -4.6387,-1.80484 z m 7.4377,3.1106 3.679,1.71943 1.7562,0.9737 -0.9698,1.7492 -1.726,-0.957 0.0615,0.0313 -3.6477,-1.7047 z m 7.1844,3.66303 2.5248,1.3999 2.6656,1.7467 -1.0962,1.6728 -2.6347,-1.7264 0.0632,0.0381 -2.4926,-1.382 z m 6.8632,4.2428 1.2157,0.7966 3.6532,2.8401 -1.2275,1.579 -3.6214,-2.8153 0.0657,0.0469 -1.1819,-0.7745 z m 6.4982,4.9445 4.161,3.8856 0.2694,0.3094 -1.5085,1.3132 -0.2354,-0.2704 0.0717,0.0743 -4.1232,-3.8503 z m 5.7436,5.7035 2.187,2.5121 1.5663,2.3083 -1.6549,1.123 -1.5326,-2.2585 0.0732,0.0951 -2.1475,-2.4668 z m 4.8763,6.4753 0.3629,0.5348 2.2811,4.6599 0.1233,0.5117 -1.9444,0.4684 -0.0976,-0.4049 0.0741,0.2055 -2.2018,-4.4979 0.0707,0.1218 -0.3232,-0.4763 z m 3.2357,7.6508 0.0649,0.2693 -1.9444,0.4684 -0.0649,-0.2693 z m 2.7866,-1.5294 -2.4916,8.5902 -5.3772,-7.1474 z"
transform="matrix(1,0,0,-1,199,337.194)"
id="path24" />
<path
d="M 0.0162082,-0.999869 6.01542,-0.902619 5.983,1.09712 -0.0162082,0.999869 Z M 8.01516,-0.870203 14.0144,-0.772953 13.982,1.22678 7.98274,1.12953 Z m 7.99894,0.129666 1.4291,0.023166 4.5969,0.218922 -0.0951,1.997739 -4.5813,-0.21818 0.0314,0.001 -1.4134,-0.02291 z m 8.0237,0.337229 5.9932,0.28542 -0.0951,1.997738 -5.9932,-0.28542 z m 7.991,0.3805603 2.7203,0.1295537 3.2969,0.257336 -0.1557,1.993938 -3.2817,-0.25616 0.0302,0.0019 -2.7052,-0.12883 z M 40.0399,0.519778 46.0217,0.986688 45.8661,2.98062 39.8843,2.51371 Z M 48.0156,1.14232 51.7818,1.43629 54.021,1.67858 53.8059,3.66698 51.5815,3.42629 51.6113,3.42907 47.86,3.13626 Z m 7.9938,0.75141 5.9652,0.64546 -0.2152,1.9884 -5.9651,-0.64546 z m 7.9536,0.86061 4.4428,0.48073 1.5466,0.2147 -0.275,1.981 -1.5317,-0.21262 0.0299,0.00369 -4.4278,-0.4791 z m 7.9704,0.97043 5.943,0.82499 -0.275,1.98101 -5.943,-0.825 z m 7.924,1.09999 4.628,0.64245 1.3394,0.22872 L 85.4882,7.66739 84.1641,7.4413 84.195,7.44607 79.5824,6.80576 Z M 87.7963,6.03257 93.7107,7.0425 93.374,9.01397 87.4596,8.00403 Z m 7.8858,1.34658 4.2033,0.71774 1.7326,0.35546 -0.401,1.95925 -1.7172,-0.3522 0.0326,0.0062 -4.1869,-0.71499 z m 7.8959,1.47505 5.877,1.2056 -0.402,1.9592 -5.877,-1.2056 z m 7.836,1.6074 3.057,0.6269 2.834,0.6896 -0.473,1.9433 -2.817,-0.6853 0.036,0.0079 -3.038,-0.6232 z m 7.834,1.7894 5.83,1.4186 -0.473,1.9433 -5.83,-1.4186 z m 7.773,1.8915 1.085,0.264 4.733,1.3611 -0.553,1.9221 -4.713,-1.3554 0.04,0.0106 -1.065,-0.2591 z m 7.74,2.1779 5.766,1.6583 -0.553,1.9221 -5.766,-1.6583 z m 7.707,2.3135 5.679,1.9375 -0.646,1.8929 -5.679,-1.9375 z m 7.572,2.5833 1.951,0.6658 3.7,1.5162 -0.759,1.8506 -3.672,-1.5048 0.057,0.0211 -1.923,-0.656 z m 7.501,2.9405 4.424,1.8128 1.151,0.5493 -0.861,1.8051 -1.126,-0.5372 0.051,0.0228 -4.397,-1.8022 z m 7.38,3.2234 1.496,0.7138 3.871,2.0659 -0.942,1.7644 -3.85,-2.0553 0.04,0.0204 -1.476,-0.7041 z m 7.13,3.8317 2.082,1.2578 2.982,2.0726 -1.142,1.6423 -2.956,-2.0544 0.054,0.0348 -2.054,-1.2413 z m 6.687,4.6258 1.439,1.1758 2.398,2.378 0.577,0.936 -1.703,1.0494 -0.514,-0.8346 0.147,0.1853 -2.279,-2.2602 0.072,0.0643 -1.402,-1.1454 z m 5.463,6.1923 0.25,0.4051 -1.703,1.0494 -0.249,-0.405 z m 2.666,-1.7349 -0.904,8.8984 -6.576,-6.0625 z"
transform="matrix(1,0,0,-1,199,337.194)"
id="path25" />
<path
d="M 0.0156214,-0.999878 6.01489,-0.90615 5.98365,1.09361 -0.0156214,0.999878 Z M 8.01465,-0.874907 14.0139,-0.781179 13.9827,1.21858 7.9834,1.12485 Z m 7.99905,0.124971 5.9992,0.093728 -0.0312,1.999758 -5.9993,-0.09373 z m 8.0288,0.139661 5.9937,0.275105 -0.0917,1.9979 -5.9937,-0.27511 z m 7.9916,0.366807 5.9937,0.2751062 L 37.9361,2.02953 31.9424,1.75443 Z M 40.0257,0.12334 46.0194,0.398446 45.9277,2.39634 39.934,2.12124 Z m 8.0188,0.399862 5.9831,0.450096 -0.15,1.994362 -5.9831,-0.45009 z M 56.022,1.12333 62.0051,1.57343 61.855,3.56779 55.872,3.1177 Z m 7.9774,0.60013 5.9792,0.4498 0.0326,0.0034 -0.2075,1.98921 -0.0182,-0.00191 0.0287,0.00258 -5.9648,-0.44872 z M 72.0004,2.3841 77.968,3.00644 77.7606,4.99565 71.793,4.37332 Z m 7.9569,0.82979 5.9676,0.62233 -0.2074,1.98922 -5.9677,-0.62234 z m 7.9568,0.82978 4.5303,0.47244 1.4612,0.1955 -0.2652,1.98233 -1.4468,-0.19356 0.0289,0.00344 -4.5158,-0.47094 z m 7.9738,0.93315 5.9471,0.79566 -0.265,1.98233 -5.9473,-0.79565 z m 7.9291,1.06087 5.947,0.79565 -0.265,1.98234 -5.947,-0.79565 z m 7.93,1.06087 2.427,0.32477 3.534,0.58152 -0.325,1.97346 -3.519,-0.57908 0.03,0.00444 -2.413,-0.32277 z m 7.934,1.23107 5.92,0.97434 -0.324,1.97343 -5.921,-0.9743 z m 7.894,1.29912 5.92,0.97435 -0.325,1.9734 -5.92,-0.9743 z m 7.922,1.32085 5.886,1.1635 -0.387,1.9621 -5.887,-1.1636 z m 7.848,1.5514 5.887,1.1636 -0.388,1.962 -5.886,-1.1636 z m 7.849,1.5514 3.498,0.6916 2.404,0.5638 -0.457,1.9472 -2.387,-0.5598 0.035,0.0074 -3.481,-0.6881 z m 7.849,1.7121 5.842,1.3699 -0.457,1.9472 -5.842,-1.37 z m 7.789,1.8266 5.841,1.37 -0.456,1.9472 -5.842,-1.37 z m 7.806,1.8786 5.798,1.5413 -0.513,1.9329 -5.799,-1.5414 z m 7.748,2.0737 5.764,1.6661 -0.555,1.9214 -5.764,-1.6662 z m 7.708,2.2287 5.723,1.8017 -0.6,1.9077 -5.723,-1.8017 z m 7.631,2.4022 0.212,0.0667 5.489,1.8864 -0.65,1.8914 -5.477,-1.8821 0.025,0.0081 -0.199,-0.0628 z m 7.592,2.6032 0.055,0.0187 5.588,2.1046 -0.705,1.8716 -5.574,-2.0994 0.027,0.0099 -0.041,-0.014 z m 7.539,2.8566 5.542,2.2986 -0.766,1.8474 -5.542,-2.2986 z m 7.401,3.1331 4.585,2.1085 0.887,0.4562 -0.915,1.7785 -0.867,-0.4461 0.04,0.0193 -4.566,-2.0993 z m 7.251,3.4796 2.806,1.4437 2.503,1.4574 -1.007,1.7284 -2.48,-1.4443 0.046,0.025 -2.783,-1.4317 z m 7.037,3.9076 0.695,0.4048 4.359,2.9197 0.07,0.055 -1.237,1.5713 -0.04,-0.0312 0.062,0.0452 -4.301,-2.8808 0.054,0.0333 -0.668,-0.389 z m 6.695,4.617 2.13,1.677 2.447,2.3395 -1.382,1.4456 -2.413,-2.3064 0.073,0.0628 -2.092,-1.6473 z m 6.009,5.5752 1.803,2.1966 1.687,2.8491 -1.721,1.0191 -1.648,-2.7833 0.087,0.1248 -1.754,-2.1376 z m 4.288,7.0798 0.519,1.4444 0.388,3.1989 -0.075,1.6336 -1.998,-0.0921 0.072,-1.5501 0.006,0.1665 -0.364,-3.0041 0.051,0.218 -0.481,-1.3383 z m 0.74,8.2748 -0.276,5.9936 -1.998,-0.092 0.276,-5.9937 z m -0.368,7.9915 -0.021,0.4617 -0.75,5.5753 -1.982,-0.2665 0.744,-5.5319 -0.008,0.0872 0.019,-0.4179 z m -1.038,8.0192 -0.552,4.1104 -0.423,1.8921 -1.952,-0.4361 0.413,-1.85 -0.015,0.0848 0.547,-4.0678 z m -1.411,7.9544 -1.308,5.8556 -1.952,-0.4361 1.308,-5.8556 z m -1.815,7.8602 -1.268,4.2755 -0.529,1.5 -1.886,-0.664 0.52,-1.477 -0.016,0.048 1.262,-4.2515 z m -2.461,7.6625 -0.475,1.35 -1.776,4.269 -1.847,-0.768 1.766,-4.244 -0.02,0.053 0.466,-1.325 z m -3.159,7.473 -1.084,2.197 -1.85,3.12 -1.72,-1.02 1.83,-3.087 -0.036,0.067 1.067,-2.162 z m -4.134,7.032 -1.238,1.698 -3.124,2.632 -1.289,-1.53 3.032,-2.554 -0.163,0.176 1.166,-1.6 z m -2.038,6.56 -8.945,-0.04 5.399,-7.131 z"
transform="matrix(1,0,0,-1,199,337.214)"
id="path26" />
<rect
x="518"
y="206"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect26" />
<rect
x="533"
y="231"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect27" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(569.617,254)"
id="text27">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(555.863,270)"
id="text28">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(562.863,286)"
id="text29">weight</text>
<rect
x="533"
y="313"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#ffffff"
id="rect29" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(550.28,333)"
id="text30">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(561.447,349)"
id="text31">factors</text>
<rect
x="735"
y="104"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect31" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(772.407,127)"
id="text32">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(758.653,143)"
id="text33">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(770.573,159)"
id="text34">input</text>
<rect
x="735"
y="231"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect34" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(772.407,254)"
id="text35">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(758.653,270)"
id="text36">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(766.073,286)"
id="text37">GEMM</text>
<path
d="M 0.00641402,-0.999979 93.1192,-0.402739 93.1064,1.59722 -0.00641402,0.999979 Z M 91.8051,-3.41123 99.7793,0.64 91.7538,4.5886 Z"
transform="matrix(1,0,0,-1,636,266.64)"
id="path37" />
<path
d="m 788,175 v 49.395 h -2 V 175 Z m 3,48.061 -4,8 -4,-8 z"
id="path38" />
<path
d="m 788.277,301.127 -0.235,1.83 -0.725,1.912 -1.157,1.852 -0.337,0.389 -1.512,-1.309 0.286,-0.33 -0.092,0.124 1.064,-1.702 -0.087,0.175 0.649,-1.71 -0.057,0.228 0.219,-1.713 z m -3.865,7.565 -1.743,1.579 -2.308,1.723 -0.836,0.493 -1.017,-1.722 0.789,-0.465 -0.09,0.059 2.228,-1.663 -0.073,0.061 1.707,-1.547 z m -6.609,4.813 -3.055,1.804 -2.319,1.069 -0.837,-1.817 2.274,-1.047 -0.091,0.047 3.011,-1.778 z m -7.191,3.709 -2.795,1.287 -2.806,1.05 -0.701,-1.873 2.772,-1.037 -0.068,0.028 2.762,-1.272 z m -7.474,3.039 -3.452,1.291 -2.263,0.706 -0.595,-1.91 2.237,-0.697 -0.053,0.019 3.425,-1.282 z m -7.625,2.592 -5.053,1.575 -0.727,0.191 -0.509,-1.935 0.706,-0.185 -0.044,0.013 5.032,-1.568 z m -7.715,2.274 -5.803,1.524 -0.508,-1.934 5.803,-1.524 z m -7.776,2.034 -5.857,1.301 -0.434,-1.952 5.857,-1.302 z m -7.81,1.735 -3.065,0.682 -2.844,0.533 -0.368,-1.966 2.827,-0.53 -0.033,0.007 3.049,-0.678 z m -7.875,1.583 -5.897,1.105 -0.368,-1.966 5.897,-1.105 z m -7.897,1.453 -5.929,0.925 -0.308,-1.976 5.929,-0.925 z m -7.905,1.233 -3.789,0.591 -2.176,0.276 -0.252,-1.984 2.163,-0.274 -0.029,0.004 3.775,-0.589 z m -7.949,1.118 -5.952,0.755 -0.252,-1.984 5.952,-0.755 z m -7.937,1.007 -0.987,0.125 -5.008,0.495 -0.197,-1.99 4.995,-0.494 -0.028,0.003 0.974,-0.123 z m -7.985,0.817 -5.971,0.59 -0.197,-1.991 5.971,-0.59 z m -7.992,0.749 -5.985,0.427 -0.142,-1.995 5.985,-0.427 z m -7.98,0.569 -4.417,0.315 -1.598,0.069 -0.087,-1.998 1.585,-0.069 -0.028,0.002 4.403,-0.314 z m -8.013,0.471 -5.994,0.261 -0.087,-1.998 5.994,-0.261 z m -7.992,0.348 -2.493,0.109 -3.533,0.052 -0.03,-2 3.519,-0.052 -0.029,10e-4 2.479,-0.108 z m -8.026,0.19 -1.98,0.03 -0.03,-2 1.98,-0.029 z m -0.603,3.01 -8.058,-3.881 7.94,-4.118 z"
id="path39" />
<path
d="M 479,50 V 393.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path40" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(105.552,72)"
id="text40">FP8 with initial scaling factors</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
id="text41"
x="641.10864"
y="71.334938">Calibration</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="366.05447"
overflow="hidden"
version="1.1"
id="svg36"
sodipodi:docname="calibration_2_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="3.1880501"
inkscape:cx="479.91718"
inkscape:cy="183.02724"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g36" />
<defs
id="defs36">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath36">
<rect
style="fill:none"
id="rect37"
width="993.32819"
height="366.05447"
x="-13.900802"
y="40.544006"
ry="36.489601" />
</clipPath>
</defs>
<g
id="g36"
clip-path="url(#clipPath36)"
transform="translate(0,-40.544006)">
<rect
x="0"
y="0"
width="960"
height="480"
fill="#ffffff"
id="rect1" />
<path
d="M 446,56 V 399.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
id="text1"
x="194.95409"
y="88">Calibration</text>
<rect
x="87"
y="211"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect2" />
<rect
x="102"
y="236"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(138.558,260)"
id="text3">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(124.805,276)"
id="text4">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(131.805,292)"
id="text5">weight</text>
<rect
x="102"
y="319"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#ffffff"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(119.222,339)"
id="text6">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(130.388,355)"
id="text7">factors</text>
<rect
x="304"
y="109"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect7" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(341.349,132)"
id="text8">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(327.595,148)"
id="text9">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(339.515,164)"
id="text10">input</text>
<rect
x="304"
y="236"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(341.348,259)"
id="text11">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(327.595,275)"
id="text12">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(335.015,291)"
id="text13">GEMM</text>
<path
d="M 0.00641402,-0.999979 93.1192,-0.402739 93.1064,1.59722 -0.00641402,0.999979 Z M 91.8051,-3.41123 99.7793,0.64 91.7538,4.5886 Z"
transform="matrix(1,0,0,-1,205,271.64)"
id="path13" />
<path
d="m 357,180 v 49.395 h -2 V 180 Z m 3,48.061 -4,8 -4,-8 z"
id="path14" />
<path
d="m 357.277,306.127 -0.235,1.83 -0.725,1.912 -1.157,1.852 -0.337,0.389 -1.512,-1.309 0.286,-0.33 -0.092,0.124 1.064,-1.702 -0.087,0.175 0.649,-1.71 -0.057,0.228 0.219,-1.713 z m -3.865,7.565 -1.743,1.579 -2.308,1.723 -0.836,0.493 -1.017,-1.722 0.789,-0.465 -0.09,0.059 2.228,-1.663 -0.073,0.061 1.707,-1.547 z m -6.609,4.813 -3.055,1.804 -2.319,1.069 -0.837,-1.817 2.274,-1.047 -0.091,0.047 3.011,-1.778 z m -7.191,3.709 -2.795,1.287 -2.806,1.05 -0.701,-1.873 2.772,-1.037 -0.068,0.028 2.761,-1.272 z m -7.474,3.039 -3.452,1.291 -2.263,0.706 -0.595,-1.91 2.237,-0.697 -0.053,0.019 3.425,-1.282 z m -7.625,2.592 -5.053,1.575 -0.727,0.191 -0.509,-1.935 0.706,-0.185 -0.044,0.013 5.032,-1.568 z m -7.715,2.274 -5.803,1.524 -0.508,-1.934 5.803,-1.524 z m -7.776,2.034 -5.857,1.301 -0.434,-1.952 5.857,-1.302 z m -7.81,1.735 -3.065,0.682 -2.844,0.533 -0.368,-1.966 2.827,-0.53 -0.033,0.007 3.049,-0.678 z m -7.875,1.583 -5.897,1.105 -0.368,-1.966 5.897,-1.105 z m -7.897,1.453 -5.929,0.925 -0.308,-1.976 5.929,-0.925 z m -7.905,1.233 -3.789,0.591 -2.176,0.276 -0.252,-1.984 2.163,-0.274 -0.029,0.004 3.775,-0.589 z m -7.949,1.118 -5.953,0.755 -0.251,-1.984 5.952,-0.755 z m -7.937,1.007 -0.987,0.125 -5.008,0.495 -0.197,-1.99 4.995,-0.494 -0.028,0.003 0.974,-0.123 z m -7.985,0.817 -5.971,0.59 -0.197,-1.991 5.971,-0.59 z m -7.992,0.749 -5.985,0.427 -0.142,-1.995 5.985,-0.427 z m -7.98,0.569 -4.417,0.315 -1.598,0.069 -0.087,-1.998 1.585,-0.069 -0.028,0.002 4.403,-0.314 z m -8.013,0.471 -5.994,0.261 -0.087,-1.998 5.994,-0.261 z m -7.992,0.348 -2.493,0.109 -3.533,0.052 -0.03,-2 3.519,-0.052 -0.029,10e-4 2.479,-0.108 z m -8.026,0.19 -1.98,0.03 -0.03,-2 1.98,-0.029 z m -0.603,3.01 -8.058,-3.881 7.94,-4.118 z"
id="path15" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(500.235,88)"
id="text15">FP8 with calibrated scaling factors</text>
<rect
x="493"
y="211"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect15" />
<rect
x="508"
y="236"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect16" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(545.009,260)"
id="text16">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(531.255,276)"
id="text17">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(538.255,292)"
id="text18">weight</text>
<rect
x="508"
y="319"
width="103"
height="48"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#92d050"
id="rect18" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(527.509,331)"
id="text19">Calibrated</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(523.922,347)"
id="text20">FP8 scaling</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(536.842,363)"
id="text21">factors</text>
<rect
x="652"
y="249"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect21" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(680.458,267)"
id="text22">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(669.958,283)"
id="text23">Weight</text>
<rect
x="756"
y="192"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(784.81,210)"
id="text24">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(780.39,226)"
id="text25">Input</text>
<rect
x="745"
y="109"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect25" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(782.477,132)"
id="text26">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(768.723,148)"
id="text27">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(780.643,164)"
id="text28">input</text>
<rect
x="756"
y="249"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect28" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(784.807,267)"
id="text29">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
transform="translate(776.14,283)"
id="text30">GEMM</text>
<path
d="M 0.015735,-0.999876 34.0184,-0.464776 33.987,1.53498 -0.015735,0.999876 Z M 32.7325,-3.48538 40.6686,0.64 32.6066,4.51362 Z"
transform="matrix(1,0,0,-1,611,271.64)"
id="path30" />
<path
d="m 734,270 h 15.791 v 2 H 734 Z m 14.458,-3 8,4 -8,4 z"
id="path31" />
<path
d="m 798,237 v 5.349 h -2 V 237 Z m 3,4.016 -4,8 -4,-8 z"
id="path32" />
<path
d="m 798,180 v 5.349 h -2 V 180 Z m 3,4.016 -4,8 -4,-8 z"
id="path33" />
<path
d="M 0.0369111,-0.999319 6.03282,-0.777852 5.959,1.22079 -0.0369111,0.999319 Z M 8.10061,-0.673792 14.0656,-0.0265058 13.8498,1.96182 7.88485,1.31454 Z M 16.1141,0.247765 22.0214,1.2984 21.6712,3.2675 15.7639,2.21686 Z M 24.0392,1.73603 29.8647,3.17232 29.386,5.11418 23.5604,3.67789 Z m 7.8008,2.03501 5.5037,1.74085 0.2724,0.10597 -0.7252,1.86389 -0.2422,-0.09424 0.061,0.0215 -5.4729,-1.73109 z m 7.64,2.57203 4.6694,1.81681 0.9564,0.44699 -0.8468,1.81193 -0.9264,-0.43303 0.0608,0.02603 -4.6387,-1.80484 z m 7.4377,3.1106 3.679,1.71943 1.7562,0.9737 -0.9698,1.7492 -1.726,-0.957 0.0615,0.0313 -3.6477,-1.7047 z m 7.1844,3.66303 2.5248,1.3999 2.6656,1.7467 -1.0962,1.6728 -2.6347,-1.7264 0.0632,0.0381 -2.4926,-1.382 z m 6.8632,4.2428 1.2157,0.7966 3.6532,2.8401 -1.2275,1.579 -3.6214,-2.8153 0.0657,0.0469 -1.1819,-0.7745 z m 6.4982,4.9445 4.161,3.8856 0.2694,0.3094 -1.5085,1.3132 -0.2354,-0.2704 0.0717,0.0743 -4.1232,-3.8503 z m 5.7436,5.7035 2.187,2.5121 1.5663,2.3083 -1.6549,1.123 -1.5326,-2.2585 0.0732,0.0951 -2.1475,-2.4668 z m 4.8763,6.4753 0.3629,0.5348 2.2811,4.6599 0.1233,0.5117 -1.9444,0.4684 -0.0976,-0.4049 0.0741,0.2055 -2.2018,-4.4979 0.0707,0.1218 -0.3232,-0.4763 z m 3.2357,7.6508 0.0649,0.2693 -1.9444,0.4684 -0.0649,-0.2693 z m 2.7866,-1.5294 -2.4916,8.5902 -5.3772,-7.1474 z"
transform="matrix(1,0,0,-1,611,342.194)"
id="path34" />
<path
d="M 0.0162082,-0.999869 6.01542,-0.902619 5.983,1.09712 -0.0162082,0.999869 Z M 8.01516,-0.870203 14.0144,-0.772953 13.982,1.22678 7.98274,1.12953 Z m 7.99894,0.129666 1.4291,0.023166 4.5969,0.218922 -0.0951,1.997739 -4.5813,-0.21818 0.0314,0.001 -1.4134,-0.02291 z m 8.0237,0.337229 5.9932,0.28542 -0.0951,1.997738 -5.9932,-0.28542 z m 7.991,0.3805603 2.7203,0.1295537 3.2969,0.257336 -0.1557,1.993938 -3.2817,-0.25616 0.0302,0.0019 -2.7052,-0.12883 z M 40.0399,0.519778 46.0217,0.986688 45.8661,2.98062 39.8843,2.51371 Z M 48.0156,1.14232 51.7818,1.43629 54.021,1.67858 53.8059,3.66698 51.5815,3.42629 51.6113,3.42907 47.86,3.13626 Z m 7.9938,0.75141 5.9652,0.64546 -0.2152,1.9884 -5.9651,-0.64546 z m 7.9536,0.86061 4.4428,0.48073 1.5466,0.2147 -0.275,1.981 -1.5317,-0.21262 0.0299,0.00369 -4.4278,-0.4791 z m 7.9704,0.97043 5.943,0.82499 -0.275,1.98101 -5.943,-0.825 z m 7.924,1.09999 4.628,0.64245 1.3394,0.22872 L 85.4882,7.66739 84.1641,7.4413 84.195,7.44607 79.5824,6.80576 Z M 87.7963,6.03257 93.7107,7.0425 93.374,9.01397 87.4596,8.00403 Z m 7.8858,1.34658 4.2033,0.71774 1.7326,0.35546 -0.401,1.95925 -1.7172,-0.3522 0.0326,0.0062 -4.1869,-0.71499 z m 7.8959,1.47505 5.877,1.2056 -0.402,1.9592 -5.877,-1.2056 z m 7.836,1.6074 3.057,0.6269 2.834,0.6896 -0.473,1.9433 -2.817,-0.6853 0.036,0.0079 -3.038,-0.6232 z m 7.834,1.7894 5.83,1.4186 -0.473,1.9433 -5.83,-1.4186 z m 7.773,1.8915 1.085,0.264 4.733,1.3611 -0.553,1.9221 -4.713,-1.3554 0.04,0.0106 -1.065,-0.2591 z m 7.74,2.1779 5.766,1.6583 -0.553,1.9221 -5.766,-1.6583 z m 7.707,2.3135 5.679,1.9375 -0.646,1.8929 -5.679,-1.9375 z m 7.572,2.5833 1.951,0.6658 3.7,1.5162 -0.759,1.8506 -3.672,-1.5048 0.057,0.0211 -1.923,-0.656 z m 7.501,2.9405 4.424,1.8128 1.151,0.5493 -0.861,1.8051 -1.126,-0.5372 0.051,0.0228 -4.397,-1.8022 z m 7.38,3.2234 1.496,0.7138 3.871,2.0659 -0.942,1.7644 -3.85,-2.0553 0.04,0.0204 -1.476,-0.7041 z m 7.13,3.8317 2.082,1.2578 2.982,2.0726 -1.142,1.6423 -2.956,-2.0544 0.054,0.0348 -2.054,-1.2413 z m 6.687,4.6258 1.439,1.1758 2.398,2.378 0.577,0.936 -1.703,1.0494 -0.514,-0.8346 0.147,0.1853 -2.279,-2.2602 0.072,0.0643 -1.402,-1.1454 z m 5.463,6.1923 0.25,0.4051 -1.703,1.0494 -0.249,-0.405 z m 2.666,-1.7349 -0.904,8.8984 -6.576,-6.0625 z"
transform="matrix(1,0,0,-1,611,342.194)"
id="path35" />
<path
d="M 0.0156214,-0.999878 6.01489,-0.90615 5.98365,1.09361 -0.0156214,0.999878 Z M 8.01465,-0.874907 14.0139,-0.781179 13.9827,1.21858 7.9834,1.12485 Z m 7.99905,0.124971 5.9992,0.093728 -0.0312,1.999758 -5.9993,-0.09373 z m 8.0288,0.139661 5.9937,0.275105 -0.0917,1.9979 -5.9937,-0.27511 z m 7.9916,0.366807 5.9937,0.2751062 L 37.9361,2.02953 31.9424,1.75443 Z M 40.0257,0.12334 46.0194,0.398446 45.9277,2.39634 39.934,2.12124 Z m 8.0188,0.399862 5.9831,0.450096 -0.15,1.994362 -5.9831,-0.45009 z M 56.022,1.12333 62.0051,1.57343 61.855,3.56779 55.872,3.1177 Z m 7.9774,0.60013 5.9792,0.4498 0.0326,0.0034 -0.2075,1.98921 -0.0182,-0.00191 0.0287,0.00258 -5.9648,-0.44872 z M 72.0004,2.3841 77.968,3.00644 77.7606,4.99565 71.793,4.37332 Z m 7.9569,0.82979 5.9676,0.62233 -0.2074,1.98922 -5.9677,-0.62234 z m 7.9568,0.82978 4.5303,0.47244 1.4612,0.1955 -0.2652,1.98233 -1.4468,-0.19356 0.0289,0.00344 -4.5158,-0.47094 z m 7.9738,0.93315 5.9471,0.79566 -0.265,1.98233 -5.9473,-0.79565 z m 7.9291,1.06087 5.947,0.79565 -0.265,1.98234 -5.947,-0.79565 z m 7.93,1.06087 2.427,0.32477 3.534,0.58152 -0.325,1.97346 -3.519,-0.57908 0.03,0.00444 -2.413,-0.32277 z m 7.934,1.23107 5.92,0.97434 -0.324,1.97343 -5.921,-0.9743 z m 7.894,1.29912 5.92,0.97435 -0.325,1.9734 -5.92,-0.9743 z m 7.922,1.32085 5.886,1.1635 -0.387,1.9621 -5.887,-1.1636 z m 7.848,1.5514 5.887,1.1636 -0.388,1.962 -5.886,-1.1636 z m 7.849,1.5514 3.498,0.6916 2.404,0.5638 -0.457,1.9472 -2.387,-0.5598 0.035,0.0074 -3.481,-0.6881 z m 7.849,1.7121 5.842,1.3699 -0.457,1.9472 -5.842,-1.37 z m 7.789,1.8266 5.841,1.37 -0.456,1.9472 -5.842,-1.37 z m 7.806,1.8786 5.798,1.5413 -0.513,1.9329 -5.799,-1.5414 z m 7.748,2.0737 5.764,1.6661 -0.555,1.9214 -5.764,-1.6662 z m 7.708,2.2287 5.723,1.8017 -0.6,1.9077 -5.723,-1.8017 z m 7.631,2.4022 0.212,0.0667 5.489,1.8864 -0.65,1.8914 -5.477,-1.8821 0.025,0.0081 -0.199,-0.0628 z m 7.592,2.6032 0.055,0.0187 5.588,2.1046 -0.705,1.8716 -5.574,-2.0994 0.027,0.0099 -0.041,-0.014 z m 7.539,2.8566 5.542,2.2986 -0.766,1.8474 -5.542,-2.2986 z m 7.401,3.1331 4.585,2.1085 0.887,0.4562 -0.915,1.7785 -0.867,-0.4461 0.04,0.0193 -4.566,-2.0993 z m 7.251,3.4796 2.806,1.4437 2.503,1.4574 -1.007,1.7284 -2.48,-1.4443 0.046,0.025 -2.783,-1.4317 z m 7.037,3.9076 0.695,0.4048 4.359,2.9197 0.07,0.055 -1.237,1.5713 -0.04,-0.0312 0.062,0.0452 -4.301,-2.8808 0.054,0.0333 -0.668,-0.389 z m 6.695,4.617 2.13,1.677 2.447,2.3395 -1.382,1.4456 -2.413,-2.3064 0.073,0.0628 -2.092,-1.6473 z m 6.009,5.5752 1.803,2.1966 1.687,2.8491 -1.721,1.0191 -1.648,-2.7833 0.087,0.1248 -1.754,-2.1376 z m 4.288,7.0798 0.519,1.4444 0.388,3.1989 -0.075,1.6336 -1.998,-0.0921 0.072,-1.5501 0.006,0.1665 -0.364,-3.0041 0.051,0.218 -0.481,-1.3383 z m 0.74,8.2748 -0.276,5.9936 -1.998,-0.092 0.276,-5.9937 z m -0.368,7.9915 -0.021,0.4617 -0.75,5.5753 -1.982,-0.2665 0.744,-5.5319 -0.008,0.0872 0.019,-0.4179 z m -1.038,8.0192 -0.552,4.1104 -0.423,1.8921 -1.952,-0.4361 0.413,-1.85 -0.015,0.0848 0.547,-4.0678 z m -1.411,7.9544 -1.308,5.8556 -1.952,-0.4361 1.308,-5.8556 z m -1.815,7.8602 -1.268,4.2755 -0.529,1.5 -1.886,-0.664 0.52,-1.477 -0.016,0.048 1.262,-4.2515 z m -2.461,7.6625 -0.475,1.35 -1.776,4.269 -1.847,-0.768 1.766,-4.244 -0.02,0.053 0.466,-1.325 z m -3.159,7.473 -1.084,2.197 -1.85,3.12 -1.72,-1.02 1.83,-3.087 -0.036,0.067 1.067,-2.162 z m -4.134,7.032 -1.238,1.698 -3.124,2.632 -1.289,-1.53 3.032,-2.554 -0.163,0.176 1.166,-1.6 z m -2.038,6.56 -8.945,-0.04 5.399,-7.131 z"
transform="matrix(1,0,0,-1,611,342.214)"
id="path36" />
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="1280"
height="379.66562"
overflow="hidden"
version="1.1"
id="svg31"
sodipodi:docname="fp8_model_init.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="1.8208"
inkscape:cx="685.41302"
inkscape:cy="184.80888"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g31" />
<defs
id="defs31">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath31">
<rect
style="fill:none"
id="rect32"
width="1390.9491"
height="379.66562"
x="-54.734409"
y="146.82722"
ry="36.489601" />
</clipPath>
</defs>
<g
id="g31"
clip-path="url(#clipPath31)"
transform="translate(0,-146.82722)">
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text1"
x="153.29384"
y="195.21265">FP32/BF16</text>
<path
d="M 821,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text2"
x="616.69165"
y="194.66344">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text3"
x="908.73199"
y="193.56503">FP8 with fp8_model_init()</text>
<rect
x="868"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect3" />
<rect
x="882.45081"
y="381.1239"
width="101"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="920.40778"
y="400.1239">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="911.3208"
y="416.1239">weight</text>
<rect
x="1078.4508"
y="381.1239"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="1107.5007"
y="400.1239">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="1098.8308"
y="416.1239">GEMM</text>
<path
d="m 983.45079,403.1239 h 89.04001 v 2 h -89.04001 z m 87.71001,-3 8,4 -8,4 z"
id="path7" />
<path
d="M 422,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path9" />
<rect
x="54"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect9" />
<rect
x="67.45079"
y="367.47629"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text10"
x="104.84079"
y="390.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text11"
x="91.087494"
y="406.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text12"
x="98.087494"
y="422.47629">weight</text>
<rect
x="270.45081"
y="240.47627"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text13"
x="307.6308"
y="263.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text14"
x="293.87778"
y="279.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text15"
x="305.79779"
y="295.47629">input</text>
<rect
x="270.45081"
y="367.47629"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect15" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text16"
x="307.6308"
y="390.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text17"
x="293.87778"
y="406.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text18"
x="301.29779"
y="422.47629">GEMM</text>
<path
d="m 170.4572,404.11625 93.11279,-0.59724 -0.0128,-1.99996 -93.11281,0.59725 z m 91.79869,2.41125 7.9742,-4.05123 -8.0255,-3.9486 z"
id="path18" />
<path
d="m 323.45079,311.47627 v 49.395 h -2 v -49.395 z m 3,48.061 -4,8 -4,-8 z"
id="path19" />
<rect
x="447"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="460.90158"
y="368.57471"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="497.76358"
y="392.57471">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="484.01059"
y="408.57471">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="491.01059"
y="424.57471">weight</text>
<rect
x="604.90161"
y="381.57471"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23"
x="633.21356"
y="399.57471">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="622.71356"
y="415.57471">Weight</text>
<g
id="g33"
transform="translate(70.847981,7.139719)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="708.90161"
y="381.57471"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="737.56158"
y="399.57471">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="728.89557"
y="415.57471">GEMM</text>
<path
d="m 563.91732,405.21457 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 685.90158,402.57469 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 750.90158,284.49209 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 751.17135,355.90367 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="701.05359"
y="215.25253"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="738.23358"
y="238.25253">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="724.48059"
y="254.25255">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="736.40057"
y="270.25253">input</text>
<g
id="g33-9"
transform="translate(441.10986,7.0509646)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2-5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7-4"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6-3"
x="662.06604"
y="336.96341">Input</text>
</g>
<path
d="m 1121.1635,284.40334 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-1"
style="stroke-width:0.53037" />
<path
d="m 1121.4332,355.81492 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2-2"
style="stroke-width:0.53037" />
<rect
x="1071.3154"
y="215.16379"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23-3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29-3"
x="1108.4955"
y="238.16379">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32-4"
x="1094.7424"
y="254.1638">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33-1"
x="1106.6625"
y="270.16376">input</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="373.58408"
overflow="hidden"
version="1.1"
id="svg23"
sodipodi:docname="fp8_model_init_1_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="3.1237948"
inkscape:cx="479.86506"
inkscape:cy="186.79204"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g23" />
<defs
id="defs23">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath23">
<rect
style="fill:none"
id="rect24"
width="997.38257"
height="373.58408"
x="-11.584002"
y="41.702408"
ry="36.489601" />
</clipPath>
</defs>
<g
id="g23"
clip-path="url(#clipPath23)"
transform="translate(0,-41.702408)">
<rect
x="0"
y="0"
width="960"
height="480"
fill="#ffffff"
id="rect1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(195.4,93)"
id="text1">FP32/BF16</text>
<path
d="M 461,61 V 404.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<rect
x="92"
y="217"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect2" />
<rect
x="105.07926"
y="266.32938"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text3"
x="142.27226"
y="289.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="128.51926"
y="305.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="135.51926"
y="321.32938">weight</text>
<rect
x="308.07925"
y="138.32938"
width="103"
height="72"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="345.06326"
y="162.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="331.31027"
y="178.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text8"
x="343.23026"
y="194.32938">input</text>
<rect
x="308.07925"
y="266.32938"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text9"
x="345.06326"
y="289.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text10"
x="331.30927"
y="305.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text11"
x="338.72925"
y="321.32938">GEMM</text>
<path
d="m 208.08567,302.96936 93.11279,-0.59724 -0.0128,-1.99996 -93.11281,0.59724 z m 91.79869,2.41125 7.9742,-4.05123 -8.0255,-3.9486 z"
id="path11" />
<path
d="m 360.07926,210.32938 v 49.395 h -2 v -49.395 z m 3,48.061 -4,8 -4,-8 z"
id="path12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(645.181,91)"
id="text23">FP8</text>
<rect
x="495.63504"
y="222.57803"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="509.53662"
y="265.15271"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="546.39862"
y="289.15274">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="532.64563"
y="305.15274">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="539.64563"
y="321.15274">weight</text>
<rect
x="653.53668"
y="278.15274"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-3"
x="681.84863"
y="296.15274">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="671.34863"
y="312.15274">Weight</text>
<g
id="g33"
transform="translate(119.48305,-96.282252)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="757.53668"
y="278.15274"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="786.19666"
y="296.15274">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="777.53064"
y="312.15274">GEMM</text>
<path
d="m 612.55239,301.7926 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 734.53665,299.15272 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 799.53665,181.07012 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 799.80642,252.4817 v 21.98469 h -2 V 252.4817 Z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="749.68866"
y="111.83057"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="786.86865"
y="134.83058">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="773.11566"
y="150.83058">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="785.03564"
y="166.83057">input</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="379.95526"
overflow="hidden"
version="1.1"
id="svg19"
sodipodi:docname="fp8_model_init_2_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="2.1718178"
inkscape:cx="502.34416"
inkscape:cy="194.07705"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="svg19" />
<defs
id="defs19">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath20">
<rect
style="fill:none"
id="rect21"
width="1014.7587"
height="379.95526"
x="-21.430403"
y="44.598408" />
</clipPath>
</defs>
<g
id="g19"
clip-path="url(#clipPath20)"
transform="translate(-76.837568,-52.086815)" />
<path
d="M 434.81331,26.957307 V 370.26931"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text2"
x="216.69165"
y="33.663437">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text3"
x="508.73199"
y="32.565033">FP8 with fp8_model_init()</text>
<rect
x="481.81332"
y="182.95731"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect3" />
<rect
x="496.26413"
y="238.08121"
width="101"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="534.22107"
y="257.08121">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="525.13409"
y="273.08121">weight</text>
<rect
x="692.2641"
y="238.08121"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="721.31403"
y="257.08121">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="712.6441"
y="273.08121">GEMM</text>
<path
d="m 597.2641,260.08121 h 89.04001 v 2 H 597.2641 Z m 87.71001,-3 8,4 -8,4 z"
id="path7" />
<rect
x="60.813313"
y="182.95731"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="74.714897"
y="225.53201"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="111.5769"
y="249.53201">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="97.823906"
y="265.53201">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="104.82391"
y="281.53201">weight</text>
<rect
x="218.71492"
y="238.53201"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23"
x="247.02687"
y="256.53201">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="236.52687"
y="272.53201">Weight</text>
<g
id="g33"
transform="translate(-315.33871,-135.90297)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="322.71494"
y="238.53201"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="351.37491"
y="256.53201">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="342.70889"
y="272.53201">GEMM</text>
<path
d="m 177.73063,262.17188 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 299.71489,259.532 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 364.71489,141.4494 v 21.98469 h -2 V 141.4494 Z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 364.98466,212.86098 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="314.86691"
y="72.209839"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="352.04691"
y="95.209839">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="338.29391"
y="111.20985">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="350.2139"
y="127.20984">input</text>
<g
id="g33-9"
transform="translate(54.923173,-135.99173)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2-5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7-4"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6-3"
x="662.06604"
y="336.96341">Input</text>
</g>
<path
d="m 734.97681,141.36065 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-1"
style="stroke-width:0.53037" />
<path
d="m 735.24651,212.77223 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2-2"
style="stroke-width:0.53037" />
<rect
x="685.12872"
y="72.121094"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23-3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29-3"
x="722.30878"
y="95.121094">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32-4"
x="708.55573"
y="111.12111">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33-1"
x="720.47577"
y="127.12106">input</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="1280"
height="303.21127"
overflow="hidden"
version="1.1"
id="svg12"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs12">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath16">
<rect
style="fill:none;stroke-width:0.96471"
id="rect16"
width="1344.0338"
height="303.21124"
x="-32.356411"
y="174.8833" />
</clipPath>
</defs>
<g
id="g12"
transform="translate(1.1556091e-7,-174.8833)"
clip-path="url(#clipPath16)">
<rect
x="0"
y="0"
width="1280"
height="720"
fill="#ffffff"
id="rect1" />
<path
d="M 645,209 V 446.818"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
transform="translate(201.111,246)"
id="text1">Without CUDA Graphs</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
transform="translate(855.749,246)"
id="text2">With CUDA Graphs</text>
<rect
x="64"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(75.6135,349)"
id="text3">Launch 1</text>
<rect
x="155"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(169.288,401)"
id="text4">Kernel 1</text>
<rect
x="245"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(256.462,349)"
id="text5">Launch 2</text>
<rect
x="336"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(350.136,401)"
id="text6">Kernel 2</text>
<rect
x="426"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect6" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(437.31,349)"
id="text7">Launch 3</text>
<rect
x="517"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect7" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(530.984,401)"
id="text8">Kernel 3</text>
<path
d="m 47,368 h 574.291 v 4 H 47 Z m 572.291,-4 12,6 -12,6 z"
id="path8" />
<rect
x="680"
y="319"
width="145"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(694.058,349)"
id="text9">Launch Graph 1</text>
<rect
x="830"
y="370"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect9" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(844.463,400)"
id="text10">Kernel 1</text>
<rect
x="924"
y="370"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(938.451,400)"
id="text11">Kernel 2</text>
<rect
x="1018"
y="370"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect11" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(1032.44,400)"
id="text12">Kernel 3</text>
<path
d="m 663,368 h 574.29 v 4 H 663 Z m 572.29,-4 12,6 -12,6 z"
id="path12" />
</g>
</svg>
transformers==4.55.0
accelerate==1.10.0
datasets==4.0.0
sentencepiece==0.2.1
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
from contextlib import contextmanager
from typing import Optional
from functools import partial
from collections import OrderedDict
import torch
from torch.amp import autocast
import transformer_engine as te
from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding
from transformer_engine.common.recipe import Format, DelayedScaling
from transformer_engine.pytorch.fp8 import get_default_fp8_recipe
import transformers
from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel
import torch.nn.functional as F
"""
Top level description of the classes used in the tutorial from this file.
----------------------------------------------------------------------
HuggingFace Gemma Model implementation hierarchy:
----------------------------------
GemmaDecoderLayer:
├── self_attn:
│ ├── norm: (nn.LayerNorm)
│ ├── qkv_proj: (nn.Linear)
│ ├── attention: (SDPA, FlashAttention, etc.)
│ └── o_proj: (nn.Linear)
├── ffn:
│ ├── norm: (nn.LayerNorm)
│ ├── gate_proj: (nn.Linear)
│ ├── up_proj: (nn.Linear)
│ └── down_proj: (nn.Linear)
GemmaModel:
├── embed_tokens : Token embedding layer
├── layers : GemmaDecoderLayer × N
├── norm : GemmaRMSNorm
└── rotary_emb : GemmaRotaryEmbedding
GemmaForCausalLM:
├── model : instance of GemmaModel
├── lm_head : (nn.Linear) hidden states to vocabulary logits for generation
└── generate : generate method (input prompt -> GemmaForCausalLM -> next tokens)
How `generate()` works in HF's GemmaForCausalLM:
1. prefill (input prompt -> model -> lm_head -> logits -> next token)
2. loop until max_new_tokens:
- next token -> model -> lm_head -> logits -> next token
3. return all tokens
NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method.
This is a common pattern in HF models.
TransformerEngine's Gemma Model Hierarchy:
----------------------------------------
HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way,
while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying
blocks of "transformer layer" are actually from TransformerEngine.
TEGemmaDecoderLayer (inherits from te.TransformerLayer):
├── te.MultiHeadAttention:
│ ├── linear_qkv: (te.LayerNormLinear)
│ ├── attention: (te.DotProductAttention)
│ └── out_proj: (te.LayerNormLinear)
├── te.LayerNormMLP:
│ ├── fc1: (te.LayerNormLinear)
│ ├── fc2: (te.Linear)
│ └── activation: (te.GeGLU)
To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which
subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods.
TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM)
├─ model : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N
├─ lm_head : directly inherited from HF's GemmaForCausalLM
├─ te_rope_emb : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility)
├─ hidden_states_buffer : shape [b, max_ctx, h] (static)
├─ generation_buffer : shape [b, 1, h] (view of `hidden_states_buffer`) (static)
├─ inference_params : TransformerEngine KV cache
├─ model_context_phase : GemmaModelWrapper → uses (model, lm_head, inference_params) for full-sequence prefill
├─ model_generation_phase : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode
└─ generate : generate method (input prompt -> TEGemmaForCausalLM -> next tokens)
Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and
"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed:
TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM)
├─ model : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N)
├─ lm_head : unchanged
├─ hidden_states_buffer : unchanged
├─ generation_buffer : unchanged
├─ inference_params : unchanged
├─ record : utility function to record the graphed callable
├─ model_context_phase : GraphedCallable(for Context/prefill) replaced by `record`
├─ model_generation_phase : GraphedCallable(for Generation) replaced by `record`
└─ generate : unchanged
How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs:
1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token)
2. model_generation_phase:
- loop until max_new_tokens:
- next token -> model -> lm_head -> logits -> next token
3. return all tokens
NOTE: In the tutorial, `record` is called when initializing the model.
Additional notes and clarifications
-----------------------------------
- Wrappers, not submodules:
`model_context_phase` and `model_generation_phase` are convenience wrappers over the same
`model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage,
masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and
KV-cache (`InferenceParams`) flow for TE-optimized inference.
- Buffer relationship:
`hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view
of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates
`generation_buffer` in-place with next-token embeddings.
- Padding policy:
Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end
to match TE attention mask expectations and to keep shapes contiguous for capture/replay.
- CUDA Graphs specifics:
`record()` captures two separate callables (context/prefill and generation) with fixed shapes and
stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the
functional behavior is identical; only allocation/pointer churn and CPU overhead are removed.
"""
class TEGemmaDecoderLayer(te.pytorch.TransformerLayer):
"""
Wrapper class over TE's `TransformerLayer`. This makes the wrapper very
similar to HF's `GemmaDecoderLayer` and easier to replace it in the code.
Args:
config: GemmaConfig
args: positional args (for compatibility with `GemmaDecoderLayer`)
kwargs: keyword args (for compatibility with `GemmaDecoderLayer`)
"""
def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs):
self.gemma_config = config
super().__init__(
hidden_size=config.hidden_size,
ffn_hidden_size=config.intermediate_size,
num_attention_heads=config.num_attention_heads,
bias=False,
layernorm_epsilon=config.rms_norm_eps,
hidden_dropout=0,
attention_dropout=0,
fuse_qkv_params=config.fuse_qkv_params,
normalization="RMSNorm",
activation="geglu",
attn_input_format="bshd",
num_gqa_groups=config.num_key_value_heads,
kv_channels=self.gemma_config.head_dim,
layer_number=(
layer_idx + 1
), # Layer numbers in TE starts from 1, not 0 like in the HF.
zero_centered_gamma=True,
)
def forward(self, *args, **kwargs): # We need to additionally pass positional encoding.
# filter out HF specific args
keys_to_remove = [
"position_ids",
"past_key_value",
"output_attentions",
"use_cache",
"cache_position",
]
for key in keys_to_remove:
kwargs.pop(key, None)
rope_emb = kwargs.pop("rope_emb", None)
# Return tuple to be compatible with HF.
return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),)
class GemmaModelWrapper(torch.nn.Module):
"""
Encapsulates the HuggingFace GemmaModel class as a wrapper whose
forward pass is compatible with CUDA Graphs.
"""
def __init__(
self,
model: GemmaModel,
dtype: torch.dtype,
lm_head: torch.nn.Module,
):
super().__init__()
self.model = model
self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype)
self.lm_head = lm_head
def set_inference_params(self, inference_params):
self.inference_params = inference_params
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor = None,
attn_mask_type: str = "arbitrary",
rope_emb: torch.Tensor = None,
):
with torch.no_grad():
# static operation - for CUDA graphs
hidden_states.data[:] = hidden_states.data[:] * self.normalizer
for i, decoder_layer in enumerate(self.model.layers):
hidden_states.data[:] = decoder_layer(
hidden_states,
attention_mask=attention_mask,
self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type,
inference_params=self.inference_params,
rope_emb=rope_emb,
)[
0
] # static copy - for CUDA graphs
hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs
logits = self.lm_head(hidden_states)
# This is not needed for generation but is needed for training
# or finetuning.
if self.training:
logits = logits.float()
return logits
class GemmaGenerationWrapper(torch.nn.Module):
"""
Gets token embeddings for a batch of single tokens, runs forward pass, and
returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a
subclass of `GemmaModel` since the model layers are simply reused here.
"""
def __init__(
self,
model: GemmaModel,
lm_head: torch.nn.Module,
dtype: torch.dtype,
):
super().__init__()
self.model = model
self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head)
def set_inference_params(self, inference_params):
self.inference_params = inference_params
self.gemma_layers.set_inference_params(inference_params)
def forward(
self,
hidden_states: torch.Tensor,
mask: torch.Tensor = None,
attn_mask_type: str = "arbitrary",
rope_emb: torch.Tensor = None,
):
logits = self.gemma_layers(
hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb
)
assert logits.shape[0] == hidden_states.shape[0] # b
assert logits.shape[1] == hidden_states.shape[1] # seq_len
# Fetch the logits for the last token
logits = logits[:, -1, :]
next_tokens = torch.argmax(logits, dim=1)
# static copy for CUDA graphs
hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1))
return next_tokens
@contextmanager
def replace_decoder(te_decoder_cls):
"""
Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer`
class.
"""
original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls
try:
yield
finally:
transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls
class TEGemmaForCausalLM(GemmaForCausalLM):
"""
Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer`
class is monkey-patched with `TEGemmaDecoderLayer` class before
initializing the causal LM with `GemmaForCausalLM`.
Args:
config: Gemma model config that HF uses to initialize the model.
"""
def __init__(self, config: GemmaConfig):
dtype = torch.bfloat16
with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer):
super().__init__(config)
self.config = config
self.to(dtype).cuda()
self.hidden_size = config.hidden_size
self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head)
self._model_generation_phase = GemmaGenerationWrapper(
lm_head=self.lm_head,
model=self.model,
dtype=dtype,
)
if self.config.fp8:
self.fp8_recipe = get_default_fp8_recipe()
# Rotary position embedding remains the same for all the layers and so
# created here. This makes it compatible with CUDA Graphs too.
self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)(
max_seq_len=self.config.max_position_embeddings
).cuda()
@staticmethod
def _padding_to_end(inputs, lengths, max_seq_len=None):
"""
Gets the tensor with sequence padded from the beginning and
updates it inplace to be padded from its end.
Parameters
----------
inputs : Tensor, tensor with shape [b, s] containing token numbers.
It's padded from the beggining.
lengths: Tensor, tensor with shape [s] with lengths of the sequences.
"""
max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len
batch_size, max_seq_len = inputs.shape
new_input_ids = inputs.clone()
for i in range(batch_size):
new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len]
new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])]
# Trim the inputs to no extra padding i.e. fix the max seq len to
# the longest sequence in the batch
actual_max_seq_len = max_seq_len
inputs.data = new_input_ids[:, :actual_max_seq_len]
def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor):
"""
Returns a tensor of shape [b, s, hd] where `b` is the batch size,
`s` is the sequence length, and `hd` is the hidden size.
This function is overriden in TEGemmaForCausalLMCudaGraphs.
"""
tensor = torch.empty(
(input_ids.shape[0], input_ids.shape[1], self.hidden_size),
device="cuda",
dtype=torch.float32,
)
return tensor
def _create_or_fetch_inference_params(self, *args, **kwargs):
"""
Creates an InferenceParams object.
This function is overriden in TEGemmaForCausalLMCudaGraphs.
"""
infer_params = InferenceParams(*args, **kwargs)
return infer_params
def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None):
"""
Returns a tensor of shape [b, 1, hd] where `b` is the batch size,
`hd` is the hidden size.
The buffer for generation is some part (beginning) of hidden states buffer.
This function returns pointer to it and also copies there data if provided.
"""
# hidden_states_buffer has shape [b, s, hd]
# generation_buffer will have shape [b, 1, hd]
# Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return
# uncontiguous buffer, which we want to avoid.
output = hidden_states_buffer.view(-1)[
: hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2]
]
if data_to_copy is not None:
output.copy_(data_to_copy.reshape(-1))
generation_buffer = output.view(
(hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2])
)
return generation_buffer
def setup_and_run_context_phase(
self, input_ids: torch.Tensor, inference_params: InferenceParams
):
"""
Runs the context or prefill phase of the model.
This function is overriden in TEGemmaForCausalLMCudaGraphs.
"""
hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids)
hidden_states.copy_(self.model.embed_tokens(input_ids))
# Update offsets before every forward pass (including context/prefill
# phase) to make cache work properly.
lengths = input_ids.ne(0).sum(dim=1)
inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist())))
logits = self._model_context_phase(
hidden_states,
attention_mask=None,
attn_mask_type="padding_causal",
rope_emb=self.te_rope_emb,
)
logits = logits[torch.arange(logits.size(0)), lengths - 1, :]
next_tokens = torch.argmax(logits, dim=1)
# `self.hidden_states` has shape [b, s, hd].
# Return hidden state for the last token - output has shape [b, 1, hd].
hidden_states = self._get_generation_buffer(
hidden_states, self.model.embed_tokens(next_tokens)
)
return hidden_states, next_tokens
@torch.no_grad()
def generate(
self,
input_ids: Optional[torch.Tensor] = None,
pad_token_id: int = 0,
max_new_tokens: int = 0,
*args,
**kwargs,
):
"""
Generates next tokens auto-regressively for a batch of input tokens.
"""
self.eval()
# Both autocasts are needed: FP8 for operations that can run in lower
# precision and BF16 for those that cannot.
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast(
enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None
):
lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze()
# If padding is at the beginning, then shift it to the end
TEGemmaForCausalLM._padding_to_end(
input_ids,
lengths,
max_seq_len=(
self.config.cuda_graphs_static_max_context_len
if self.config.generation_cuda_graphs
else None
),
)
batch_size = input_ids.shape[0]
# For benchmark generation run, this is being set explicitly.
max_input_sequence_len = self.config.max_seq_length
# InferenceParams is a cache, where keys and values of previous
# tokens are stored. Moreover it stores the current running lengths
# of the sequences in the current batch.
# A helper function is used to create the inference params object
# because this `generate` method is common for TEGemmaForCausalLM
# and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this
# function is overriden to simply return the inference params object
# that is already created in TEGemmaForCausalLMCudaGraphs'
# constructor.
inference_params = self._create_or_fetch_inference_params(
max_batch_size=batch_size,
max_sequence_length=max_input_sequence_len,
num_heads_kv=self.config.num_key_value_heads,
head_dim_v=self.config.head_dim,
head_dim_k=self.config.head_dim,
dtype=torch.bfloat16,
is_paged=self.config.is_paged,
page_size=16,
total_num_pages=batch_size * max_input_sequence_len // 16,
)
# Set the inference params for both the context/prefill phase and
# generation phase objects.
self._model_context_phase.set_inference_params(inference_params)
self._model_generation_phase.set_inference_params(inference_params)
# Context/prefill phase.
hidden_states, next_tokens = self.setup_and_run_context_phase(
input_ids, inference_params
)
# Generation phase.
lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
inference_params.pre_step(
OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
)
output_tokens = [next_tokens]
for _ in range(max_new_tokens):
next_tokens = self._model_generation_phase(
hidden_states,
mask=None,
attn_mask_type="padding",
rope_emb=self.te_rope_emb,
)
# Increase sequence offsets by one because we generated one token
# for every sequence.
lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int)
inference_params.pre_step(
OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist()))
)
# `next_tokens` is a static output tensor, so we need to clone
# it because it gets changed every iteration.
output_tokens.append(next_tokens.clone())
result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1)
return result
def forward(self, *args, **kwargs):
"""
Forward pass for the model. This is used in calibration step when
forward pass is needed to generate FP8 calibration data.
"""
self._model_context_phase.set_inference_params(None)
hidden_states = self.model.embed_tokens(kwargs["input_ids"])
logits = self._model_context_phase(
hidden_states,
attention_mask=(
kwargs["input_ids"] == 0
), # Hardcoded, this only applies to bshd/sbhd layouts.
attn_mask_type="padding_causal",
)
return logits
class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM):
"""
TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM
and uses CUDA Graphs to speed up the generation process. We need to make one
trade-off - batch_size, max_seq_len and max_context_seq_len need to
be static. It is necessary to run generation without changing the pointer
to the variables that are recorded in the graph.
"""
def __init__(self, config: GemmaConfig):
super().__init__(config)
self.config = config
# Preparation of the static buffer to hold the hidden states that are
# passed from one layer to the next.
self.hidden_states_buffer = torch.empty(
(
self.config.cuda_graphs_static_batch_size,
self.config.cuda_graphs_static_max_context_len,
self.config.hidden_size,
)
).cuda()
# This is in fact part of the buffer for hidden_states. Refer to the
# `_get_generation_buffer` function for more details.
self.generation_buffer = self._get_generation_buffer(
self.hidden_states_buffer,
)
# InferenceParams contains the keys and values cache. Refer to the
# original call in TEGemmaForCausalLM's `generate` method for more
# details.
self.inference_params = InferenceParams(
max_batch_size=self.config.cuda_graphs_static_batch_size,
max_sequence_length=self.config.cuda_graphs_static_max_context_len,
num_heads_kv=self.config.num_key_value_heads,
head_dim_v=self.config.head_dim,
head_dim_k=self.config.head_dim,
dtype=torch.bfloat16,
is_paged=self.config.is_paged,
page_size=16,
total_num_pages=self.config.cuda_graphs_static_batch_size
* self.config.cuda_graphs_static_max_context_len
// 16,
)
self._model_generation_phase.set_inference_params(self.inference_params)
self._model_context_phase.set_inference_params(self.inference_params)
def record(self):
"""
Here "the trick" happens. `_model_context_phase` and
`_model_generation_phase` from TEGemmaForCausalLM are replaced with
their recorded version. Once the graphs are recorded, they can be
replayed with minimal usage of CPU and that leads to speedup.
"""
# Record the model with training=False, because it will be used in
# generation.
self.eval()
# Setup the recording for context/prefill phase.
input_shape = (
self.config.cuda_graphs_static_batch_size,
self.config.cuda_graphs_static_max_context_len,
)
# Hardcoded value for the context length.
lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to(
device="cuda", dtype=torch.int32
)
self.inference_params.pre_step(
OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
)
# Record the graph for context/prefill phase.
self._model_context_phase = self.record_graph(
self._model_context_phase,
self.hidden_states_buffer,
attn_mask_type="padding_causal",
rope_emb=self.te_rope_emb,
)
# Setup the recording for generation phase.
input_shape = (self.config.cuda_graphs_static_batch_size, 1)
lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32)
self.inference_params.pre_step(
OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))
)
# Record the graph for generation phase.
self._model_generation_phase = self.record_graph(
self._model_generation_phase,
self.generation_buffer,
attn_mask_type="padding",
rope_emb=self.te_rope_emb,
)
def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs):
"""
Overriden to make `hidden_states` static i.e. not change its pointer
in memory between every invocation.
Returns the static buffer for `hidden states` which is already created
in the constructor. This is the same buffer as used in the
context/prefill phase.
"""
return self.hidden_states_buffer
def _create_or_fetch_inference_params(self, *args, **kwargs):
"""
Overriden to make `inference_params` static i.e. not change its pointer
in memory between every invocation.
Returns the static buffer for `inference_params` which is already created
in the constructor.
"""
self.inference_params.reset()
return self.inference_params
@torch.no_grad()
def record_graph(self, function, input_tensor, **sample_kwargs):
"""
Records the graph for the given function. The function is invoked on
argument (self.hidden_states,) and all kernels are recorded.
It then returns the captured callable, which can be run later while
minimizing CPU usage.
"""
fp8_recipe = get_default_fp8_recipe()
# We need both autocasts: FP8 for operations that can run in lower
# precision and BF16 for those that cannot.
with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False):
graphed_function = te.pytorch.make_graphed_callables(
function,
(input_tensor,),
fp8_enabled=self.config.fp8,
fp8_recipe=fp8_recipe,
allow_unused_input=True,
num_warmup_iters=5,
sample_kwargs=sample_kwargs,
)
return graphed_function
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import re
import gc
import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
"""
This file contains logic of mapping the HuggingFace GemmaModel parameters
with TransformerEngine TransformerLayer. When we have initialized Transformer models
both with HF and with TE, we can copy parameters from the first to the second.
"""
def _load_weights_for_fp8_model(vanilla_model, hyperparams):
"""
Loads weights and FP8 metadata from a calibrated weights file.
The weights are in BF16 precision, but the state dict also contains
fp8 metadata computed by the calibration procedure.
"""
fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename)
# A hack to remove the extra state from the fp8_metadata_sd
# that contains the extra state from the core_attention module.
fp8_metadata_sd = {
k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k
}
vanilla_model.load_state_dict(
fp8_metadata_sd,
strict=False,
# Because some parameters have multiple pointers to the same weight
# vanilla_model._model_context_phase.model and
# vanilla_model._model_generation_phase.model we need to load the
# weights in a non-strict manner.
)
def _load_weights_for_standard_model(vanilla_model, config):
"""
Loads weights from the HuggingFace checkpoint.
"""
archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json")
resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file)
total_dict = {}
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
total_dict.update(state_dict)
replace_params(
total_dict,
vanilla_model.state_dict(),
config,
qkv_fused_and_interleaved=config.fuse_qkv_params,
)
# Copy remaining parameters like embedding.
vanilla_model.load_state_dict(total_dict, strict=False)
# Force mem release. Taken from huggingface code.
del total_dict
gc.collect()
def load_te_model(cls, config):
"""
Loads the TE model with proper weights.
"""
# Force the dtype to bfloat16 while loading the model.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo:
https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
# Copy proper weights into the model. If loading weights with FP8 metadata,
# then the source weights are basically the same as the weights in the model.
# If not, then we need to load the weights from the HuggingFace checkpoint
# and do mapping of the weight names from HF to the TE model.
if config.fp8_model_weights_filename is not None:
_load_weights_for_fp8_model(vanilla_model, config)
else:
_load_weights_for_standard_model(vanilla_model, config)
# Restore the original dtype.
torch.set_default_dtype(old_dtype)
return vanilla_model
def _get_all_layer_prefixes_to_update(hf_state_dict):
"""
There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
This function extracts all strings like "model.layers.[number]."
that are starting strings of keys in hf_state_dict.
"""
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
return all_layer_prefixes
def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
"""
Replaces params from TE TransformerLayer state_dict with corresponding parameters
from HuggingFace GemmaModel state_dict.
"""
all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
for layer_prefix in all_layer_prefixes:
def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
te_state_dict[layer_prefix + te_name].data[start:end].copy_(
hf_state_dict[layer_prefix + hf_name]
)
copy_from_ht_to_te(
"self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
)
copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
)
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
)
if qkv_fused_and_interleaved:
"""
When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
in TE TransformerLayer. Moreover they are interleaved within each head.
Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
Then TE stores weight tensor in the form:
[q1 k1 v1 q2 k2 v2 ...]
This is done to maximally optimize performance time.
"""
te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
def copy_interleave(hf_name, idx):
src = hf_state_dict[layer_prefix + hf_name]
for head_nr in range(config.num_attention_heads):
dst_offset = head_nr * config.head_dim * 3
dst_slice = slice(
dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
)
src_slice = slice(
head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
)
te_qkv_layer[dst_slice, :] = src[src_slice, :]
copy_interleave("self_attn.q_proj.weight", 0)
copy_interleave("self_attn.k_proj.weight", 1)
copy_interleave("self_attn.v_proj.weight", 2)
else:
copy_from_ht_to_te(
"self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
)
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te(
"self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
)
return all_layer_prefixes
{
"cells": [
{
"cell_type": "markdown",
"id": "87e8360b-8d08-44bc-9333-79ba949afe8c",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"# Accelerating Hugging Face Gemma Inference with Transformer Engine"
]
},
{
"cell_type": "markdown",
"id": "2da33092-eef5-46a4-b222-0188cc6e5079",
"metadata": {
"editable": true,
"slideshow": {
"slide_type": ""
},
"tags": []
},
"source": [
"## Introduction\n",
"\n",
"Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/generation_animation.gif\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
"<figcaption>\n",
"Animation 1: Hugging Face Gemma model token generation.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n",
"\n",
"In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n",
"\n",
"This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n",
"\n",
"### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n",
"\n",
"The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n",
"\n",
"\n",
"Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n",
"\n",
"1. Internal fragmentation\n",
"2. External Fragmentation\n",
"\n",
"More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n",
"\n",
"\n",
"Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n",
"\n",
"\n",
"### 2. CUDA Graphs API\n",
"\n",
"The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n",
"\n",
"One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n",
"\n",
"PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/graphs.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\" >\n",
"<figcaption>\n",
"Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### 3. FP8 Scaling Factors Calibration\n",
"\n",
"This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n",
"\n",
"If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n",
"\n",
"It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n",
"\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 2:\n",
"Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### 4. FP8 Model Weights\n",
"\n",
"The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n",
"\n",
"The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init.svg\" style=\"border: 1px solid #000; border-radius: 0;\" alt=\"\">\n",
"<figcaption>\n",
"Figure 3: Model under <b>fp8_autocast()</b> stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using <b>fp8_model_init()</b> results in storing model weights in FP8 by default, which can help with these potential issues.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"### Benchmarking\n",
"\n",
"We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
" \n",
"This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n",
"</div>"
]
},
{
"cell_type": "markdown",
"id": "b18f91a9",
"metadata": {},
"source": [
"## Dependencies for this tutorial"
]
},
{
"cell_type": "markdown",
"id": "e5201d77",
"metadata": {},
"source": [
"The following files and media are necessary to effectively run this tutorial:\n",
"\n",
"1. `te_gemma.py`\n",
" - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n",
"2. `te_gemma_loading_weights.py`\n",
" - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n",
"3. `utils.py`\n",
" - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n",
"4. `requirements.txt`\n",
" - This file contains the necessary Python packages for this tutorial.\n",
"5. `media/`\n",
" - This directory contains the images and other artefacts used in this tutorial."
]
},
{
"cell_type": "markdown",
"id": "36767694-a1c5-4a00-a075-7addc55d8307",
"metadata": {},
"source": [
"### Setup and checks"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf",
"metadata": {},
"outputs": [],
"source": [
"# Uncomment and run this cell when running the tutorial for the first time\n",
"# %pip install -r requirements.txt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "c756ebbd-24c9-4a54-a381-e7c02c555206",
"metadata": {},
"outputs": [],
"source": [
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"import torch\n",
"cudnn_version = torch.backends.cudnn.version()\n",
"assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\""
]
},
{
"cell_type": "markdown",
"id": "e8dfabbf",
"metadata": {},
"source": [
"## [Baseline] Running Hugging Face generation with Gemma model"
]
},
{
"cell_type": "markdown",
"id": "59560bff",
"metadata": {},
"source": [
"HuggingFace Transformers library offers generation API. \n",
"HuggingFace generation for the Gemma model will be used as a baseline."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "2803e0ec",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why GPUs are so good at graphics. The\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n",
"* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 46.60 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.batch_size = 64\n",
"run_config.max_seq_length = 512\n",
"\n",
"model = init_baseline_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "b3698dc6",
"metadata": {},
"source": [
"Let's put this time into the table for later comparison.\n",
"\n",
"| Models | Time | Speedup | \n",
"|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n",
"| HF (baseline) | 46.6 s | - |"
]
},
{
"cell_type": "markdown",
"id": "8bb40f45",
"metadata": {},
"source": [
"## [Optimization 1] Accelerating generation with Transformer Engine "
]
},
{
"cell_type": "markdown",
"id": "263b40f2",
"metadata": {},
"source": [
"Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "9dceef93",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why they are so good at graphics. The second\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
"* NVIDIA is the world leader in AI computing.\n",
"* NVIDIA is the world leader in graphics processing units (GP\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 12.25 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.batch_size = 64\n",
"run_config.max_seq_length = 512\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "b5d40836",
"metadata": {},
"source": [
"With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!"
]
},
{
"cell_type": "markdown",
"id": "006d18e8",
"metadata": {},
"source": [
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |"
]
},
{
"cell_type": "markdown",
"id": "21a89d9c",
"metadata": {},
"source": [
"## [Optimization 2] More acceleration with CUDA Graphs"
]
},
{
"cell_type": "markdown",
"id": "e2d53e7b",
"metadata": {},
"source": [
"Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n",
"```python\n",
" def __init__(self, config : GemmaConfig):\n",
" \"\"\"\n",
" Here \"the trick\" happens. `_model_context_phase` and\n",
" `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n",
" their recorded version. Once the graphs are recorded, they can be\n",
" replayed with minimal usage of CPU and that leads to speedup.\n",
" \"\"\"\n",
" (...)\n",
" # Record the graph for context/prefill phase.\n",
" self._model_context_phase = \n",
" self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n",
"\n",
" (...) \n",
" # Record the graph for generation phase.\n",
" self._model_generation_phase = \n",
" self.record_graph(self._model_generation_phase, self.generation_buffer)\n",
"\n",
" @torch.no_grad()\n",
" def record_graph(self, function, input_tensor):\n",
" \"\"\"\n",
" Records the graph for the given function. The function is invoked on\n",
" argument (self.hidden_states,) and all kernels are recorded.\n",
" It then returns the captured callable, which can be run later while\n",
" minimizing CPU usage.\n",
" \"\"\"\n",
" fp8_recipe = get_default_fp8_recipe()\n",
"\n",
" # We need both autocasts: FP8 for operations that can run in lower\n",
" # precision and BF16 for those that cannot.\n",
" with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n",
" graphed_function = te.pytorch.make_graphed_callables(\n",
" function,\n",
" (input_tensor,),\n",
" fp8_enabled=self.config.fp8,\n",
" fp8_recipe=fp8_recipe,\n",
" allow_unused_input=True,\n",
" num_warmup_iters=5,\n",
" sample_kwargs=sample_kwargs,\n",
" )\n",
" return graphed_function\n",
"```\n",
"\n",
"It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n",
"\n",
"*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "31a3a8a3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing a lot of the same thing at the same time.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"The first fact is why they are so good at graphics. The second\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n",
"* NVIDIA is the world leader in AI computing.\n",
"* NVIDIA is the world leader in graphics processing units (GP\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 6.39 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.max_seq_length = 512\n",
"run_config.batch_size = 64\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# It is necessary to preallocate a static buffer.\n",
"# CUDA graphs require static input tensors for every kernel.\n",
"# This approach may result in a slight increase in memory consumption;\n",
"# however, the substantial speedup achieved makes it worthwhile.\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "53bb430f",
"metadata": {},
"source": [
"A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n",
"\n",
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |"
]
},
{
"cell_type": "markdown",
"id": "0a11b75c",
"metadata": {},
"source": [
"Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n",
"\n",
"1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n",
"2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n",
"3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n",
"4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n",
"\n",
"(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/transformer_cuda_graphed.png\" width=\"80%\" \">\n",
"<figcaption>\n",
" \n",
"Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n",
"</figcaption>\n",
"</figure>\n"
]
},
{
"cell_type": "markdown",
"id": "e6b171a0",
"metadata": {},
"source": [
"## [Optimization 3] Even more acceleration with FP8 precision "
]
},
{
"cell_type": "markdown",
"id": "1a80288b",
"metadata": {},
"source": [
"### Calibrating FP8 scaling factors for correctness\n",
"\n",
"Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n",
"\n",
"1. Model weight tensors\n",
"2. Input tensors\n",
"\n",
"If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n",
"\n",
"To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n",
"\n",
"*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n",
" \n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"\n",
"The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "aecee0e1",
"metadata": {},
"outputs": [],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"import transformer_engine.pytorch as te\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"run_config.fuse_qkv_params = True\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"# Calibration\n",
"with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" model.train()\n",
" run_forward_pass(model, run_config, num_iters=64)\n",
"\n",
"# Compute scale_fwd with enabled fp8 autocast\n",
"with te.fp8_autocast(enabled=True), torch.autocast(\n",
" device_type=\"cuda\", dtype=torch.bfloat16\n",
"):\n",
" run_forward_pass(model, run_config, 1)\n",
"\n",
"# Some parameters are in pointing to the same tensors, double save is avoided here.\n",
"dict_to_save = {\n",
" k: v\n",
" for k, v in model.state_dict().items()\n",
" if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n",
"}\n",
"torch.save(\n",
" dict_to_save, \"calibrated_weights.pth\"\n",
") # <-- Add path to save calibrated weights."
]
},
{
"cell_type": "markdown",
"id": "b6dcd135",
"metadata": {},
"source": [
"### Generation with better FP8 scaling factors\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/calibration_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "a913f54d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing the same thing over and over again.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
"* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
"* NVIDIA is a key player\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 8.73 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.fuse_qkv_params = True # This is needed by the last improvement.\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# CUDA Graphs related config\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"# Enable FP8\n",
"run_config.fp8 = True\n",
"# Calibrated fp8 weights are loaded directly from the file.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "8cdbb56c",
"metadata": {},
"source": [
"One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n",
"\n",
"### Use of FP8-only model weights\n",
"\n",
"Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n",
"\n",
"This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n",
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_1_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n",
"</figcaption>\n",
"</figure>"
]
},
{
"cell_type": "markdown",
"id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d",
"metadata": {},
"source": [
"However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n",
"\n",
"1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n",
"2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n",
"\n",
"\n",
"Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "4562ee82-8c95-4736-8815-cd386078a485",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Memory required for 16384x16384 linear layer: \n",
"FP32 - 1024.0 MB, \n",
"BF16 - 512.0 MB, \n",
"FP8 - 256.0 MB, \n",
"\n",
"Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n",
"Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n",
"Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n"
]
}
],
"source": [
"import torch\n",
"import transformer_engine.pytorch as te\n",
"\n",
"H = 2**14\n",
"D = 2**14\n",
"print(f\"Memory required for {H}x{D} linear layer: \\n\"\n",
" f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n",
" f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n",
" f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n",
"\n",
"linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n",
"print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp32\n",
"\n",
"linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n",
"print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_bf16\n",
"\n",
"# Initialize model weights in FP8 precision\n",
"with torch.no_grad(), te.fp8_model_init(enabled=True):\n",
" linear_fp8 = te.Linear(H, D)\n",
"print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n",
"del linear_fp8"
]
},
{
"cell_type": "markdown",
"id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b",
"metadata": {},
"source": [
"\n",
"<figure align=\"center\">\n",
"<img src=\"./media/fp8_model_init_2_half.svg\" style=\"border: 1px solid #000; border-radius: 0;\">\n",
"<figcaption>\n",
" Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n",
"</figcaption>\n",
"</figure>\n",
"\n",
"Let's run the code with `fp8_model_init`:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "96264b9c",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"============================== Generation example 1 ==============================\n",
"Prompt: \"Here are the two facts about GPUs:\"\n",
"Generated text: \"\n",
"\n",
"1. They are very good at doing the same thing over and over again.\n",
"2. They are very bad at doing different things at the same time.\n",
"\n",
"This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n",
"============================== Generation example 2 ==============================\n",
"Prompt: \"Some facts about NVIDIA:\"\n",
"Generated text: \"\n",
"\n",
"* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n",
"* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n",
"* NVIDIA is a key player\"\n",
"\n",
"================================================================================\n",
"Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n",
"Time: 4.99 s.\n"
]
}
],
"source": [
"# Restart the notebook (to flush the GPU memory)\n",
"from utils import restart_jupyter_notebook\n",
"restart_jupyter_notebook()\n",
"\n",
"# Import necessary packages and methods\n",
"from utils import *\n",
"\n",
"# Provide Huggingface Access Token\n",
"run_config.hf_access_token = \"\"\n",
"assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n",
"run_config.model_name = \"google/gemma-7b\"\n",
"\n",
"# Provide a directory to cache weights in to avoid downloading them every time.\n",
"# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n",
"run_config.weights_cache_dir = \"\"\n",
"\n",
"# Set specific hyperparameters\n",
"# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n",
"run_config.fuse_qkv_params = True # This is needed by the last improvement.\n",
"run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n",
"\n",
"# CUDA Graphs related config\n",
"run_config.generation_cuda_graphs = True\n",
"run_config.cuda_graphs_static_batch_size = 64\n",
"run_config.cuda_graphs_static_max_seq_len = 512\n",
"run_config.cuda_graphs_static_max_context_len = 512\n",
"\n",
"# Enable FP8 math and FP8 model weights\n",
"run_config.fp8 = True\n",
"run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n",
"run_config.fp8_model_weights_filename = (\n",
" \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n",
")\n",
"\n",
"model = init_te_gemma_model(run_config)\n",
"\n",
"print_sample_of_generated_texts(model, run_config)\n",
"benchmark_generation(model, run_config)"
]
},
{
"cell_type": "markdown",
"id": "3e30ca5a",
"metadata": {},
"source": [
"The final speedup is **9.3x**. \n",
"\n",
"| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n",
"|---|---|---|---|---|\n",
"| HF (baseline) | 46.6 s | - | - | - |\n",
"| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n",
"| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |"
]
},
{
"cell_type": "markdown",
"id": "c6e87275",
"metadata": {},
"source": [
"## Conclusions"
]
},
{
"cell_type": "markdown",
"id": "7bb2452d",
"metadata": {},
"source": [
"This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n",
"\n",
"1. Support for KV Caching (both non-paged and paged),\n",
"2. Integration with CUDA Graphs,\n",
"3. FP8 scaling factors calibration,\n",
"4. Keeping model parameters in FP8 precision.\n",
"\n",
"It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n",
"\n",
"1. Longer context lengths (with paged KV cache) \n",
"2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n",
"\n",
"Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import sys
import IPython
import random
import string
from te_gemma_loading_weights import load_te_model
import torch
from torch.utils.data import DataLoader
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
AutoConfig,
)
from transformers import DataCollatorForLanguageModeling
from datasets import load_dataset
from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs
random.seed(42)
torch.manual_seed(42)
class RunConfiguration:
def __init__(self):
self.mixed_precision = "bf16"
self.model_name = None
# FP8 precision settings
self.fp8 = False
self.fp8_model_weights_filename = None
self.fp8_model_init = False
# Cuda graphs
self.generation_cuda_graphs = False
self.cuda_graphs_static_batch_size = 64
self.cuda_graphs_static_max_seq_len = 512
self.cuda_graphs_static_max_context_len = 512
# Finetuning/calibration/generation settings
self.dataset_name = "timdettmers/openassistant-guanaco"
self.dataset_text_field = "text"
self.learning_rate = 1.41e-5
self.batch_size = 64
self.max_seq_length = 512
self.gradient_accumulation_steps = 1
self.num_warmup_steps = 5
self.num_training_steps = 10
# Coalesced QKV params or not
self.fuse_qkv_params = False
# Attention
self.is_paged = False
# This is either provided by the user or it will be set when the
# model weights are downloaded.
self.weights_cache_dir = ""
# Global variable for the run configuration so that it can be easily accessed
# throughout the jupyter notebook with an `import * from utils` statement
run_config = RunConfiguration()
def get_dataloaders(run_config):
"""
Returns a basic dataloader for the dataset which contains tokenized batches
of text.
"""
dataset = load_dataset(run_config.dataset_name, split="train")
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
def tokenize(element):
outputs = tokenizer(
element["text"],
truncation=True,
padding=False,
max_length=run_config.max_seq_length,
return_overflowing_tokens=False,
return_length=False,
)
return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]}
# Tokenize the dataset
dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names)
# Simply pad to the multiple of 16 for both FP8 and BF16 precision
pad_to_multiple_of = 16
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False,
pad_to_multiple_of=pad_to_multiple_of,
)
dataloader_params = {
"batch_size": run_config.batch_size,
"collate_fn": data_collator,
"drop_last": True,
}
train_dataloader = DataLoader(dataset, **dataloader_params)
return train_dataloader
def ensure_model_is_downloaded(run_config):
"""
Downloads and caches the model weights if not already downloaded. A valid
Huggingface Access Token is required to download the model weights.
"""
assert run_config.model_name in [
"google/gemma-7b",
], "Only Gemma 7B model is supported!"
# Login using Huggingface Hub API
from huggingface_hub import login
try:
login(run_config.hf_access_token)
except Exception as e:
if "Invalid token passed!" in str(e):
print(
"Please pass a valid HF Access Token! More info at"
" https://huggingface.co/docs/hub/en/security-tokens."
)
else:
print(f"Exception is {e}")
# Download the model if it doesn't exist
from huggingface_hub import snapshot_download
supplied_cache_dir = (
run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None
)
run_config.weights_cache_dir = snapshot_download(
repo_id=run_config.model_name, cache_dir=supplied_cache_dir
)
def init_baseline_model(run_config):
"""
Initializes a baseline HF Gemma model with the model name provided in
the run_config.
"""
# Download and cache the weights if not already downloaded
ensure_model_is_downloaded(run_config)
# Init the model
config = AutoConfig.from_pretrained(run_config.model_name)
# Make sure to use flash_attention to do iso comparison with TEGemmaModel
config._attn_implementation = "flash_attention_2"
model = AutoModelForCausalLM.from_pretrained(
run_config.model_name,
config=config,
torch_dtype=torch.bfloat16,
).cuda()
return model
def init_te_gemma_model(run_config):
"""
Initializes a Gemma model with `GemmaDecoderLayer`s swapped with
`TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled,
the model is initialized from `TEGemmaForCausalLMCudaGraphs` class.
"""
# Download and cache the weights if not already downloaded
ensure_model_is_downloaded(run_config)
cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM
config = AutoConfig.from_pretrained(run_config.model_name)
# Inject all fields from the `run_config` to the model `config` to make the
# code simpler.
for key, value in run_config.__dict__.items():
setattr(config, key, value)
# Initialize the model and move it to the GPU.
model = load_te_model(cls, config).cuda()
# Record the model if CUDA Graphs are enabled.
if run_config.generation_cuda_graphs:
model.record()
return model
def restart_jupyter_notebook():
# Try restarting the Jupyter kernel
IPython.Application.instance().kernel.do_shutdown(True)
# Check whether the device memory has been flushed
if torch.cuda.memory_allocated() != 0:
import warnings
warnings.warn("The device memory hasn't been flushed, trying with a second method!")
# Try restarting the Jupyter kernel another way
# Restart the kernel
from IPython.core.display import HTML
HTML("<script>Jupyter.notebook.kernel.restart()</script>")
if torch.cuda.memory_allocated() != 0:
print(
"The device memory hasn't been flushed, try manually restarting the Jupyter kernel!"
)
# Suppress the warnings
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
torch.set_warn_always(False)
@torch.no_grad()
def run_forward_pass(model, run_config, num_iters):
"""
Runs the forward pass of the model with sample data. Intended to use for
warmup and/or calibration.
"""
train_dataloader = get_dataloaders(run_config)
model.train()
train_dataloader = enumerate(train_dataloader)
for _ in range(num_iters):
_, batch = next(train_dataloader)
batch["input_ids"] = batch["input_ids"].cuda()
batch["attention_mask"] = batch["attention_mask"].cuda()
model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"])
###############################################################################
# Benchmarking and example generation functions.
###############################################################################
def print_sample_of_generated_texts(model, run_config):
"""
Prints a sample of generated texts from the input model.
"""
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
prompts = [
"Here are the two facts about GPUs:",
"Some facts about NVIDIA:",
"The fundamental theorem of calculus for the layman:",
"A fact about AI:",
]
# Repeat prompts to match batch size
prompts *= run_config.batch_size // len(prompts)
inputs = tokenizer(prompts, return_tensors="pt", padding=True)
max_total_tokens = (
run_config.max_seq_length
if not run_config.generation_cuda_graphs
else run_config.cuda_graphs_static_max_seq_len
)
max_length = inputs["input_ids"].size(1)
new_length = ((max_length + 63) // 64) * max_total_tokens
# Add padding to the left
inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id
)
# Add padding to the left (only intended for baseline generation with HF
# which expects padding to the left)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (new_length - max_length, 0), value=0
)
inputs["input_ids"] = inputs["input_ids"].cuda()
inputs["attention_mask"] = inputs["attention_mask"].cuda()
outputs = model.generate(**inputs, max_new_tokens=50)
generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True)
def print_output(prompts, generated_texts, idx):
print("=" * 30 + f" Generation example {idx+1} " + "=" * 30)
print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"')
print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"')
# Print the output from first two prompts
for i in range(2):
print_output(prompts, generated_texts, i)
def _generate_random_words(num_words, max_word_length):
"""
Generates random words for the benchmark.
"""
words = []
for _ in range(num_words):
word_length = random.randint(1, max_word_length)
word = "".join(random.choices(string.ascii_lowercase, k=word_length))
words.append(word)
return words
def benchmark_generation(model, run_config, context_length=20):
"""
Benchmarks the generation time for a random input to the model.
"""
batch_size = run_config.batch_size
max_total_tokens = (
run_config.max_seq_length
if not run_config.generation_cuda_graphs
else run_config.cuda_graphs_static_max_seq_len
)
max_new_tokens = max_total_tokens - context_length
print("\n" + "=" * 80)
print(
f"Benchmarking for batch_size = {batch_size}, prefill tokens ="
f" {context_length} and max new tokens = {max_new_tokens}"
)
input_str = _generate_random_words(batch_size, context_length)
tokenizer = AutoTokenizer.from_pretrained(run_config.model_name)
inputs = tokenizer(input_str, return_tensors="pt", padding=True)
max_context_tokens = inputs["input_ids"].size(1)
# Add padding to the left
inputs["input_ids"] = torch.nn.functional.pad(
inputs["input_ids"],
(max_total_tokens - max_context_tokens, 0),
value=tokenizer.pad_token_id,
)
# Add padding to the left (only intended for baseline generation with HF
# which expects padding to the left)
inputs["attention_mask"] = torch.nn.functional.pad(
inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0
)
inputs["input_ids"] = inputs["input_ids"].cuda()
inputs["attention_mask"] = inputs["attention_mask"].cuda()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
torch.cuda.synchronize()
start.record()
model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens)
torch.cuda.synchronize()
end.record()
print(f"Time: {start.elapsed_time(end)/1000:.2f} s.")
......@@ -5,7 +5,7 @@
"id": "6a5b2993",
"metadata": {},
"source": [
"# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n",
"# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
......
......@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb
.. toctree::
......
......@@ -215,6 +215,17 @@ class InferenceParams:
device=torch.cuda.current_device(),
)
# This internal buffer holds the running length of each
# unfinished sequence in the batch and is updated in `pre_step()`
# method. One use of this buffer is applying RoPE to q and k tensors
# during inference by slicing ROPE Embeddings according to the
# current sequence length window.
self.pre_step_seqlens = torch.zeros(
self.max_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def reset(self):
"""Reset InferenceParams state"""
self.sequences = OrderedDict()
......@@ -266,6 +277,15 @@ class InferenceParams:
for k, v in self.sequences.items():
self.sequences_pre_step[k] = v - step_dict[k]
pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
# Copy the pre-step seqlens to the device in CUDA Graphs safe manner.
self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_(
pre_step_seqlens_temp, non_blocking=False
)
seqlens_q = list(step_dict.values())
cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size)
......@@ -280,9 +300,7 @@ class InferenceParams:
def get_seqlens_pre_step(self):
"""Get cached sequence lengths before the stepping"""
return torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
return self.pre_step_seqlens
def convert_paged_to_nonpaged(self, layer_number: int):
"""
......@@ -458,14 +476,14 @@ class NonPagedKVCacheManager(KVCacheManager):
finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs]
self.batch_indices.copy_(
self.batch_indices.data[:].copy_(
torch.Tensor(
(
unfinished_indices
+ finished_indices
+ list(range(prev_batch_size, self.max_batch_size))
)
).to(dtype=torch.int32, device="cpu")
)
)
# Advance unfinished sequences
......
......@@ -889,23 +889,11 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb, k_pos_emb = rotary_pos_emb
# adjust key and value for inference
if inference_params is not None:
if self.qkv_format == "sbhd":
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)
else:
raise ValueError(
f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
)
sequence_start = inference_params.get_seqlens_pre_step()
# sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
# Applyig RoPE for inference needs start positions of sequences
# for each iteration.
sequence_start_positions = (
inference_params.get_seqlens_pre_step() if inference_params is not None else None
)
if pad_between_seqs:
rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded
......@@ -922,6 +910,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_q,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved,
)
key_layer = apply_rotary_pos_emb(
......@@ -932,6 +921,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_kv,
cp_size=self.cp_size,
cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved,
)
......
......@@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor
auto start_positions_cu = TensorWrapper(); // empty start_positions tensor
if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value());
TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor");
}
if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
......
......@@ -883,7 +883,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module."""
if not self.fp8:
if not self.fp8 and not self.fp8_calibration:
return [None] * self.num_gemms
weight_quantizers = [
self.quantizers["scaling_fwd"][
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment