Unverified Commit ba37529c authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

FP8 Output Quantization for GEMM (#2123)



* Test working as I think it should work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert accidental change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

bug: missed a } in the code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

missed a quant code removal
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug fix
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

minor code cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor cosmetics
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Address review comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor comment update
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

minor bug: quantizer should not be none for unfused quantization
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update test_multi_process_distributed_grouped_gemm.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update dense.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address revie comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Bug solved: delayed scaling quantization with mxfp8 inputs didnt work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the unit test error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* just to trigger ci
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c221909d
......@@ -39,16 +39,21 @@ from transformer_engine.pytorch import (
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend
from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import (
Float8Quantizer,
Float8CurrentScalingQuantizer,
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace
from transformer_engine.pytorch.utils import get_device_compute_capability
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states, get_available_attention_backends
# Only run FP8 tests on supported devices.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
sm_80plus = get_device_compute_capability() >= (8, 0)
......@@ -2607,6 +2612,73 @@ def test_grouped_gemm(shape, dtype, layout, accumulate):
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)
@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
"input_quantizer",
[
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
],
)
@pytest.mark.parametrize(
"out_quantizer",
[
Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"),
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3),
Float8Quantizer(
torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3
),
],
)
def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer):
# For MXFP8 and CurrentScaling, below unfused quantization should happen
# FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output
# Skip invalid configurations
is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance(
out_quantizer, MXFP8Quantizer
)
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if is_mxfp8_needed and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype))
outp_type = torch.float32
quantized_out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=out_quantizer,
bias=None,
use_split_accumulator=False,
)
out, *_ = general_gemm(
weight_fp8,
inp_fp8,
get_workspace(),
outp_type,
quantization_params=None,
bias=None,
use_split_accumulator=False,
)
expected_quantized_out = out_quantizer(out)
# Match results again Pytorch GEMM and allow for quantization tolerance
pytorch_out = torch.matmul(
inp_fp8.dequantize().to(torch.float64),
torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1),
)
fp8_tols = dict(rtol=0.125, atol=0.0675)
torch.testing.assert_close(
pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols
)
# Match results between quantization happening inside vs outside general_gemm
torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize())
@pytest.mark.parametrize(
"shape",
[
......
......@@ -579,14 +579,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
"Input and output_t must have the same shape for columnwise non-transpose case.");
}
}
NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype.");
if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) {
// output may not be defined if rowwise quantization is not needed.
NVTE_CHECK(output.dtype == output_t.dtype,
"output and output_t need to have the same dtype.");
}
NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2.");
bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT;
size_t scale_t_k = scale_inv_t.shape[1];
scale_t_stride_x = columnwise_compact ? 1 : scale_t_k;
scale_t_stride_y = columnwise_compact ? scale_t_k : 1;
}
auto output_dtype =
rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype;
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
......@@ -597,7 +602,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType,
output_dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
......
......@@ -93,6 +93,8 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
bool use_split_accumulator, CommOverlapCore* comm_overlap,
std::optional<CommOverlapType> comm_type, MaybeTensor extra_output,
bool bulk_overlap, float alpha, std::optional<float> beta) {
using namespace transformer_engine::pytorch::detail;
// Input tensors
NVTE_CHECK(!A.is_none(), "Tensor A has not been provided");
NVTE_CHECK(!B.is_none(), "Tensor B has not been provided");
......@@ -123,10 +125,10 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
"into D tensor. Beta has nothing to be applied to.");
}
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
// Output tensor
TensorWrapper D_tensor;
if (D.is_none()) {
DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype();
std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer);
} else {
D_tensor = makeTransformerEngineTensor(D, quantizer);
......@@ -139,12 +141,35 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
// maintain unquantized tensor in case we need unfused quantization support.
TensorWrapper unquantized_D_tensor;
py::object unquantized_out;
// Unfused quantization is needed in the following cases
// 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that)
// 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling,
// GEMM Output needs to be in BF16, to allow for unfused quantization)
bool unfused_quantization_needed = !quantizer.is_none();
if (low_precision) {
// At the moment, only use-case for fused GEMM:
// Delayed scaling quantizer with per-tensor scaling inputs
bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr());
if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input)
unfused_quantization_needed = false;
}
if (unfused_quantization_needed) {
NoneQuantizer q{none};
std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype);
}
TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor;
// Bias tensor
TensorWrapper bias_tensor;
MaybeTensor bias_grad = std::nullopt;
if (bias.has_value()) {
if (grad) {
auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA);
auto opts =
torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA);
bias_grad = at::empty({static_cast<int64_t>(B_shape.data[B_shape.ndim - 1])}, opts);
bias_tensor = makeTransformerEngineTensor(*bias_grad);
} else {
......@@ -157,7 +182,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Activation input tensor
MaybeTensor pre_gelu_out = std::nullopt;
DType gelu_type = low_precision ? bias_type : D_tensor.dtype();
DType gelu_type = low_precision ? bias_type : out_tensor.dtype();
if (gelu) {
if (!grad) {
auto dtype = GetATenDType(gelu_type);
......@@ -210,7 +235,7 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
// Direct GEMM call to the correct overlap
if (bulk_overlap) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor,
comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor,
te_pre_gelu_out, te_workspace, grad, accumulate,
use_split_accumulator, comm_type.value(), extra_output_tensor,
main_stream);
......@@ -218,14 +243,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else if (comm_type.value() == CommOverlapType::AG) {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
......@@ -234,14 +259,14 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
if (comm_overlap->is_atomic_gemm()) {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator,
extra_output_tensor, main_stream);
});
} else {
NVTE_SCOPED_GIL_RELEASE({
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor,
comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor,
bias_tensor, te_pre_gelu_out, te_workspace, grad,
accumulate, use_split_accumulator, extra_output_tensor,
main_stream);
......@@ -251,15 +276,15 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
} else {
// Launch GEMM
NVTE_SCOPED_GIL_RELEASE({
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(),
nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(),
bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad,
te_workspace.data(), alpha, *beta, use_split_accumulator,
num_math_sms, main_stream);
});
}
} else {
if (D_tensor.numel() != 0 && !accumulate) {
D_tensor.zero_(main_stream);
if (out_tensor.numel() != 0 && !accumulate) {
out_tensor.zero_(main_stream);
}
if (bias.has_value()) {
if (bias->numel() != 0 && grad) {
......@@ -267,7 +292,11 @@ std::vector<py::object> gemm(py::handle A, bool transa, py::handle B, bool trans
}
}
}
if (unfused_quantization_needed) {
// Quantize the output
std::unique_ptr<Quantizer> my_quantizer = convert_quantizer(quantizer);
my_quantizer->quantize(unquantized_D_tensor, D_tensor);
}
// Pack outputs
std::vector<py::object> out;
out.emplace_back(std::move(D));
......
......@@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
......@@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA);
tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()),
getTensorShape(amax));
// quantize output and its transpose
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
......@@ -562,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
this->all_gather_usage = quantizer.attr("all_gather_usage").cast<bool>();
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {
// Change the rowwise and columnwise_data to the configured dtype.
// May be a switch between E5M2 and E4M3.
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}
std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype) const {
......@@ -917,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize
this->dtype = quantizer.attr("dtype").cast<DType>();
}
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
auto rowwise_data = tensor->get_rowwise_data();
rowwise_data.dtype = static_cast<NVTEDType>(dtype);
auto columnwise_data = tensor->get_columnwise_data();
columnwise_data.dtype = static_cast<NVTEDType>(dtype);
tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast<DType>(rowwise_data.dtype),
rowwise_data.shape);
tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast<DType>(columnwise_data.dtype),
columnwise_data.shape);
}
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {}
std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment