pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel # BENCHMARK_BASELINE_OUTPUT_START Baseline Flax: Mean time: 86.580 ms # BENCHMARK_BASELINE_OUTPUT_END # BENCHMARK_TE_UNFUSED_OUTPUT_START TE Unfused: Mean time: 42.252 ms # BENCHMARK_TE_UNFUSED_OUTPUT_END # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START TE Unfused + TE Attention: Mean time: 35.054 ms # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START TE Unfused + TE Attention + FP8: Mean time: 22.638 ms # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END # BENCHMARK_TE_FUSED_FP8_OUTPUT_START TE Fused + TE Attention + FP8: Mean time: 23.703 ms # BENCHMARK_TE_FUSED_FP8_OUTPUT_END # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START TE TransformerLayer + FP8: Mean time: 22.812 ms # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END Summary written to getting_started_jax_summary.csv