Unverified Commit 6afca29c authored by Paweł Gadziński's avatar Paweł Gadziński Committed by GitHub
Browse files

[PyTorch Debug] More advanced stats for Quantized Tensors (#1897)



* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* turn on userbuffers for layers without debug
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* working change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests and fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* update nvinspect version
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* docs change
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* test
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix default
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix default
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix default
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* tests fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fixes
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent ebca6153
......@@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w
Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing.
In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below.
.. figure:: ./img/api_calls2.svg
:align: center
Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls.
In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed.
There are 2 categories of API calls, each is used for different purposes:
- GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them,
......@@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled
<svg width="4235" height="2342" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" overflow="hidden"><g transform="translate(-41 -119)"><g><rect x="46.4999" y="1576.5" width="1564" height="734" stroke="#042433" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><rect x="630.5" y="125.5" width="580" height="151" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 786.823 225)">Tensor A</text><rect x="303.5" y="337.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 684.099 414)">inspect_tensor</text><rect x="1258.5" y="596.5" width="617" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1440.36 673)">fp8 </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1557.81 673)">cast</text><rect x="114.5" y="596.5" width="683" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 227.611 673)">modify_tensor</text><rect x="303.5" y="826.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 460.753 903)">inspect_tensor_postquantize</text><rect x="1583.5" y="1123.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2095.73 1200)">GEMM</text><rect x="1583.5" y="1310.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1963.85 1387)">inspect_tensor</text><rect x="1859.5" y="1499.5" width="682" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1972.18 1576)">modify_tensor</text><rect x="115.5" y="1956.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 434.13 2033)">inspect_tensor_enabled</text><rect x="115.5" y="2103.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 210.785 2180)">inspect_tensor_postquantize_enabled</text><rect x="115.5" y="1660.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 502.961 1737)">fp8_gemm_enabled</text><rect x="115.5" y="1808.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 442.461 1885)">modify_tensor_enabled</text><path d="M1.07643-3.26461 444.129 142.822 441.977 149.351-1.07643 3.26461ZM443.006 131.593 464.817 153.263 434.395 157.71Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 920.317 443.5)"/><path d="M921.293 440.155 1545.58 588.133 1543.99 594.822 919.707 446.845ZM1543.5 577.041 1567.09 596.763 1537.16 603.8Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M923.938 443.5 923.938 803.572 917.063 803.572 917.063 443.5ZM934.25 798.988 920.5 826.488 906.75 798.988Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M456.384 699.178 899.056 817.032 897.288 823.676 454.616 705.822ZM897.281 805.888 920.317 826.25 890.206 832.462Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.646175-3.37622 624.723 116.066 623.431 122.818-0.646175 3.37622ZM622.16 105.076 646.585 123.75 616.991 132.085Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 1567.09 702.5)"/><rect x="2945.5" y="125.5" width="579" height="151" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3100.27 225)">Tensor B</text><rect x="2617.5" y="337.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2998.12 414)">inspect_tensor</text><rect x="3572.5" y="596.5" width="617" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3754.39 673)">fp8 </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3871.84 673)">cast</text><rect x="2428.5" y="596.5" width="683" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2541.64 673)">modify_tensor</text><rect x="2617.5" y="826.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2774.78 903)">inspect_tensor_postquantize</text><path d="M1.07643-3.26461 444.129 142.822 441.976 149.351-1.07643 3.26461ZM443.006 131.593 464.817 153.263 434.394 157.71Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3234.32 443.5)"/><path d="M3235.29 440.155 3859.58 588.133 3857.99 594.822 3233.71 446.845ZM3857.5 577.041 3881.09 596.763 3851.16 603.8Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M3237.94 443.5 3237.94 803.572 3231.06 803.572 3231.06 443.5ZM3248.25 798.988 3234.5 826.488 3220.75 798.988Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M2770.38 699.178 3213.06 817.032 3211.29 823.676 2768.62 705.822ZM3211.28 805.888 3234.32 826.25 3204.21 832.462Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.646175-3.37622 624.723 116.066 623.431 122.818-0.646175 3.37622ZM622.16 105.076 646.585 123.75 616.991 132.085Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3881.09 702.5)"/><path d="M921.009 929.1 2178.11 1117.2 2177.09 1124 919.991 935.9ZM2175.09 1106.33 2200.26 1123.99 2171.02 1133.52Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.625813-3.38005 1012.36 183.941 1011.11 190.702-0.625813 3.38005ZM1009.73 172.967 1034.27 191.493 1004.72 200.007Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3234.77 932.5)"/><path d="M3.4375-1.54131e-05 3.43776 57.5713-3.43724 57.5714-3.4375 1.54131e-05ZM13.7502 52.988 0.000360892 80.488-13.7498 52.9881Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 2200.5 1229.5)"/><path d="M3.4375-1.54131e-05 3.43776 57.5713-3.43724 57.5714-3.4375 1.54131e-05ZM13.7502 52.988 0.000360892 80.488-13.7498 52.9881Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 2200.5 1418.5)"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 623.083 2394)">Routing </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 882.041 2394)">calls</text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3298.55 1286)">GEMM </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3536.88 1286)">calls</text><path d="M923.938 276.5 923.938 314.619 917.063 314.62 917.063 276.5ZM934.25 310.036 920.5 337.536 906.75 310.036Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M3237.94 276.5 3237.94 314.619 3231.06 314.62 3231.06 276.5ZM3248.25 310.036 3234.5 337.536 3220.75 310.036Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/></g></g></svg>
\ No newline at end of file
......@@ -23,6 +23,7 @@ pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
......
......@@ -364,6 +364,40 @@ def test_log_distributed(parallel_mode, gather_weight, **kwargs):
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs):
from test_log import LOG_QUANTIZED_CONFIG
kwargs["config_file"].write(LOG_QUANTIZED_CONFIG)
kwargs["config_file"].flush()
_init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS)
set_weight_tensor_tp_group_reduce(gather_weight)
if WORLD_SIZE % 2 != 0:
return # skip
TP_SIZE = WORLD_SIZE // 2
DP_SIZE = 2
TP_RANK = WORLD_RANK % TP_SIZE
DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE
debug_api.set_tensor_reduction_group(NCCL_WORLD)
x, weight = _get_tensors(
parallel_mode,
weight_seed=TP_RANK * 1234,
data_seed=DP_RANK * 1234,
tp_size=TP_SIZE,
tp_rank=TP_RANK,
)
tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)]
tp_group = dist.new_group(ranks=tp_group_ranks)
model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group)
_run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group)
set_weight_tensor_tp_group_reduce(True) # reset
@run_debug_test
def test_log_expert_parallel(**kwargs):
"""
......
......@@ -36,11 +36,6 @@ def test_transformer_engine_no_config(feature_dirs):
"decoder.1.attn.qkv", tensor_name="activation", iteration=0
)[0]
# inspect_tensor_postquantize - (False, None) by default
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0
)[0]
finally:
debug_api.end_debug()
......@@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs):
)
tensor = torch.randn((100, 100, 5)).cuda()
tensor_fp8 = Float8Tensor(
data=tensor.to(torch.uint8).cuda(),
fp8_scale_inv=torch.full([1], 1.0).cuda(),
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=tensor.shape,
dtype=torch.float32,
)
tensor_fp8 = quantizer(tensor)
def log():
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
......@@ -260,6 +254,9 @@ def test_statistics_collection(configs_dir, feature_dirs):
tensor_name="activation",
iteration=200,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max()
......@@ -269,44 +266,52 @@ def test_statistics_collection(configs_dir, feature_dirs):
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="activation", iteration=200
)[0]
assert not debug_api.transformer_engine.inspect_tensor_enabled(
expected_underflows = (
((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5)
)
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5)
# TE FP8 tensor stats --
assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
assert debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
debug_api.transformer_engine.inspect_tensor_postquantize(
debug_api.transformer_engine.inspect_tensor(
"decoder.1.mlp.fc1",
tensor=tensor_fp8,
tensor_name="gradient",
iteration=200,
rowwise=True,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
torch.testing.assert_close(
stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows
)
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.1.mlp.fc1", tensor_name="activation", iteration=201
)[0]
assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200
assert not debug_api.transformer_engine.inspect_tensor_enabled(
"decoder.2.mlp.fc1", tensor_name="gradient", iteration=200
)[0]
# Second config in same yaml
tensor = torch.rand((100, 100, 5))
debug_api.transformer_engine.inspect_tensor(
"decoder.6.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -315,10 +320,13 @@ def test_statistics_collection(configs_dir, feature_dirs):
debug_api.transformer_engine.inspect_tensor(
"decoder.7.mlp.fc1",
tensor=tensor,
tensor_name="weight",
iteration=200,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
stats = log()
stats_names = [x[3] for x in stats.keys()]
......@@ -342,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
default_logging_enabled=False,
)
def feed(tensor, tensor_fp8):
def feed(tensor, tensor_fp8, quantizer):
debug_api.transformer_engine.inspect_tensor(
"decoder.5.mlp.fc1",
tensor=tensor,
tensor_name="activation",
iteration=1,
tp_group=None,
)
debug_api.transformer_engine.inspect_tensor_postquantize(
"decoder.5.mlp.fc1",
tensor=tensor_fp8,
tensor_name="activation",
iteration=1,
rowwise=True,
tp_group=None,
quantizer=quantizer,
rowwise_quantized_tensor=tensor_fp8,
columnwise_quantized_tensor=tensor_fp8,
)
def log_stats():
......@@ -364,26 +367,26 @@ def test_statistics_multi_run(configs_dir, feature_dirs):
return STATS_BUFFERS.log_stats()
quantizer = Float8Quantizer(
scale=torch.full([1], 1.0).cuda(),
amax=torch.full([1], 1.0).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
)
def fp8_tensor(t):
return Float8Tensor(
data=t.to(torch.uint8).cuda(),
fp8_scale_inv=torch.ones([1]).cuda(),
fp8_dtype=tex.DType.kFloat8E4M3,
shape=t.shape,
dtype=torch.float32,
)
return quantizer(t.cuda())
shape = [1024, 1024]
tensors = [torch.randn(shape) for _ in range(2)]
tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)]
feed(tensors[0], tensors_fp8[0])
feed(tensors[1], tensors_fp8[1])
feed(tensors[0], tensors_fp8[0], quantizer)
feed(tensors[1], tensors_fp8[1], quantizer)
stats1 = log_stats()
tensor2 = torch.cat((tensors[0], tensors[1])).cuda()
fp8tensor2 = fp8_tensor(tensor2)
feed(tensor2, fp8tensor2)
feed(tensor2, fp8tensor2, quantizer)
stats2 = log_stats()
assert len(stats1.keys()) > 0
......
......@@ -12,7 +12,7 @@ test:
freq: 3
LogFp8TensorStats:
enabled: True
tensors: weight
tensors: activation
stats: [underflows%]
start_step: 1
freq: 5
......@@ -2,16 +2,207 @@
#
# See LICENSE for license information.
import pytest
import torch
import nvdlfw_inspect.api as debug_api
import transformer_engine.debug
import transformer_engine.pytorch as te
import torch
import tempfile
from transformer_engine.common import recipe
from transformer_engine.pytorch.fp8 import RecipeState
import pytest
import contextlib
import os
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.debug.pytorch.debug_state import TEDebugState
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
LOG_QUANTIZED_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogFp8TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""
recipes = [
"fp8_delayed_scaling",
"fp8_current_scaling",
"fp8_block_scaling",
"mxfp8",
]
bare_stats = [
"underflows%",
"scale_inv_min",
"scale_inv_max",
"mse",
]
all_stats = []
for r in recipes:
for stat in bare_stats:
for columnwise_postfix in ["", "_columnwise"]:
if (
r in ["fp8_current_scaling", "fp8_block_scaling"]
and torch.cuda.get_device_capability()[0] < 9
):
# hopper is needed for current-scaling, block-scaling
continue
if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
# blackwell is needed for mxfp8
continue
if (
r in ["fp8_delayed_scaling", "fp8_current_scaling"]
and columnwise_postfix == "_columnwise"
):
# columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling
continue
all_stats.append(f"{r}_{stat}{columnwise_postfix}")
all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows%
@contextlib.contextmanager
def debug_session(config_str: str, feature_dirs):
"""
Helper context manager that
1. writes the YAML `config_str` to a temporary file,
2. starts a debug session, and
3. yields the directory that contains the statistics log.
The session is closed automatically – even on exceptions – so every test
stays concise and leak-free.
"""
with tempfile.NamedTemporaryFile(
mode="w", delete=False
) as cfg_file, tempfile.TemporaryDirectory() as log_dir:
cfg_file.write(config_str)
cfg_file.flush()
debug_api.initialize(
config_file=cfg_file.name,
feature_dirs=feature_dirs,
log_dir=log_dir,
)
try:
yield log_dir
finally:
debug_api.end_debug()
def read_log(log_dir: str) -> str:
"""Return the content of the statistics log produced by `debug_session`."""
stat_path = os.path.join(
log_dir,
"nvdlfw_inspect_statistics_logs",
"nvdlfw_inspect_globalrank-0.log",
)
with open(stat_path, "r") as f:
return f.read()
def test_sanity(feature_dirs):
log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats))
with debug_session(log_all_stats_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda()
for _ in range(10):
with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()
output = read_log(log_dir)
assert output, "Output is empty"
for stat in all_stats:
assert stat in output, f"Stat {stat} not found in output"
fp8_recipes = [
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]
@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_numerics(fp8_recipe, feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling():
pytest.skip(reason_for_no_mxfp8)
if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling():
pytest.skip(reason_for_no_fp8_block_scaling)
log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats))
with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir:
recipe_state = RecipeState.create(
fp8_recipe,
mode="forward",
num_quantizers=3,
)
tensor = torch.zeros(1024, 1024).cuda()
tensor[0, :] = 1000
quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)
debug_api.transformer_engine.inspect_tensor(
layer_name="layer_name",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()
dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)
for line in output.splitlines():
if "underflows%" in line:
underflows = float(line.split("value=")[1])
expected = (
((dequantized_tensor == 0).sum() - (tensor == 0).sum())
/ dequantized_tensor.numel()
* 100
)
assert underflows == pytest.approx(expected.cpu(), abs=1e-4)
if "mse" in line:
mse = float(line.split("value=")[1])
expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean")
assert mse == pytest.approx(expected.cpu(), abs=1e-6)
if "overflows%" in line:
overflows = float(line.split("value=")[1])
expected = (
(abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100
)
assert overflows == pytest.approx(expected.cpu(), abs=1e-4)
@pytest.mark.parametrize("layer", ["linear", "transformer"])
......@@ -35,7 +226,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
else:
raise ValueError(f"Invalid layer: {layer}")
for i in range(11):
for i in range(20):
x = torch.randn(4, 128, 128).cuda()
with te.fp8_autocast(enabled=True):
y = model(x)
......@@ -49,7 +240,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
"r",
) as f:
file_content = f.read()
for i in range(1, 11):
for i in range(1, 20):
if i % 3 == 0 or i % 5 == 0:
assert f"iteration={i:06d}" in file_content
else:
......
......@@ -5,6 +5,7 @@
"""API definition for nvidia-dlframework-inspect."""
import copy
import warnings
from typing import Dict, Union, Tuple, Optional
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry
......@@ -114,7 +115,7 @@ class TEDefaultFeatures:
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled.
It can return (bool, None) if the feature will never be enabled for that layer and gemm.
Returning the next enabled iteration can help optimize CPU usage.
......@@ -244,6 +245,9 @@ class TEDefaultFeatures:
layer_name: str,
tensor_name: str,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor],
columnwise_quantized_tensor: Optional[torch.Tensor],
quantizer: Optional[Quantizer],
iteration: int,
tp_group: torch.distributed.ProcessGroup,
) -> None:
......@@ -260,6 +264,12 @@ class TEDefaultFeatures:
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in high precision,
rowwise_quantized_tensor: Optional[torch.Tensor]
rowwise quantized tensor,
columnwise_quantized_tensor: Optional[torch.Tensor]
columnwise quantized tensor,
quantizer: Optional[Quantizer]
quantizer,
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
......@@ -277,12 +287,15 @@ class TEDefaultFeatures:
config: Dict,
layer_name: str,
tensor_name: str,
gemm: str,
tensor: torch.Tensor,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
rowwise: bool,
) -> None:
"""
This is deprecated call, we advise to use *inspect_tensor* instead.
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters
......@@ -295,8 +308,6 @@ class TEDefaultFeatures:
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int
iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup
......@@ -352,6 +363,8 @@ class TEDefaultFeatures:
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
"""
This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead.
It is a routing call, which is run at the initialization of the layer.
Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked.
......@@ -399,8 +412,8 @@ class TransformerEngineAPI(BaseNamespaceAPI):
"modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"],
"inspect_tensor_postquantize_enabled": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name", "iteration"],
"inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"],
"modify_tensor_enabled": ["tensor_name"],
}
......@@ -460,6 +473,26 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"]
def call_feature(self, call, feat_config, layer_name, **kwargs):
"""
For backward compatibility, remove kwargs that are not needed for the call
"""
if call.__name__ == "inspect_tensor":
kwargs_copy = kwargs.copy()
for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(k)
else:
kwargs_copy = kwargs
if call.__name__ == "inspect_tensor_postquantize":
warnings.warn(
"inspect_tensor_postquantize is deprecated, use inspect_tensor instead.",
DeprecationWarning,
)
return call(feat_config, layer_name, **kwargs_copy)
def handle_multi_feature_output(
self, api_name, multi_feature_outputs, features_to_invoke, **kwargs
):
......@@ -474,19 +507,18 @@ class TransformerEngineAPI(BaseNamespaceAPI):
# representing the number of steps after the feature will be enabled next time.
# If the second value is None, that means that the feature will never be enabled.
all_ret_tuple = all(
isinstance(feature_output, tuple)
for feature_output in multi_feature_outputs.values()
isinstance(feature_output, tuple) for feature_output in multi_feature_outputs
)
if all_ret_tuple:
run_current = any(
feature_output[0] for feature_output in multi_feature_outputs.values()
)
run_current = any(feature_output[0] for feature_output in multi_feature_outputs)
next_iter = None
for feature_output in multi_feature_outputs.values():
if feature_output[1] is not None:
for feature_output in multi_feature_outputs:
if next_iter is None:
next_iter = feature_output[1]
elif feature_output[1] is not None:
next_iter = min(next_iter, feature_output[1])
return run_current, next_iter
run_current = any(feature_output for feature_output in multi_feature_outputs.values())
run_current = any(feature_output for feature_output in multi_feature_outputs)
return run_current, None
return super().handle_multi_feature_output(
api_name, multi_feature_outputs, features_to_invoke, **kwargs
......
......@@ -50,4 +50,4 @@ class DisableFP8GEMM(TEConfigAPIMapper):
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, None
return False, iteration + 1
......@@ -41,7 +41,7 @@ class DisableFP8Layer:
# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, None
return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
......
......@@ -4,55 +4,119 @@
"""LogFp8TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
from typing import Dict, Optional, List, Tuple
from contextlib import contextmanager
import torch
import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
def _get_recipe_name(quantizer: Optional[Quantizer]):
if quantizer is None:
return ""
if isinstance(quantizer, Float8Quantizer):
return "fp8_delayed_scaling"
if isinstance(quantizer, Float8CurrentScalingQuantizer):
return "fp8_current_scaling"
if isinstance(quantizer, MXFP8Quantizer):
return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
def _get_new_quantizer(recipe_name, fp8_dtype):
if recipe_name == "fp8_block_scaling":
return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
if recipe_name == "fp8_current_scaling":
return Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True
)
if recipe_name == "mxfp8":
return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
if recipe_name == "fp8_delayed_scaling":
raise ValueError("Cannot recreate quantizer for fp8_delayed_scaling")
raise ValueError(f"Unsupported recipe name: {recipe_name}")
@Registry.register_feature(namespace="transformer_engine")
class LogFp8TensorStats(BaseLogTensorStats):
"""
This feature handles logging of FP8 tensor stats.
Logs statistics of quantized tensors.
Supports computing statistics for current recipe, but also
allows to see what would happend if different recipes were used for these tensors in current iteration.
For example, during delayed-scaling training you may wish to track
"current_scaling_underflows%" to measure the accuracy of the current scaling
factors; note that this requires an extra cast and therefore adds overhead.
Using a logging frequency (`freq`) greater than 1 is recommended in this case.
Computing the stats matching the training recipe does not require an extra cast.
In a distributed setting, the auxiliary stats are computed on each rank and gathered after
the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log
stats!
Statistics are identified by the pattern `<recipe>_<stat>` with optional `_columnwise` suffix (e.g.
`delayed_scaling_underflows%` or `mxfp8_scale_inv_min_columnwise`).
One can provide `<stat>` only, then the current training recipe is used.
`LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked
per `debug_api.step()`, then stats for all tensors except weights will be accumulated.
Stats for delayed-scaling cannot be collected if delayed-scaling is not the current training recipe.
`LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats
with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the
overhead is smaller. If no other feature is used for the layer, the TE layer will
run as fast as it would without `debug_api` initialized.
In distributed runs each rank first computes its local statistics; the values
are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
The feature is micro-batch aware: if several forward/backward passes occur
between successive `debug_api.step()` calls, statistics are accumulated for all
tensors except weights.
Collecting FP8 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters
----------
stats: List[str]
list of statistics to log
Each stat is a string of the form `<recipe>_<stat>`, with an optional `_columnwise` suffix (i.e., `<recipe>_<stat>_columnwise`).
If only `<recipe>` is omitted, the current training recipe is used.
For mxfp8 and fp8_block_scaling `_columnwise` suffix can be provided. Then stat is computed on columnwise(transpose)
version of the tensor, which can be numerically different from rowwise (non-transpose) tensors.
"_columnwise" suffix is not supported for fp8_delayed_scaling and fp8_current_scaling.
recipes:
- fp8_delayed_scaling,
- fp8_current_scaling,
- mxfp8,
- fp8_block_scaling,
stats:
- underflows% - percentage of non-zero elements of tensor clipped to 0 after quantization,
- overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling,
- scale_inv_min - minimum of the inverse of the scaling factors,
- scale_inv_max - maximum of the inverse of the scaling factors,
- mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements,
- underflows% - percentage of elements of the tensor equal to 0,
tensors/tensors_struct: List[str]
list of tensors to log
- activation,
- gradient,
- weight,
- activation
- gradient
- weight
freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None
......@@ -75,7 +139,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
enabled: True
tensors_struct:
- tensor: activation
stats: [underflows%]
stats: [mxfp8_underflows%]
freq: 1
- tensor: gradient
stats: [underflows%]
......@@ -84,13 +148,106 @@ class LogFp8TensorStats(BaseLogTensorStats):
end_step: 80
"""
def _get_supported_stats_list(self):
"""Returns stats this feature can log."""
return {"underflows%"}
def check_if_stat_is_supported(self, stat: str, current_recipe: str):
"""Returns True if stat is supported, raises ValueError otherwise."""
columnwise = stat.endswith("_columnwise")
if columnwise:
stat = stat[: -len("_columnwise")]
recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe)
stat_without_recipe = stat.replace(recipe_from_stat + "_", "")
if current_recipe == "" and recipe_from_stat == "":
raise ValueError(
f"Stat {stat} does not contain a recipe name and the current recipe is not set."
)
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
" fp8_delayed_scaling and fp8_current_scaling."
)
if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%":
return True
if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9:
raise ValueError(f"Stat {stat} needs Hopper or later GPU.")
if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
raise ValueError(f"Stat {stat} needs Blackwell or later GPU.")
supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"]
if stat_without_recipe not in supported_stats:
raise ValueError(
f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}"
)
return True
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat:
return recipe_name, columnwise_stat
return default_recipe, columnwise_stat
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
recipe_name: str,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer,
original_tensor: torch.Tensor,
recipes_in_stats: List[Tuple[str, bool]],
):
"""
Updates the aux_dict with the quantized tensor for each recipe provided in recipes_in_stats.
It allows to compute stats for different recipes in the same iteration,
without recomputing the quantized tensor for each recipe for each stat.
Also updates usage of the quantized tensor with rowwise and columnwise usage.
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = None
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
)
fp8_dtype = quantizer.dtype
aux_dict = {
recipe_name: quantized_tensor,
}
old_rowwise_usage = quantizer.rowwise_usage
old_columnwise_usage = quantizer.columnwise_usage
for cur_recipe_name, cur_columnwise_stat in recipes_in_stats:
if recipe_name is not cur_recipe_name:
quantizer = _get_new_quantizer(cur_recipe_name, fp8_dtype)
aux_dict[cur_recipe_name] = quantizer(original_tensor)
elif isinstance(quantized_tensor, QuantizedTensor):
if cur_columnwise_stat:
quantized_tensor.update_usage(columnwise_usage=True)
else:
quantized_tensor.update_usage(rowwise_usage=True)
aux_dict[""] = quantized_tensor
aux_dict[cur_recipe_name] = quantized_tensor
try:
yield aux_dict
finally:
if isinstance(quantized_tensor, QuantizedTensor):
quantized_tensor.update_usage(
rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage
)
@api_method
def inspect_tensor_postquantize_enabled(
self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int
def inspect_tensor_enabled(
self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor_postquantize() in the forward."""
run_current, next_iter = next_enabled_iter(
......@@ -104,29 +261,34 @@ class LogFp8TensorStats(BaseLogTensorStats):
return run_current, next_iter
@api_method
def inspect_tensor_postquantize(
def inspect_tensor(
self,
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
rowwise: bool,
iteration: int,
tp_group: torch.distributed.ProcessGroup,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
):
"""
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], (
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using"
" log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors."
)
quantized_tensor = rowwise_quantized_tensor
assert isinstance(
quantized_tensor, QuantizedTensor
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
recipe_name = _get_recipe_name(quantizer)
# This API can be invoked twice - with the tensor and with the transpose.
# We want to collect the stats once.
if not rowwise:
return # tensor was already seen rowwise in the other gemm
for stat in config["stats"]:
self.check_if_stat_is_supported(stat, recipe_name)
options = (
config.get("start_step", None),
......@@ -135,19 +297,9 @@ class LogFp8TensorStats(BaseLogTensorStats):
"fp8",
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
STATS_BUFFERS.try_add_buffer(
layer_name=layer_name,
......@@ -158,10 +310,30 @@ class LogFp8TensorStats(BaseLogTensorStats):
reduce_within_microbatch=reduce_within_microbatch,
)
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction)
recipes_in_stats = [
self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"]
]
with self.update_aux_dict(
aux_dict={},
recipe_name=recipe_name,
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
recipes_in_stats=recipes_in_stats,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}",
f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name,
extra_cachable_args=(tensor_name,),
)
......@@ -4,7 +4,7 @@
"""LogTensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union
from typing import Dict, Optional
import torch
......@@ -12,14 +12,13 @@ from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as Bas
from nvdlfw_inspect.registry import Registry, api_method
import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor
from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
@Registry.register_feature(namespace="transformer_engine")
......@@ -114,10 +113,13 @@ class LogTensorStats(BaseLogTensorStats):
config: Dict,
layer_name: str,
tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
iteration: int,
tp_group: torch.distributed.ProcessGroup,
):
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
): # pylint: disable=unused-argument
"""API call used to collect the data about the tensor before process_tensor()/quantization."""
assert (
......@@ -134,14 +136,9 @@ class LogTensorStats(BaseLogTensorStats):
config.get("start_end_list", None),
)
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
tensor_name, tp_group
)
for stat in config["stats"]:
assert (
......
......@@ -6,6 +6,26 @@
Utils for the debug features.
"""
import torch
import nvdlfw_inspect.api as debug_api
from transformer_engine.debug.pytorch.debug_state import TEDebugState
def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup):
"""
Returns the statistics reduction parameters for the tensor.
"""
skip_reduction = False
reduction_group = debug_api.get_tensor_reduction_group()
reduce_within_microbatch = tensor_name != "weight"
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
return skip_reduction, reduction_group, reduce_within_microbatch
def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration):
"""
......
......@@ -67,14 +67,17 @@ class _Buffer:
gathered_buffer, _ = gather_along_first_dim(
self._buffer.unsqueeze(0), process_group=self.reduction_group
)
return gathered_buffer[mask.to(bool)]
return gathered_buffer[mask.to(torch.bool)]
def feed(self, tensor, iteration):
def feed(self, tensor, iteration, aux_dict=None):
"""
feed() is used to add tensor for computing the statistics.
Because of the microbatching, feed() can be used multiple
times for one log().
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
The main reason of this design: need to combine results for already processed
tensors with the result of the new tensor.
"""
......@@ -97,7 +100,7 @@ class _Buffer:
# save stats for tensor to tmp buffer
for stat_name in self.stats_to_compute:
fn, _ = STATS[stat_name]
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor)
self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
# [num_buffers, num_stats]
buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0)
......@@ -108,7 +111,7 @@ class _Buffer:
self._new_buffer[stats_to_num[stat_name]] = combinator(buffers)
else:
fn = STATS[stat_name][0]
self._new_buffer[stats_to_num[stat_name]] = fn(tensor)
self._new_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict)
self._buffer.copy_(self._new_buffer)
......@@ -127,7 +130,6 @@ class _Buffer:
for stat_name in self.stats_to_log:
combiner = STATS[stat_name][1]
stat_value = combiner(gathered_helper_stats)
MetricLogger.log_scalar(
f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration
)
......@@ -194,11 +196,18 @@ class StatsBuffers:
self.buffers[(layer_name, tensor_name, options)] = buffer
self.reduction_group_to_buffer[reduction_group].append(buffer)
def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction):
"""Feeds the tensor into the respective buffer."""
def feed(
self, layer_name, tensor_name, options, tensor, iteration, skip_reduction, aux_dict=None
):
"""
Feeds the tensor into the respective buffer.
The aux_dict is used to share common computation between different stats.
For example for LogFp8TensorStats in can contain quantized tensors in different precisions.
"""
self.at_least_one_layer_fed = True
buffer = self.buffers[(layer_name, tensor_name, options)]
buffer.feed(tensor, iteration)
buffer.feed(tensor, iteration, aux_dict)
buffer.skip_reduction = skip_reduction
def log_stats(self):
......
......@@ -8,8 +8,9 @@ Mathematical functions used to tensor statistics computation.
import math
import torch
MAX_FP8_VALUE_INT8 = 126
import torch.nn.functional as F
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Format
@torch.compile
......@@ -49,6 +50,29 @@ def compute_std(variances, numels, sums):
return torch.sqrt(compute_variance(variances, numels, sums))
def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor):
"""Computes the overflows of the tensor."""
scale_inv = quantized_tensor._scale_inv
dtype = quantized_tensor._fp8_dtype
# Map each supported FP8 dtype to its corresponding max forward value.
dtype_to_max = {
tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd,
tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd,
}
if dtype not in dtype_to_max:
raise ValueError(
f"Unsupported FP8 dtype {dtype} passed to compute_fp8_delayed_scaling_overflows_num()."
)
fp8_max = dtype_to_max[dtype]
fp8_min = -fp8_max
overflows = (tensor > fp8_max * scale_inv) | (tensor < fp8_min * scale_inv)
return overflows.sum()
# buffers is tensor of shape [nr_buffers, nr_stats]
def _get(buffers, stat_name):
stat_nr = stats_to_num[stat_name]
......@@ -68,10 +92,12 @@ stats_to_num = {
"cur_amax": 9,
"dynamic_range_top": 10,
"dynamic_range_bottom": 11,
"underflows_num": 12,
"std": 13,
"dynamic_range": 14,
"underflows%": 15,
"std": 12,
"dynamic_range": 13,
"fp8_delayed_scaling_overflows_num": 14,
"fp8_delayed_scaling_overflows%": 15,
"overflows_num": 16,
"overflows%": 17,
}
DEPENDENCIES = {
......@@ -87,62 +113,207 @@ DEPENDENCIES = {
"cur_amax": {"cur_amax"},
"dynamic_range_top": {"dynamic_range_top"},
"dynamic_range_bottom": {"dynamic_range_bottom"},
"underflows_num": {"underflows_num"},
"std": {"variance", "numel", "sum"},
"dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"},
"underflows%": {"underflows_num", "numel"},
"fp8_delayed_scaling_overflows_num": {"fp8_delayed_scaling_overflows_num"},
"fp8_delayed_scaling_overflows%": {"fp8_delayed_scaling_overflows_num", "numel"},
"overflows_num": {"overflows_num"},
"overflows%": {"overflows_num", "numel"},
}
STATS = {
"min": (torch.min, lambda buffers: min(_get(buffers, "min"))),
"max": (torch.max, lambda buffers: max(_get(buffers, "max"))),
"sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))),
"mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))),
"min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))),
"max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))),
"sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))),
"mean": (
lambda x, aux_dict: torch.mean(x),
lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")),
),
"numel": (
lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(),
lambda buffers: sum(_get(buffers, "numel")),
),
"l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))),
"l1_norm": (
lambda x, aux_dict: torch.norm(x, p=1),
lambda buffers: sum(_get(buffers, "l1_norm")),
),
"l2_norm_square": (
lambda x: torch.sum(x**2),
lambda x, aux_dict: torch.sum(x**2),
lambda buffers: sum(_get(buffers, "l2_norm_square")),
),
"l2_norm": (
lambda x: torch.norm(x, p=2),
lambda x, aux_dict: torch.norm(x, p=2),
lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))),
),
"variance": (
torch.var,
lambda x, aux_dict: torch.var(x),
lambda buffers: compute_variance(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))),
"dynamic_range_top": (
_compute_dynamic_range_top,
lambda x, aux_dict: _compute_dynamic_range_top(x),
lambda buffers: max(_get(buffers, "dynamic_range_top")),
),
"dynamic_range_bottom": (
_compute_dynamic_range_bottom,
lambda x, aux_dict: _compute_dynamic_range_bottom(x),
lambda buffers: min(_get(buffers, "dynamic_range_bottom")),
),
"underflows_num": (
lambda x: (x.get_data_tensors()[0] == 0).sum(),
lambda buffers: sum(_get(buffers, "underflows_num")),
),
"std": (
torch.std,
lambda x, aux_dict: torch.std(x),
lambda buffers: compute_std(
_get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum")
),
),
"dynamic_range": (
lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x),
lambda buffers: max(_get(buffers, "dynamic_range_top"))
- min(_get(buffers, "dynamic_range_bottom")),
),
"underflows%": (
lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100,
lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")),
"fp8_delayed_scaling_overflows_num": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
x, aux_dict["fp8_delayed_scaling"]
),
lambda buffers: sum(_get(buffers, "fp8_delayed_scaling_overflows_num")),
),
"fp8_delayed_scaling_overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(
x, aux_dict["fp8_delayed_scaling"]
)
/ x.numel()
* 100,
lambda buffers: 100
* sum(_get(buffers, "fp8_delayed_scaling_overflows_num"))
/ sum(_get(buffers, "numel")),
),
"overflows_num": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]),
lambda buffers: sum(_get(buffers, "overflows_num")),
),
"overflows%": (
lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""])
/ x.numel()
* 100,
lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")),
),
}
def add_underflows_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* underflow stats (num and %) for the given recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
# Stat names
stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}"
stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}"
stats_to_num[stat_num] = len(stats_to_num)
stats_to_num[stat_pct] = len(stats_to_num)
STATS[stat_num] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
- (x == 0).sum(),
lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)),
)
STATS[stat_pct] = (
lambda x, aux_dict: (
aux_dict[recipe_name].get_data_tensors(
rowwise_data=not columnwise, columnwise_data=columnwise
)
== 0
).sum()
/ aux_dict[recipe_name].numel()
* 100,
lambda buffers, _sn_num=stat_num: 100
* sum(_get(buffers, _sn_num))
/ sum(_get(buffers, "numel")),
)
DEPENDENCIES[stat_num] = {stat_num}
DEPENDENCIES[stat_pct] = {stat_num, "numel"}
def add_scale_inv_stats(recipe_name: str, columnwise: bool = False):
"""Register *both* scale-inv min and max stats for a given recipe.
This replaces the earlier separate helpers and avoids duplicated boilerplate.
"""
# Determine which attribute holds the scale-inverse tensor.
def get_scale_inv(quantized_tensor, columnwise):
if hasattr(quantized_tensor, "_scale_inv"):
return getattr(quantized_tensor, "_scale_inv")
if columnwise:
return getattr(quantized_tensor, "_columnwise_scale_inv")
return getattr(quantized_tensor, "_rowwise_scale_inv")
columnwise_suffix = "_columnwise" if columnwise else ""
# Prepare stat names.
stat_name_min = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}"
)
stat_name_max = (
f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}"
)
# Assign indices in `stats_to_num` (order matters — keep insertion order deterministic).
stats_to_num[stat_name_min] = len(stats_to_num)
stats_to_num[stat_name_max] = len(stats_to_num)
# Capture the attribute name inside lambdas via default args to avoid late binding.
STATS[stat_name_min] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(),
lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)),
)
STATS[stat_name_max] = (
lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(),
lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)),
)
DEPENDENCIES[stat_name_min] = {stat_name_min}
DEPENDENCIES[stat_name_max] = {stat_name_max}
def add_mse_stats(recipe_name: str, columnwise: bool = False):
"""Register mse and total_square_error stats for the recipe."""
columnwise_suffix = "_columnwise" if columnwise else ""
stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}"
stat_err = (
f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}"
)
stats_to_num[stat_mse] = len(stats_to_num)
stats_to_num[stat_err] = len(stats_to_num)
STATS[stat_mse] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err))
/ sum(_get(buffers, "numel")),
)
STATS[stat_err] = (
lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"),
lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)),
)
DEPENDENCIES[stat_err] = {stat_err}
DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"}
for _columnwise in [True, False]:
for _recipe_name in [
"", # default recipe
"fp8_delayed_scaling",
"mxfp8",
"fp8_current_scaling",
"fp8_block_scaling",
]:
add_underflows_stats(_recipe_name, _columnwise)
add_scale_inv_stats(_recipe_name, _columnwise)
add_mse_stats(_recipe_name, _columnwise)
......@@ -156,7 +156,6 @@ class DebugQuantizer(Quantizer):
gemm=self.columnwise_gemm_name,
)
)
return (
inspect_tensor_enabled,
inspect_tensor_postquantize_enabled_rowwise,
......@@ -259,6 +258,9 @@ class DebugQuantizer(Quantizer):
"tensor_name": self.tensor_name,
"iteration": TEDebugState.get_iteration(),
"tp_group": self.tp_group,
"columnwise_quantized_tensor": columnwise_gemm_tensor,
"rowwise_quantized_tensor": rowwise_gemm_tensor,
"quantizer": self.parent_quantizer,
}
if tensor is not None and self.inspect_tensor_enabled:
debug_api.transformer_engine.inspect_tensor(**args)
......@@ -266,6 +268,10 @@ class DebugQuantizer(Quantizer):
if self.output_tensor:
return
del args["columnwise_quantized_tensor"]
del args["rowwise_quantized_tensor"]
del args["quantizer"]
if (
self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_rowwise
......@@ -273,6 +279,7 @@ class DebugQuantizer(Quantizer):
args["tensor"] = rowwise_gemm_tensor
args["rowwise"] = True
debug_api.transformer_engine.inspect_tensor_postquantize(**args)
if (
self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE]
and self.inspect_tensor_postquantize_enabled_columnwise
......@@ -398,6 +405,7 @@ class DebugQuantizer(Quantizer):
"""Returns bool if there is at least one API call enabled."""
if self.output_tensor:
return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY
# pylint: disable=too-many-boolean-expressions
if (
self.inspect_tensor_enabled
or self.inspect_tensor_postquantize_enabled_rowwise
......
......@@ -1138,6 +1138,10 @@ def _all_gather_fp8_blockwise(
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
quantizer.all_gather_usage = orig_all_gather_usage
# Begin to do network communication, need to make sure compact format
......@@ -1147,9 +1151,6 @@ def _all_gather_fp8_blockwise(
f"but found data_format={inp._data_format}"
)
# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
# Coalesce NCCL collectives
with torch.distributed._coalescing_manager(
group=process_group,
......
......@@ -124,9 +124,15 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor:
"""Takes dequantized columnwise data and permutes to a rowwise shape"""
......
......@@ -128,9 +128,15 @@ class Float8TensorBase(QuantizedTensorBase):
self._scale_inv = tensors[2]
return tensors[3:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
return self._data, self._transpose
if rowwise_data and columnwise_data:
return self._data, self._transpose
if rowwise_data:
return self._data
if columnwise_data:
return self._transpose
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
......
......@@ -136,9 +136,15 @@ class MXFP8TensorBase(QuantizedTensorBase):
self._columnwise_scale_inv = tensors[3]
return tensors[4:]
def get_data_tensors(self):
def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True):
"""Get this Tensor's data."""
return self._rowwise_data, self._columnwise_data
if rowwise_data and columnwise_data:
return self._rowwise_data, self._columnwise_data
if rowwise_data:
return self._rowwise_data
if columnwise_data:
return self._columnwise_data
raise ValueError("No data to get, both rowwise_data and columnwise_data are False")
def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor:
"""Dequantize to a higher precision."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment