Unverified Commit 7042d7ae authored by Sudhakar Singh's avatar Sudhakar Singh Committed by GitHub
Browse files

TE Gemma tutorial attempt#2 (#1839)



* add tutorial files and other local changes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove extraneous code for easy debu
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make cuda graphs work with non-paged and paged attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* perf imp for kv cache ops
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add code for calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changes to make quantizers work with fp8_calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use quantizer only if needed
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* functionality of the tutorial tested and perf checked
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove files and update headers/licenses
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* update header/license
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update tutorial for review
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make weights downloadable on the fly; remove extra print statements
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and update comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add comma back, typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* sequence_start_positions should be None for training
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add paged attention numberes and update requirements.txt file
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make tutorial work on blackwell
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove gemma FT tutorial for now
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixing the headings placement and rewording attention -> kv caching
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixes from comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the images
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* misc fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add more comments to te_gemma.py and cleanup utils.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more information about the hierarchy of the classes used in the tutorial
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add better cuda graphs picture
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* addd updated cuda graphs pictures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add illustrated cuda graphs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes in documentation
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add torch.no_grad() to force reduced memory usage
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* some fixes from recent comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes from remaining comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add te_rope_emb to class desc
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix tutorial wording; add calibration fix to grouped_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent ba37529c
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="1280"
height="379.66562"
overflow="hidden"
version="1.1"
id="svg31"
sodipodi:docname="fp8_model_init.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="1.8208"
inkscape:cx="685.41302"
inkscape:cy="184.80888"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g31" />
<defs
id="defs31">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath31">
<rect
style="fill:none"
id="rect32"
width="1390.9491"
height="379.66562"
x="-54.734409"
y="146.82722"
ry="36.489601" />
</clipPath>
</defs>
<g
id="g31"
clip-path="url(#clipPath31)"
transform="translate(0,-146.82722)">
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text1"
x="153.29384"
y="195.21265">FP32/BF16</text>
<path
d="M 821,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text2"
x="616.69165"
y="194.66344">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text3"
x="908.73199"
y="193.56503">FP8 with fp8_model_init()</text>
<rect
x="868"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect3" />
<rect
x="882.45081"
y="381.1239"
width="101"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="920.40778"
y="400.1239">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="911.3208"
y="416.1239">weight</text>
<rect
x="1078.4508"
y="381.1239"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="1107.5007"
y="400.1239">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="1098.8308"
y="416.1239">GEMM</text>
<path
d="m 983.45079,403.1239 h 89.04001 v 2 h -89.04001 z m 87.71001,-3 8,4 -8,4 z"
id="path7" />
<path
d="M 422,170 V 513.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path9" />
<rect
x="54"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect9" />
<rect
x="67.45079"
y="367.47629"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text10"
x="104.84079"
y="390.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text11"
x="91.087494"
y="406.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text12"
x="98.087494"
y="422.47629">weight</text>
<rect
x="270.45081"
y="240.47627"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text13"
x="307.6308"
y="263.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text14"
x="293.87778"
y="279.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text15"
x="305.79779"
y="295.47629">input</text>
<rect
x="270.45081"
y="367.47629"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect15" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text16"
x="307.6308"
y="390.47629">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text17"
x="293.87778"
y="406.47629">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text18"
x="301.29779"
y="422.47629">GEMM</text>
<path
d="m 170.4572,404.11625 93.11279,-0.59724 -0.0128,-1.99996 -93.11281,0.59725 z m 91.79869,2.41125 7.9742,-4.05123 -8.0255,-3.9486 z"
id="path18" />
<path
d="m 323.45079,311.47627 v 49.395 h -2 v -49.395 z m 3,48.061 -4,8 -4,-8 z"
id="path19" />
<rect
x="447"
y="326"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="460.90158"
y="368.57471"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="497.76358"
y="392.57471">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="484.01059"
y="408.57471">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="491.01059"
y="424.57471">weight</text>
<rect
x="604.90161"
y="381.57471"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23"
x="633.21356"
y="399.57471">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="622.71356"
y="415.57471">Weight</text>
<g
id="g33"
transform="translate(70.847981,7.139719)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="708.90161"
y="381.57471"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="737.56158"
y="399.57471">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="728.89557"
y="415.57471">GEMM</text>
<path
d="m 563.91732,405.21457 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 685.90158,402.57469 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 750.90158,284.49209 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 751.17135,355.90367 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="701.05359"
y="215.25253"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="738.23358"
y="238.25253">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="724.48059"
y="254.25255">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="736.40057"
y="270.25253">input</text>
<g
id="g33-9"
transform="translate(441.10986,7.0509646)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2-5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7-4"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6-3"
x="662.06604"
y="336.96341">Input</text>
</g>
<path
d="m 1121.1635,284.40334 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-1"
style="stroke-width:0.53037" />
<path
d="m 1121.4332,355.81492 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2-2"
style="stroke-width:0.53037" />
<rect
x="1071.3154"
y="215.16379"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23-3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29-3"
x="1108.4955"
y="238.16379">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32-4"
x="1094.7424"
y="254.1638">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33-1"
x="1106.6625"
y="270.16376">input</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="373.58408"
overflow="hidden"
version="1.1"
id="svg23"
sodipodi:docname="fp8_model_init_1_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="3.1237948"
inkscape:cx="479.86506"
inkscape:cy="186.79204"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="g23" />
<defs
id="defs23">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath23">
<rect
style="fill:none"
id="rect24"
width="997.38257"
height="373.58408"
x="-11.584002"
y="41.702408"
ry="36.489601" />
</clipPath>
</defs>
<g
id="g23"
clip-path="url(#clipPath23)"
transform="translate(0,-41.702408)">
<rect
x="0"
y="0"
width="960"
height="480"
fill="#ffffff"
id="rect1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(195.4,93)"
id="text1">FP32/BF16</text>
<path
d="M 461,61 V 404.312"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<rect
x="92"
y="217"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect2" />
<rect
x="105.07926"
y="266.32938"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text3"
x="142.27226"
y="289.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="128.51926"
y="305.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="135.51926"
y="321.32938">weight</text>
<rect
x="308.07925"
y="138.32938"
width="103"
height="72"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="345.06326"
y="162.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="331.31027"
y="178.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text8"
x="343.23026"
y="194.32938">input</text>
<rect
x="308.07925"
y="266.32938"
width="103"
height="70"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text9"
x="345.06326"
y="289.32938">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text10"
x="331.30927"
y="305.32938">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text11"
x="338.72925"
y="321.32938">GEMM</text>
<path
d="m 208.08567,302.96936 93.11279,-0.59724 -0.0128,-1.99996 -93.11281,0.59724 z m 91.79869,2.41125 7.9742,-4.05123 -8.0255,-3.9486 z"
id="path11" />
<path
d="m 360.07926,210.32938 v 49.395 h -2 v -49.395 z m 3,48.061 -4,8 -4,-8 z"
id="path12" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="22px"
transform="translate(645.181,91)"
id="text23">FP8</text>
<rect
x="495.63504"
y="222.57803"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="509.53662"
y="265.15271"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="546.39862"
y="289.15274">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="532.64563"
y="305.15274">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="539.64563"
y="321.15274">weight</text>
<rect
x="653.53668"
y="278.15274"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-3"
x="681.84863"
y="296.15274">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="671.34863"
y="312.15274">Weight</text>
<g
id="g33"
transform="translate(119.48305,-96.282252)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="757.53668"
y="278.15274"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="786.19666"
y="296.15274">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="777.53064"
y="312.15274">GEMM</text>
<path
d="m 612.55239,301.7926 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 734.53665,299.15272 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 799.53665,181.07012 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 799.80642,252.4817 v 21.98469 h -2 V 252.4817 Z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="749.68866"
y="111.83057"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="786.86865"
y="134.83058">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="773.11566"
y="150.83058">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="785.03564"
y="166.83057">input</text>
</g>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="960"
height="379.95526"
overflow="hidden"
version="1.1"
id="svg19"
sodipodi:docname="fp8_model_init_2_half.svg"
inkscape:version="1.4.2 (f4327f4, 2025-05-13)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview1"
pagecolor="#ffffff"
bordercolor="#000000"
borderopacity="0.25"
inkscape:showpageshadow="2"
inkscape:pageopacity="0.0"
inkscape:pagecheckerboard="0"
inkscape:deskcolor="#d1d1d1"
inkscape:zoom="2.1718178"
inkscape:cx="502.34416"
inkscape:cy="194.07705"
inkscape:window-width="3440"
inkscape:window-height="1369"
inkscape:window-x="-8"
inkscape:window-y="-8"
inkscape:window-maximized="1"
inkscape:current-layer="svg19" />
<defs
id="defs19">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath20">
<rect
style="fill:none"
id="rect21"
width="1014.7587"
height="379.95526"
x="-21.430403"
y="44.598408" />
</clipPath>
</defs>
<g
id="g19"
clip-path="url(#clipPath20)"
transform="translate(-76.837568,-52.086815)" />
<path
d="M 434.81331,26.957307 V 370.26931"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text2"
x="216.69165"
y="33.663437">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
id="text3"
x="508.73199"
y="32.565033">FP8 with fp8_model_init()</text>
<rect
x="481.81332"
y="182.95731"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect3" />
<rect
x="496.26413"
y="238.08121"
width="101"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text4"
x="534.22107"
y="257.08121">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text5"
x="525.13409"
y="273.08121">weight</text>
<rect
x="692.2641"
y="238.08121"
width="82"
height="45"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text6"
x="721.31403"
y="257.08121">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text7"
x="712.6441"
y="273.08121">GEMM</text>
<path
d="m 597.2641,260.08121 h 89.04001 v 2 H 597.2641 Z m 87.71001,-3 8,4 -8,4 z"
id="path7" />
<rect
x="60.813313"
y="182.95731"
width="129"
height="164"
stroke="#042433"
stroke-width="2"
stroke-miterlimit="8"
fill="#e8e8e8"
id="rect19" />
<rect
x="74.714897"
y="225.53201"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect20" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text20"
x="111.5769"
y="249.53201">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text21"
x="97.823906"
y="265.53201">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text22"
x="104.82391"
y="281.53201">weight</text>
<rect
x="218.71492"
y="238.53201"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23"
x="247.02687"
y="256.53201">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24"
x="236.52687"
y="272.53201">Weight</text>
<g
id="g33"
transform="translate(-315.33871,-135.90297)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6"
x="662.06604"
y="336.96341">Input</text>
</g>
<rect
x="322.71494"
y="238.53201"
width="82"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#c1e5f5"
id="rect26" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text27"
x="351.37491"
y="256.53201">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text28"
x="342.70889"
y="272.53201">GEMM</text>
<path
d="m 177.73063,262.17188 34.00266,-0.5351 -0.0314,-1.99976 -34.00273,0.53511 z m 32.71676,2.4855 7.9361,-4.12538 -8.062,-3.87362 z"
id="path28" />
<path
d="m 299.71489,259.532 h 15.791 v 2 h -15.791 z m 14.458,-3 8,4 -8,4 z"
id="path29" />
<path
d="m 364.71489,141.4494 v 21.98469 h -2 V 141.4494 Z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30"
style="stroke-width:0.53037" />
<path
d="m 364.98466,212.86098 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2"
style="stroke-width:0.53037" />
<rect
x="314.86691"
y="72.209839"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29"
x="352.04691"
y="95.209839">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32"
x="338.29391"
y="111.20985">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33"
x="350.2139"
y="127.20984">input</text>
<g
id="g33-9"
transform="translate(54.923173,-135.99173)">
<rect
x="638.21271"
y="302.41418"
width="81"
height="44"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#fbe3d6"
id="rect22-2-5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text23-7-4"
x="666.52472"
y="320.41418">FP8</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text24-6-3"
x="662.06604"
y="336.96341">Input</text>
</g>
<path
d="m 734.97681,141.36065 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-1"
style="stroke-width:0.53037" />
<path
d="m 735.24651,212.77223 v 21.98469 h -2 v -21.98469 z m 3,21.60945 -4,2.25033 -4,-2.25033 z"
id="path30-2-2"
style="stroke-width:0.53037" />
<rect
x="685.12872"
y="72.121094"
width="103"
height="71"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect23-3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text29-3"
x="722.30878"
y="95.121094">High</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text32-4"
x="708.55573"
y="111.12111">precision</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="13px"
id="text33-1"
x="720.47577"
y="127.12106">input</text>
</svg>
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
width="1280"
height="303.21127"
overflow="hidden"
version="1.1"
id="svg12"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<defs
id="defs12">
<clipPath
clipPathUnits="userSpaceOnUse"
id="clipPath16">
<rect
style="fill:none;stroke-width:0.96471"
id="rect16"
width="1344.0338"
height="303.21124"
x="-32.356411"
y="174.8833" />
</clipPath>
</defs>
<g
id="g12"
transform="translate(1.1556091e-7,-174.8833)"
clip-path="url(#clipPath16)">
<rect
x="0"
y="0"
width="1280"
height="720"
fill="#ffffff"
id="rect1" />
<path
d="M 645,209 V 446.818"
stroke="#000000"
stroke-width="2"
stroke-miterlimit="8"
fill="none"
fill-rule="evenodd"
id="path1" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
transform="translate(201.111,246)"
id="text1">Without CUDA Graphs</text>
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="700"
font-size="24px"
transform="translate(855.749,246)"
id="text2">With CUDA Graphs</text>
<rect
x="64"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect2" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(75.6135,349)"
id="text3">Launch 1</text>
<rect
x="155"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect3" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(169.288,401)"
id="text4">Kernel 1</text>
<rect
x="245"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect4" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(256.462,349)"
id="text5">Launch 2</text>
<rect
x="336"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect5" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(350.136,401)"
id="text6">Kernel 2</text>
<rect
x="426"
y="319"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect6" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(437.31,349)"
id="text7">Launch 3</text>
<rect
x="517"
y="371"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect7" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(530.984,401)"
id="text8">Kernel 3</text>
<path
d="m 47,368 h 574.291 v 4 H 47 Z m 572.291,-4 12,6 -12,6 z"
id="path8" />
<rect
x="680"
y="319"
width="145"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#f2f2f2"
id="rect8" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(694.058,349)"
id="text9">Launch Graph 1</text>
<rect
x="830"
y="370"
width="91"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect9" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(844.463,400)"
id="text10">Kernel 1</text>
<rect
x="924"
y="370"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect10" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(938.451,400)"
id="text11">Kernel 2</text>
<rect
x="1018"
y="370"
width="90"
height="49"
stroke="#000000"
stroke-width="2"
stroke-linejoin="round"
stroke-miterlimit="10"
fill="#d9f2d0"
id="rect11" />
<text
font-family="'NVIDIA Sans', 'NVIDIA Sans_MSFontService', sans-serif"
font-weight="400"
font-size="16px"
transform="translate(1032.44,400)"
id="text12">Kernel 3</text>
<path
d="m 663,368 h 574.29 v 4 H 663 Z m 572.29,-4 12,6 -12,6 z"
id="path12" />
</g>
</svg>
transformers==4.55.0
accelerate==1.10.0
datasets==4.0.0
sentencepiece==0.2.1
This diff is collapsed.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import os
import re
import gc
import torch
from typing import List
from transformer_engine.pytorch.fp8 import fp8_model_init
from transformers.modeling_utils import load_state_dict
from transformers.utils.hub import get_checkpoint_shard_files
"""
This file contains logic of mapping the HuggingFace GemmaModel parameters
with TransformerEngine TransformerLayer. When we have initialized Transformer models
both with HF and with TE, we can copy parameters from the first to the second.
"""
def _load_weights_for_fp8_model(vanilla_model, hyperparams):
"""
Loads weights and FP8 metadata from a calibrated weights file.
The weights are in BF16 precision, but the state dict also contains
fp8 metadata computed by the calibration procedure.
"""
fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename)
# A hack to remove the extra state from the fp8_metadata_sd
# that contains the extra state from the core_attention module.
fp8_metadata_sd = {
k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k
}
vanilla_model.load_state_dict(
fp8_metadata_sd,
strict=False,
# Because some parameters have multiple pointers to the same weight
# vanilla_model._model_context_phase.model and
# vanilla_model._model_generation_phase.model we need to load the
# weights in a non-strict manner.
)
def _load_weights_for_standard_model(vanilla_model, config):
"""
Loads weights from the HuggingFace checkpoint.
"""
archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json")
resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file)
total_dict = {}
for shard_file in resolved_archive_file:
state_dict = load_state_dict(shard_file)
total_dict.update(state_dict)
replace_params(
total_dict,
vanilla_model.state_dict(),
config,
qkv_fused_and_interleaved=config.fuse_qkv_params,
)
# Copy remaining parameters like embedding.
vanilla_model.load_state_dict(total_dict, strict=False)
# Force mem release. Taken from huggingface code.
del total_dict
gc.collect()
def load_te_model(cls, config):
"""
Loads the TE model with proper weights.
"""
# Force the dtype to bfloat16 while loading the model.
old_dtype = torch.get_default_dtype()
torch.set_default_dtype(torch.bfloat16)
"""
Custom method adapted from `from_pretrained` method in HuggingFace
Transformers repo:
https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579
"""
config.use_cache = False # To make TransformerLayer compatible with GemmaModel
# Loading model with FP8 only weights needs both the following context managers.
# 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights.
# 2. torch.no_grad() during TE modules' initilization so that they respect
# the `fp8_model_init` context manager.
with torch.no_grad(), fp8_model_init(config.fp8_model_init):
# Just create a model with random weights.
vanilla_model = cls(config).cuda()
# Copy proper weights into the model. If loading weights with FP8 metadata,
# then the source weights are basically the same as the weights in the model.
# If not, then we need to load the weights from the HuggingFace checkpoint
# and do mapping of the weight names from HF to the TE model.
if config.fp8_model_weights_filename is not None:
_load_weights_for_fp8_model(vanilla_model, config)
else:
_load_weights_for_standard_model(vanilla_model, config)
# Restore the original dtype.
torch.set_default_dtype(old_dtype)
return vanilla_model
def _get_all_layer_prefixes_to_update(hf_state_dict):
"""
There are many parameters in hf_state_dict, whose name start with "model.layers.[number]."
This function extracts all strings like "model.layers.[number]."
that are starting strings of keys in hf_state_dict.
"""
all_layer_prefixes = set()
for param_key in hf_state_dict.keys():
layer_prefix_pat = "model.layers.\d+."
m = re.match(layer_prefix_pat, param_key)
if m is not None:
all_layer_prefixes.add(m.group())
return all_layer_prefixes
def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False):
"""
Replaces params from TE TransformerLayer state_dict with corresponding parameters
from HuggingFace GemmaModel state_dict.
"""
all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict)
for layer_prefix in all_layer_prefixes:
def copy_from_ht_to_te(te_name, hf_name, start=None, end=None):
te_state_dict[layer_prefix + te_name].data[start:end].copy_(
hf_state_dict[layer_prefix + hf_name]
)
copy_from_ht_to_te(
"self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight"
)
copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight")
copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight")
copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight")
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size
)
copy_from_ht_to_te(
"layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size
)
if qkv_fused_and_interleaved:
"""
When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor
in TE TransformerLayer. Moreover they are interleaved within each head.
Let q_i, k_i and v_i be query, key and value layers for i-th head respectively.
Then TE stores weight tensor in the form:
[q1 k1 v1 q2 k2 v2 ...]
This is done to maximally optimize performance time.
"""
te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"]
def copy_interleave(hf_name, idx):
src = hf_state_dict[layer_prefix + hf_name]
for head_nr in range(config.num_attention_heads):
dst_offset = head_nr * config.head_dim * 3
dst_slice = slice(
dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim
)
src_slice = slice(
head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim
)
te_qkv_layer[dst_slice, :] = src[src_slice, :]
copy_interleave("self_attn.q_proj.weight", 0)
copy_interleave("self_attn.k_proj.weight", 1)
copy_interleave("self_attn.v_proj.weight", 2)
else:
copy_from_ht_to_te(
"self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight"
)
copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight")
copy_from_ht_to_te(
"self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight"
)
return all_layer_prefixes
This diff is collapsed.
This diff is collapsed.
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
"id": "6a5b2993", "id": "6a5b2993",
"metadata": {}, "metadata": {},
"source": [ "source": [
"# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"\n", "\n",
......
...@@ -46,6 +46,7 @@ Transformer Engine documentation ...@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/te_gemma/tutorial_generation_gemma_with_te.ipynb
examples/onnx/onnx_export.ipynb examples/onnx/onnx_export.ipynb
.. toctree:: .. toctree::
......
...@@ -215,6 +215,17 @@ class InferenceParams: ...@@ -215,6 +215,17 @@ class InferenceParams:
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
) )
# This internal buffer holds the running length of each
# unfinished sequence in the batch and is updated in `pre_step()`
# method. One use of this buffer is applying RoPE to q and k tensors
# during inference by slicing ROPE Embeddings according to the
# current sequence length window.
self.pre_step_seqlens = torch.zeros(
self.max_batch_size,
dtype=torch.int32,
device=torch.cuda.current_device(),
)
def reset(self): def reset(self):
"""Reset InferenceParams state""" """Reset InferenceParams state"""
self.sequences = OrderedDict() self.sequences = OrderedDict()
...@@ -266,6 +277,15 @@ class InferenceParams: ...@@ -266,6 +277,15 @@ class InferenceParams:
for k, v in self.sequences.items(): for k, v in self.sequences.items():
self.sequences_pre_step[k] = v - step_dict[k] self.sequences_pre_step[k] = v - step_dict[k]
pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to(
dtype=torch.int32, device="cpu"
)
# Copy the pre-step seqlens to the device in CUDA Graphs safe manner.
self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_(
pre_step_seqlens_temp, non_blocking=False
)
seqlens_q = list(step_dict.values()) seqlens_q = list(step_dict.values())
cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)]
cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size)
...@@ -280,9 +300,7 @@ class InferenceParams: ...@@ -280,9 +300,7 @@ class InferenceParams:
def get_seqlens_pre_step(self): def get_seqlens_pre_step(self):
"""Get cached sequence lengths before the stepping""" """Get cached sequence lengths before the stepping"""
return torch.Tensor(list(self.sequences_pre_step.values())).to( return self.pre_step_seqlens
dtype=torch.int32, device="cpu"
)
def convert_paged_to_nonpaged(self, layer_number: int): def convert_paged_to_nonpaged(self, layer_number: int):
""" """
...@@ -458,14 +476,14 @@ class NonPagedKVCacheManager(KVCacheManager): ...@@ -458,14 +476,14 @@ class NonPagedKVCacheManager(KVCacheManager):
finished_seqs = self.sequences.keys() - unfinished_seqs finished_seqs = self.sequences.keys() - unfinished_seqs
unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs]
finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs]
self.batch_indices.copy_( self.batch_indices.data[:].copy_(
torch.Tensor( torch.Tensor(
( (
unfinished_indices unfinished_indices
+ finished_indices + finished_indices
+ list(range(prev_batch_size, self.max_batch_size)) + list(range(prev_batch_size, self.max_batch_size))
) )
).to(dtype=torch.int32, device="cpu") )
) )
# Advance unfinished sequences # Advance unfinished sequences
......
...@@ -889,23 +889,11 @@ class MultiheadAttention(torch.nn.Module): ...@@ -889,23 +889,11 @@ class MultiheadAttention(torch.nn.Module):
q_pos_emb, k_pos_emb = rotary_pos_emb q_pos_emb, k_pos_emb = rotary_pos_emb
# adjust key and value for inference # Applyig RoPE for inference needs start positions of sequences
if inference_params is not None: # for each iteration.
if self.qkv_format == "sbhd": sequence_start_positions = (
sequence_length = key_layer.size(0) inference_params.get_seqlens_pre_step() if inference_params is not None else None
elif self.qkv_format == "bshd": )
sequence_length = key_layer.size(1)
else:
raise ValueError(
f"qkv_format={self.qkv_format} not supported for KV caching and RoPE."
)
sequence_start = inference_params.get_seqlens_pre_step()
# sequence_start = inference_params.seqlens[0]
sequence_end = sequence_start + sequence_length
q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]
if pad_between_seqs: if pad_between_seqs:
rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded
...@@ -922,6 +910,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -922,6 +910,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_q, cu_seqlens=rotary_pos_cu_seq_lens_q,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved, interleaved=self.rotary_pos_interleaved,
) )
key_layer = apply_rotary_pos_emb( key_layer = apply_rotary_pos_emb(
...@@ -932,6 +921,7 @@ class MultiheadAttention(torch.nn.Module): ...@@ -932,6 +921,7 @@ class MultiheadAttention(torch.nn.Module):
cu_seqlens=rotary_pos_cu_seq_lens_kv, cu_seqlens=rotary_pos_cu_seq_lens_kv,
cp_size=self.cp_size, cp_size=self.cp_size,
cp_rank=self.cp_rank, cp_rank=self.cp_rank,
start_positions=sequence_start_positions,
interleaved=self.rotary_pos_interleaved, interleaved=self.rotary_pos_interleaved,
) )
......
...@@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, ...@@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs,
auto freqs_cu = makeTransformerEngineTensor(freqs); auto freqs_cu = makeTransformerEngineTensor(freqs);
auto output_cu = makeTransformerEngineTensor(output); auto output_cu = makeTransformerEngineTensor(output);
auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor auto start_positions_cu = TensorWrapper(); // empty start_positions tensor
if (start_positions) { if (start_positions) {
start_positions_cu = makeTransformerEngineTensor(start_positions.value()); start_positions_cu = makeTransformerEngineTensor(start_positions.value());
TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor");
} }
if (qkv_format == NVTE_QKV_Format::NVTE_THD) { if (qkv_format == NVTE_QKV_Format::NVTE_THD) {
......
...@@ -883,7 +883,7 @@ class GroupedLinear(TransformerEngineBaseModule): ...@@ -883,7 +883,7 @@ class GroupedLinear(TransformerEngineBaseModule):
def _get_weight_quantizers(self) -> List[Quantizer]: def _get_weight_quantizers(self) -> List[Quantizer]:
"""Get the weight quantizers of the module.""" """Get the weight quantizers of the module."""
if not self.fp8: if not self.fp8 and not self.fp8_calibration:
return [None] * self.num_gemms return [None] * self.num_gemms
weight_quantizers = [ weight_quantizers = [
self.quantizers["scaling_fwd"][ self.quantizers["scaling_fwd"][
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment