Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
...@@ -53,7 +53,8 @@ jobs: ...@@ -53,7 +53,8 @@ jobs:
|| github.actor == 'lhb8125' || github.actor == 'lhb8125'
|| github.actor == 'kunlunl' || github.actor == 'kunlunl'
|| github.actor == 'pstjohn' || github.actor == 'pstjohn'
|| github.actor == 'mk-61' || github.actor == 'vcherepanov-nv'
|| github.actor == 'tdophung'
) )
steps: steps:
- name: Check if comment is issued by authorized person - name: Check if comment is issued by authorized person
......
...@@ -9,11 +9,11 @@ import numpy as np ...@@ -9,11 +9,11 @@ import numpy as np
import torch import torch
import nvtx import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.utils import (
ModelConfig, ModelConfig,
_get_attention_backends, get_available_attention_backends,
_run_dot_product_attention,
) )
from tests.pytorch.attention.test_attention import _run_dot_product_attention
pd.set_option("display.precision", 4) pd.set_option("display.precision", 4)
...@@ -197,7 +197,7 @@ def main(): ...@@ -197,7 +197,7 @@ def main():
) )
for model in model_configs.keys(): for model in model_configs.keys():
config = model_configs[model] config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
...@@ -247,7 +247,7 @@ if __name__ == "__main__": ...@@ -247,7 +247,7 @@ if __name__ == "__main__":
num_gemms_list = [8] num_gemms_list = [8]
if args.profile: if args.profile:
mkns = [(4096, 4096, 4096)] mkns = [(4096 * 8, 4096, 4096)]
# in profile mode, only run one recipe specified in args.recipe # in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", ( assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as" "In profile mode, only one recipe can be specified, please specify the recipe as"
......
...@@ -14,19 +14,7 @@ from typing import List ...@@ -14,19 +14,7 @@ from typing import List
def install_requirements() -> List[str]: def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions.""" """Install dependencies for TE/PyTorch extensions."""
reqs = ["torch>=2.1", "einops"] return ["torch>=2.1", "einops"] # "onnxscript==0.3.1", "onnx"]
# reqs.append(
# "nvdlfw-inspect @"
# " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
# )
reqs.extend(
[
"torch>=2.1",
# "onnx",
# "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
def test_requirements() -> List[str]: def test_requirements() -> List[str]:
......
...@@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`. ...@@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded. * JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.
Checkpointing
------------------------------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.
Modules Modules
------------------------------------ ------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
......
...@@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea ...@@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
There are 4 things one needs to do to use Transformer Engine debug features: There are 4 things one needs to do to use Transformer Engine debug features:
1. Create a configuration YAML file to configure the desired features. 1. Create a configuration YAML file to configure the desired features.
2. Import, and initialize the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool, which is installed as the dependency of the Transformer Engine. 2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool.
3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. 3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically.
4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. 4. Invoke ``debug_api.step()`` at the end of one forward-backward pass.
...@@ -141,7 +141,7 @@ Adjusting Python file ...@@ -141,7 +141,7 @@ Adjusting Python file
In the modified code above, the following changes were made: In the modified code above, the following changes were made:
1. Added an import for ``nvdlfw_inspect.api``. 1. Added an import for ``nvdlfw_inspect.api``.
2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. 2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. The directory with Transformer Engine features is located `here <https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/debug/features>`_. The full parameters description could be found :doc:`here <3_api_debug_setup>`.
3. Added ``debug_api.step()`` after each of the forward-backward pass. 3. Added ``debug_api.step()`` after each of the forward-backward pass.
Inspecting the logs Inspecting the logs
......
...@@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w ...@@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w
Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing. Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing.
In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below. In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed.
.. figure:: ./img/api_calls2.svg
:align: center
Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls.
There are 2 categories of API calls, each is used for different purposes: There are 2 categories of API calls, each is used for different purposes:
- GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them, - GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them,
...@@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay ...@@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled
.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled
<svg width="4235" height="2342" xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" xml:space="preserve" overflow="hidden"><g transform="translate(-41 -119)"><g><rect x="46.4999" y="1576.5" width="1564" height="734" stroke="#042433" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><rect x="630.5" y="125.5" width="580" height="151" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 786.823 225)">Tensor A</text><rect x="303.5" y="337.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 684.099 414)">inspect_tensor</text><rect x="1258.5" y="596.5" width="617" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1440.36 673)">fp8 </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1557.81 673)">cast</text><rect x="114.5" y="596.5" width="683" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 227.611 673)">modify_tensor</text><rect x="303.5" y="826.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 460.753 903)">inspect_tensor_postquantize</text><rect x="1583.5" y="1123.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2095.73 1200)">GEMM</text><rect x="1583.5" y="1310.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1963.85 1387)">inspect_tensor</text><rect x="1859.5" y="1499.5" width="682" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 1972.18 1576)">modify_tensor</text><rect x="115.5" y="1956.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 434.13 2033)">inspect_tensor_enabled</text><rect x="115.5" y="2103.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 210.785 2180)">inspect_tensor_postquantize_enabled</text><rect x="115.5" y="1660.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 502.961 1737)">fp8_gemm_enabled</text><rect x="115.5" y="1808.5" width="1402" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 442.461 1885)">modify_tensor_enabled</text><path d="M1.07643-3.26461 444.129 142.822 441.977 149.351-1.07643 3.26461ZM443.006 131.593 464.817 153.263 434.395 157.71Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 920.317 443.5)"/><path d="M921.293 440.155 1545.58 588.133 1543.99 594.822 919.707 446.845ZM1543.5 577.041 1567.09 596.763 1537.16 603.8Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M923.938 443.5 923.938 803.572 917.063 803.572 917.063 443.5ZM934.25 798.988 920.5 826.488 906.75 798.988Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M456.384 699.178 899.056 817.032 897.288 823.676 454.616 705.822ZM897.281 805.888 920.317 826.25 890.206 832.462Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.646175-3.37622 624.723 116.066 623.431 122.818-0.646175 3.37622ZM622.16 105.076 646.585 123.75 616.991 132.085Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 1567.09 702.5)"/><rect x="2945.5" y="125.5" width="579" height="151" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3100.27 225)">Tensor B</text><rect x="2617.5" y="337.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2998.12 414)">inspect_tensor</text><rect x="3572.5" y="596.5" width="617" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#FFFFFF" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3754.39 673)">fp8 </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3871.84 673)">cast</text><rect x="2428.5" y="596.5" width="683" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2541.64 673)">modify_tensor</text><rect x="2617.5" y="826.5" width="1234" height="106" stroke="#000000" stroke-width="6.875" stroke-linecap="butt" stroke-linejoin="miter" stroke-miterlimit="8" stroke-opacity="1" fill="#DCEAF7" fill-opacity="1"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 2774.78 903)">inspect_tensor_postquantize</text><path d="M1.07643-3.26461 444.129 142.822 441.976 149.351-1.07643 3.26461ZM443.006 131.593 464.817 153.263 434.394 157.71Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3234.32 443.5)"/><path d="M3235.29 440.155 3859.58 588.133 3857.99 594.822 3233.71 446.845ZM3857.5 577.041 3881.09 596.763 3851.16 603.8Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M3237.94 443.5 3237.94 803.572 3231.06 803.572 3231.06 443.5ZM3248.25 798.988 3234.5 826.488 3220.75 798.988Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M2770.38 699.178 3213.06 817.032 3211.29 823.676 2768.62 705.822ZM3211.28 805.888 3234.32 826.25 3204.21 832.462Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.646175-3.37622 624.723 116.066 623.431 122.818-0.646175 3.37622ZM622.16 105.076 646.585 123.75 616.991 132.085Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3881.09 702.5)"/><path d="M921.009 929.1 2178.11 1117.2 2177.09 1124 919.991 935.9ZM2175.09 1106.33 2200.26 1123.99 2171.02 1133.52Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M0.625813-3.38005 1012.36 183.941 1011.11 190.702-0.625813 3.38005ZM1009.73 172.967 1034.27 191.493 1004.72 200.007Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 3234.77 932.5)"/><path d="M3.4375-1.54131e-05 3.43776 57.5713-3.43724 57.5714-3.4375 1.54131e-05ZM13.7502 52.988 0.000360892 80.488-13.7498 52.9881Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 2200.5 1229.5)"/><path d="M3.4375-1.54131e-05 3.43776 57.5713-3.43724 57.5714-3.4375 1.54131e-05ZM13.7502 52.988 0.000360892 80.488-13.7498 52.9881Z" fill="#000000" fill-rule="nonzero" fill-opacity="1" transform="matrix(-1 0 0 1 2200.5 1418.5)"/><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 623.083 2394)">Routing </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 882.041 2394)">calls</text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3298.55 1286)">GEMM </text><text fill="#000000" fill-opacity="1" font-family="Aptos,Aptos_MSFontService,sans-serif" font-style="normal" font-variant="normal" font-weight="400" font-stretch="normal" font-size="73" text-anchor="start" direction="ltr" writing-mode="lr-tb" unicode-bidi="normal" text-decoration="none" transform="matrix(1 0 0 1 3536.88 1286)">calls</text><path d="M923.938 276.5 923.938 314.619 917.063 314.62 917.063 276.5ZM934.25 310.036 920.5 337.536 906.75 310.036Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/><path d="M3237.94 276.5 3237.94 314.619 3231.06 314.62 3231.06 276.5ZM3248.25 310.036 3234.5 337.536 3220.75 310.036Z" fill="#000000" fill-rule="nonzero" fill-opacity="1"/></g></g></svg>
\ No newline at end of file
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
import os import os
import torch import torch
from typing import Tuple from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import DotProductAttention
# Initialize RNG state # Initialize RNG state
......
...@@ -375,7 +375,7 @@ ...@@ -375,7 +375,7 @@
"\n", "\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n", "\n",
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
] ]
}, },
{ {
...@@ -394,10 +394,10 @@ ...@@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n", "\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
] ]
}, },
{ {
...@@ -458,7 +458,7 @@ ...@@ -458,7 +458,7 @@
" </tr>\n", " </tr>\n",
"</table>\n", "</table>\n",
"\n", "\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n", "\n",
"<div class=\"alert alert-info\">\n", "<div class=\"alert alert-info\">\n",
"<b>Note</b>\n", "<b>Note</b>\n",
...@@ -548,7 +548,7 @@ ...@@ -548,7 +548,7 @@
"id": "dda4a589", "id": "dda4a589",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n", "\n",
"### 3.3 Attention Bias\n", "### 3.3 Attention Bias\n",
"\n", "\n",
...@@ -594,7 +594,7 @@ ...@@ -594,7 +594,7 @@
"\n", "\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n", "\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
] ]
}, },
{ {
...@@ -612,7 +612,7 @@ ...@@ -612,7 +612,7 @@
"\n", "\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n", "\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
] ]
} }
], ],
......
...@@ -9,11 +9,11 @@ import numpy as np ...@@ -9,11 +9,11 @@ import numpy as np
import torch import torch
import nvtx import nvtx
import transformer_engine import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import ( from tests.pytorch.utils import (
ModelConfig, ModelConfig,
_get_attention_backends, get_available_attention_backends,
_run_dot_product_attention,
) )
from tests.pytorch.attention.test_attention import _run_dot_product_attention
# data type # data type
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -90,7 +90,7 @@ def main(): ...@@ -90,7 +90,7 @@ def main():
models = ["test_0"] models = ["test_0"]
for model in models: for model in models:
config = model_configs[model] config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends( available_backends, _, fused_attn_backends = get_available_attention_backends(
config, config,
qkv_dtype=dtype, qkv_dtype=dtype,
qkv_layout=qkv_layout, qkv_layout=qkv_layout,
......
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Export to ONNX and inference using TensorRT\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note:</b>\n",
"\n",
"Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n",
"\n",
"</div>\n",
"\n",
"Transformer Engine (TE) is a library designed primarily for training DL models in low precision. It is not specifically optimized for inference tasks, so other dedicated solutions should be used. NVIDIA provides several [inference tools](https://www.nvidia.com/en-us/solutions/ai/inference/) that enhance the entire inference pipeline. Two prominent NVIDIA inference SDKs are [TensorRT](https://github.com/NVIDIA/TensorRT) and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).\n",
"\n",
"This tutorial illustrates how one can export a PyTorch model to ONNX format and subsequently perform inference with TensorRT. This approach is particularly beneficial if model integrates Transformer Engine layers within more complex architectures. It's important to highlight that for Transformer-based large language models (LLMs), TensorRT-LLM could provide a more optimized inference experience. However, the ONNX-to-TensorRT approach described here may be more suitable for other models, such as diffusion-based architectures or vision transformers.\n",
"\n",
"#### Creating models with TE\n",
"\n",
"Let's begin by defining a simple model composed of layers both from Transformer Engine and standard PyTorch:\n"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch \n",
"import torch.nn as nn\n",
"import transformer_engine as te\n",
"import warnings\n",
"warnings.filterwarnings(\"ignore\")\n",
"\n",
"# batch size, sequence length, hidden dimension\n",
"B, S, H = 256, 512, 256\n",
"\n",
"class Model(torch.nn.Module):\n",
" def __init__(self, hidden_dim=H, num_non_te_layers=16, num_te_layers=4, num_te_heads=4):\n",
" super(Model, self).__init__()\n",
" self.non_te_part = nn.Sequential(\n",
" *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_non_te_layers)]\n",
" )\n",
" self.te_part = nn.Sequential(\n",
" *[te.pytorch.TransformerLayer(hidden_dim, hidden_dim, num_te_heads) for _ in range(num_te_layers)]\n",
" )\n",
"\n",
" def forward(self, x):\n",
" x = self.non_te_part(x)\n",
" return self.te_part(x)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run some simple inference benchmarks:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average inference time FP32: 0.065 ms\n",
"Average inference time FP8: 0.062 ms\n"
]
}
],
"source": [
"from utils import _measure_time\n",
"\n",
"model = Model().eval().cuda()\n",
"inps = (torch.randn([S, B, H], device=\"cuda\"),)\n",
"def _inference(fp8_enabled):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n",
" model(*inps)\n",
"\n",
"te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n",
"te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))\n",
"\n",
"print(f\"Average inference time FP32: {te_fp32_time} ms\")\n",
"print(f\"Average inference time FP8: {te_fp8_time} ms\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Exporting the TE Model to ONNX Format\n",
"\n",
"PyTorch developed a new [ONNX exporter](https://pytorch.org/docs/stable/onnx.html) built on TorchDynamo and plans to phase out the existing TorchScript exporter. As this feature is currently in active development, we recommend running this process with the latest PyTorch version.\n",
"\n",
"\n",
"To export a Transformer Engine model into ONNX format, follow these steps:\n",
"\n",
"- Conduct warm-up run within autocast using the recipe intended for export.\n",
"- Encapsulate your export-related code within `te.onnx_export`, ensuring warm-up runs remain outside this wrapper.\n",
"- Use the PyTorch Dynamo ONNX exporter by invoking: `torch.onnx.export(..., dynamo=True)`."
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Exporting model_fp8.onnx\n",
"[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n",
"[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n",
"[torch.onnx] Run decomposition...\n",
"[torch.onnx] Run decomposition... ✅\n",
"[torch.onnx] Translate the graph into ONNX...\n",
"[torch.onnx] Translate the graph into ONNX... ✅\n",
"Applied 12 of general pattern rewrite rules.\n",
"Exporting model_fp32.onnx\n",
"[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n",
"[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n",
"[torch.onnx] Run decomposition...\n",
"[torch.onnx] Run decomposition... ✅\n",
"[torch.onnx] Translate the graph into ONNX...\n",
"[torch.onnx] Translate the graph into ONNX... ✅\n",
"Applied 12 of general pattern rewrite rules.\n"
]
}
],
"source": [
"from transformer_engine.pytorch.export import te_translation_table\n",
"\n",
"def export(model, fname, inputs, fp8=True):\n",
" with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n",
" # ! IMPORTANT !\n",
" # Transformer Engine models must have warm-up run\n",
" # before export. FP8 recipe during warm-up should \n",
" # match the recipe used during export.\n",
" model(*inputs)\n",
" \n",
" # Only dynamo=True mode is supported;\n",
" # dynamo=False is deprecated and unsupported.\n",
" #\n",
" # te_translation_table contains necessary ONNX translations\n",
" # for FP8 quantize/dequantize operators.\n",
" print(f\"Exporting {fname}\")\n",
" with te.pytorch.onnx_export(enabled=True):\n",
" torch.onnx.export(\n",
" model,\n",
" inputs,\n",
" fname,\n",
" output_names=[\"output\"],\n",
" dynamo=True,\n",
" custom_translation_table=te_translation_table\n",
" )\n",
"\n",
"# Example usage:\n",
"export(model, \"model_fp8.onnx\", inps, fp8=True)\n",
"export(model, \"model_fp32.onnx\", inps, fp8=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### Inference with TensorRT\n",
"\n",
"TensorRT is a high-performance deep learning inference optimizer and runtime developed by NVIDIA. It enables optimized deployment of neural network models by maximizing inference throughput and reducing latency on NVIDIA GPUs. TensorRT performs various optimization techniques, including layer fusion, precision calibration, kernel tuning, and memory optimization. \n",
"For detailed information and documentation, refer to the official [TensorRT documentation](https://developer.nvidia.com/tensorrt).\n",
"\n",
"When using TensorRT, ONNX model must first be compiled into a TensorRT engine. This compilation step involves converting the ONNX model into an optimized representation tailored specifically to the target GPU platform. The compiled engine file can then be loaded into applications for rapid and efficient inference execution."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"!trtexec --onnx=model_fp32.onnx --saveEngine=model_fp32.engine > output_fp32.log 2>&1\n",
"!trtexec --onnx=model_fp8.onnx --saveEngine=model_fp8.engine > output_fp8.log 2>&1"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's run the benchmarks for inference:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Average inference time without TRT (FP32 for all layers): 0.065 ms\n",
"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): 0.062 ms, speedup = 1.05x\n",
"Average inference time with TRT (FP32 for all layers): 0.0500 ms, speedup = 1.30x\n",
"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): 0.0470 ms, speedup = 1.38x\n"
]
}
],
"source": [
"import tensorrt as trt\n",
"\n",
"# Output tensor is allocated - TRT needs static memory address.\n",
"output_tensor = torch.empty_like(model(*inps))\n",
"\n",
"# Loads TRT engine from file.\n",
"def load_engine(engine_file_path):\n",
" logger = trt.Logger(trt.Logger.WARNING)\n",
" runtime = trt.Runtime(logger)\n",
" \n",
" with open(engine_file_path, \"rb\") as f:\n",
" engine_data = f.read()\n",
" \n",
" engine = runtime.deserialize_cuda_engine(engine_data)\n",
" return engine\n",
"\n",
"def benchmark_inference(model_name):\n",
" engine = load_engine(model_name)\n",
" context = engine.create_execution_context()\n",
" stream = torch.cuda.Stream()\n",
" \n",
" # TRT need static input and output addresses.\n",
" # Here they are set.\n",
" for i in range(len(inps)):\n",
" context.set_tensor_address(engine.get_tensor_name(i), inps[i].data_ptr()) \n",
" context.set_tensor_address(\"output\", output_tensor.data_ptr())\n",
" \n",
" def _inference():\n",
" # The data is loaded from static input addresses\n",
" # and output is written to static output address.\n",
" context.execute_async_v3(stream_handle=stream.cuda_stream)\n",
" stream.synchronize()\n",
" \n",
" return _measure_time(_inference)\n",
"\n",
"\n",
"trt_fp8_time = benchmark_inference(\"model_fp8.engine\")\n",
"trt_fp32_time = benchmark_inference(\"model_fp32.engine\")\n",
"\n",
"print(f\"Average inference time without TRT (FP32 for all layers): {te_fp32_time} ms\")\n",
"print(f\"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): {te_fp8_time} ms, speedup = {te_fp32_time/te_fp8_time:.2f}x\")\n",
"print(f\"Average inference time with TRT (FP32 for all layers): {trt_fp32_time:.4f} ms, speedup = {te_fp32_time/trt_fp32_time:.2f}x\")\n",
"print(f\"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): {trt_fp8_time:.4f} ms, speedup = {te_fp32_time/trt_fp8_time:.2f}x\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<p>\n",
"\n",
"\n",
"| Run | Inference Time (ms) | Speedup |\n",
"| ----------------------------------| ------------------- | ------------------- |\n",
"| PyTorch + TE | 0.065 | 1.00x |\n",
"| PyTorch + TE (FP8 for TE layers) | 0.062 | 1.05x |\n",
"| TRT | 0.0500 | 1.30x |\n",
"| TRT (FP8 for TE layers) | 0.047 | 1.38x |\n",
"\n",
"Note that this example highlights how TensorRT can speed up models composed of both TE and non-TE layers.\n",
"If a larger part of the model's layers were implemented with TE, the benefits of using FP8 for inference could be greater.\n",
"\n",
"</p>\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We clearly observe performance improvements when using FP8 and the TensorRT inference engine. These improvements may become even more significant with more complex models, as TensorRT could potentially identify additional optimization opportunities.\n",
"\n",
"#### Appendix: Low Precision Operators in ONNX and TensorRT\n",
"\n",
"The ONNX standard does not currently support all precision types provided by the Transformer Engine. All available ONNX operators are listed on [this website](https://onnx.ai/onnx/operators/). Consequently, TensorRT and the Transformer Engine utilize certain specialized low-precision operators, detailed below.\n",
"\n",
"**TRT_FP8_QUANTIZE**\n",
"\n",
"- **Name**: TRT_FP8_QUANTIZE\n",
"- **Domain**: trt\n",
"- **Inputs**:\n",
" - `x`: float32 tensor\n",
" - `scale`: float32 scalar\n",
"- **Outputs**:\n",
" - `y`: int8 tensor\n",
"\n",
"Produces an int8 tensor that represents the binary encoding of FP8 values.\n",
"\n",
"**TRT_FP8_DEQUANTIZE**\n",
"\n",
"- **Name**: TRT_FP8_DEQUANTIZE\n",
"- **Domain**: trt\n",
"- **Inputs**:\n",
" - `x`: int8 tensor\n",
" - `scale`: float32 scalar\n",
"- **Outputs**:\n",
" - `y`: float32 tensor\n",
"\n",
"Converts FP8-encoded int8 tensor data back into float32 precision.\n",
"\n",
"<div class=\"alert alert-info\">\n",
"\n",
"<b>Note:</b>\n",
"\n",
"Since standard ONNX operators do not support certain input and output precision types, a workaround is employed: tensors are dequantized to higher precision (float32) before input into these operators or quantized to lower precision after processing. TensorRT recognizes such quantize-dequantize patterns and replaces them with optimized operations. More details are available in [this section](https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quantized-types.html#tensorrt-processing-of-q-dq-networks) of the TensorRT documentation.\n",
"\n",
"</div>"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""
Utility functions for ONNX export.
"""
import time
import torch
def _measure_time(f):
time_taken = []
num_iterations = 10
f() # warm-up
for _ in range(num_iterations):
start_time = time.time()
f()
torch.cuda.synchronize()
end_time = time.time()
time_taken.append(end_time - start_time)
return round(sum(time_taken) / num_iterations, 3)
...@@ -46,6 +46,7 @@ Transformer Engine documentation ...@@ -46,6 +46,7 @@ Transformer Engine documentation
examples/fp8_primer.ipynb examples/fp8_primer.ipynb
examples/advanced_optimizations.ipynb examples/advanced_optimizations.ipynb
examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb
examples/onnx/onnx_export.ipynb
.. toctree:: .. toctree::
:hidden: :hidden:
......
...@@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" ...@@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs} : ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR" mkdir -p "$XML_LOG_DIR"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"
pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" ...@@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls # Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
if [ $RET -ne 0 ]; then if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES" echo "Error: some sub-tests failed: $FAILED_CASES"
......
...@@ -26,23 +26,24 @@ pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine- ...@@ -26,23 +26,24 @@ pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine-
VERSION=`cat $TE_PATH/build_tools/VERSION.txt` VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}" WHL_BASE="transformer_engine-${VERSION}"
# Core wheel. # Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*" wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl" rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage"
cd transformer_engine/jax cd transformer_engine/jax
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist"
pip3 install dist/* || error_exit "Failed to install dist/*" pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH cd $TE_PATH
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py"
......
...@@ -14,14 +14,23 @@ ...@@ -14,14 +14,23 @@
FAIL=0 FAIL=0
# It is not installed as a requirement,
# because it is not available on PyPI.
pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git
pip install pytest==8.2.1 pip install pytest==8.2.1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
# standard sanity and numerics tests with initialized debug # standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1
exit $FAIL exit $FAIL
...@@ -50,8 +50,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ ...@@ -50,8 +50,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
......
...@@ -27,22 +27,22 @@ VERSION=`cat $TE_PATH/build_tools/VERSION.txt` ...@@ -27,22 +27,22 @@ VERSION=`cat $TE_PATH/build_tools/VERSION.txt`
WHL_BASE="transformer_engine-${VERSION}" WHL_BASE="transformer_engine-${VERSION}"
# Core wheel. # Core wheel.
NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel"
wheel unpack dist/* || error_exit "Failed to unpack dist/*" wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl"
sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA"
mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info"
wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}"
rm dist/*.whl || error_exit "Failed to remove dist/*.whl" rm dist/*.whl || error_exit "Failed to remove dist/*.whl"
mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/"
NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage"
cd transformer_engine/pytorch cd transformer_engine/pytorch
NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist"
pip3 install dist/* || error_exit "Failed to install dist/*" pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*"
cd $TE_PATH cd $TE_PATH
pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps"
python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment