Unverified Commit 7fa0f554 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)



* Test working as I think it should work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

revert accidental change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

bug: missed a } in the code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

missed a quant code removal
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug fix
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor code cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor cosmetics
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Address review comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor comment update
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug: quantizer should not be none for unfused quantization
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* redundant implementation for the pytorch to te hook up, refactoring to be done later
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* all gated kernels modified, pytest working for oss swiglu
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix the merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] fix cross entropy vanishing gradients (#2139)

* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix bug when enabling --overlap-grad-reduce in mcore (#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CUDA version in setup.py (#2132)

* Fix CUDA version in setup.py
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Re-enable building comm-gemm tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* WAR for nvidia-nvshmem package
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)

* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)

Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fixing few issues with multi-process launching. (#2155)

* Fixing few issues with multi-process launching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Update list of authorized CI users (#2152)
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

a bit of cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* accidentally had removed some activations, minor bug in the templated function
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262513 +0000

parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262476 +0000

parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262304 +0000

merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] fix cross entropy vanishing gradients (#2139)

* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

Fix bug when enabling --overlap-grad-reduce in mcore (#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix CUDA version in setup.py (#2132)

* Fix CUDA version in setup.py
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Re-enable building comm-gemm tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* WAR for nvidia-nvshmem package
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)

* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)

Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Fixing few issues with multi-process launching. (#2155)

* Fixing few issues with multi-process launching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Update list of authorized CI users (#2152)
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

Fused RoPE with combined QKV input. (#2122)

* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* accidentally removed the copyright
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix linting issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor issue in comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Commit is for another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* revert changes since this belongs to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert change back since belongs to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changes belong to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes here
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162)

* add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[JAX] Scale swizzling via JAX transpose op (#2163)

* add swizzle in jax
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added outer_impl
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Extract cpp distributed tests into a separate project (#2165)

* Extract cpp distributed tests into a separate project
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove obsolete exclusion
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Run L1_cpp_distributed tests if at least 4 GPUs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129)

* test - adds unit test for cp utilities and the utilites
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* assert line change
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



[PyTorch Debug] Fix issue with negative underflow% stat. (#2107)

* fix underflows log issue
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Lower precision gated-act to accelerate FP8 current-scaling. (#2153)

* Applying the original precision as Norm outputs' and activation compuations.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding knob to control norm output precision.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing the knob and applying lower-precision norm with current-scaling only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the error when quantizer==None
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

[PyTorch] Support activation CPU offloading in fusible ops (#2158)

* Add CPU offloading logic to ops. Fix test to compute dgrad.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure grads are contiguous in op backwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op-based MLP to CPU offloading tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Handle different weight cache behavior on Hopper/Blackwell

Add MXFP8 to CPU offload tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove MXFP8 test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174)

* Do not use norm fwd + amax fusion if cudnn backend is requested
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Read envirornment vairable directly to avoid include error
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix unjoined comm stream in UB communicator (#2160)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

FP8 Output Quantization for GEMM (#2123)

* Test working as I think it should work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert accidental change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

bug: missed a } in the code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

missed a quant code removal
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug fix
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

minor code cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor cosmetics
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Address review comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor comment update
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

minor bug: quantizer should not be none for unfused quantization
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update test_multi_process_distributed_grouped_gemm.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update dense.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address revie comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Bug solved: delayed scaling quantization with mxfp8 inputs didnt work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the unit test error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* just to trigger ci
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

TE Gemma tutorial attempt#2 (#1839)

* add tutorial files and other local changes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove extraneous code for easy debu
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make cuda graphs work with non-paged and paged attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* perf imp for kv cache ops
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add code for calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changes to make quantizers work with fp8_calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use quantizer only if needed
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* functionality of the tutorial tested and perf checked
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove files and update headers/licenses
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* update header/license
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update tutorial for review
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make weights downloadable on the fly; remove extra print statements
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and update comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add comma back, typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* sequence_start_positions should be None for training
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add paged attention numberes and update requirements.txt file
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make tutorial work on blackwell
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove gemma FT tutorial for now
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixing the headings placement and rewording attention -> kv caching
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixes from comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the images
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* misc fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add more comments to te_gemma.py and cleanup utils.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more information about the hierarchy of the classes used in the tutorial
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add better cuda graphs picture
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* addd updated cuda graphs pictures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add illustrated cuda graphs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes in documentation
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add torch.no_grad() to force reduced memory usage
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* some fixes from recent comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes from remaining comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add te_rope_emb to class desc
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix tutorial wording; add calibration fix to grouped_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

Fix memory overhead of linear layer when all gather from sequence parallel (#2125)

* fix memory overhead of all gather from sequence parallel
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* quick fix the errors when for UB buffers
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Avoid deallocating FP8 scale-invs since they are reused
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

Fix incorrect TP rank calculation when using data parallel (#2179)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045)

* feat: add cutlass group gemm support
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor multi tensor gemm interface
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor nvte_multi_stream_cublas_gemm func and add license info
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add unit test for cutlass group gemm
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add cutlass support type protect
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* add tests and fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: fix unit tests error
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: refactor host workspace malloc
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* further relex threshold and add a env var to warn fall back
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avataralan yang <89962857+cassiewilliam@users.noreply.github.com>
Co-authored-by: default avatarMin Yang <min.yang@shopee.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[PyTorch] Support FA3 for MLA and with CP (#1907)

feature(FA3,MLA,CP):
1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc
Signed-off-by: default avatarzhujian <zhujian.whu.cs@gmail.com>

Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185)

* Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix for cuDNN version condition check
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use limit=0.75 in clamped SwiGLU test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* accidentally removed a line while resolving merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* match pytorch implementation: dclamp should be 1 for borders of clamping limits as well
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix dswiglu quantization fusion bug
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* pass param by reference as much as possible
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* float should rather be bool: fix by copilot
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* { missed in activation.cpp
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor formatting change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* nvfp4 change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 25252e9f
......@@ -1736,6 +1736,80 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_clamped_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute and quantization == "nvfp4":
tols = dtype_tols(tex.DType.kFloat4E2M1)
elif quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes)
......
......@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
}
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) {
using namespace detail;
constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
}
} // namespace transformer_engine
......
......@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, e, stream);
}
......@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
......@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, e, stream);
}
......@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream);
Empty e = {};
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
}
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream);
Empty e = {};
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
}
......@@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols) {
const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
......@@ -161,7 +162,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
......@@ -171,6 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
......@@ -178,18 +184,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt;
float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
......@@ -197,7 +212,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate));
} else {
const float after_act = ActOP(act_elt, {}) * gate_elt;
const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act));
}
......@@ -300,7 +315,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) {
const size_t scale_stride_colwise, const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
......@@ -476,25 +491,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt;
float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_elt = ActOP(act_elt, p) * gate_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
......@@ -720,27 +747,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt;
float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
const float x = min(act_elt, p.limit);
const float s = sigmoidf(p.alpha * x);
act_x = x * s;
dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_elt = ActOP(act_elt, p) * gate_elt;
after_act_rowwise[j] = after_act_elt;
}
......@@ -885,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -960,15 +999,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows,
cols);
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
}
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) {
checkCuDriverContext(stream);
......@@ -1099,7 +1137,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::COLWISE:
......@@ -1116,7 +1155,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
case ScalingType::BIDIMENSIONAL:
......@@ -1125,7 +1165,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
......@@ -1133,7 +1172,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError());
break;
}); // NOLINT(*)
......@@ -1141,12 +1180,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
}
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
......@@ -1168,7 +1204,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), {}, stream);
output->flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
......@@ -1177,7 +1213,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) {
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output");
......@@ -1206,7 +1243,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), {}, stream);
grad.flat_last_dim(), p, stream);
} else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*)
......@@ -1215,7 +1252,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) {
constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input");
......@@ -1255,17 +1292,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input, output, stream);
cast_gated<ParamOP, ActOP>(gated_input, output, p, stream);
}
}
} else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream);
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape);
......@@ -1281,7 +1318,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
cudaStream_t stream) {
ParamOP &p, cudaStream_t stream) {
using namespace gated_kernels;
Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
......@@ -1290,13 +1327,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, stream);
output_tensor, p, stream);
} else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, stream);
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, p,
stream);
} else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, stream);
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, p, stream);
}
} else {
// MX scaling
......
......@@ -11,6 +11,11 @@ namespace transformer_engine {
struct Empty {};
struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
};
template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val;
......@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s);
}
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
Empty e = {};
return cval * sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
return qgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
Empty e = {};
return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
return dqgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
......@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return cval * sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
const float cval = min(p.limit, static_cast<float>(val)); // Clamping
return qgelu_with_alpha<OType, float>(cval, p.alpha);
}
template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
}
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
const bool dclamp_val = static_cast<float>(val) <= p.limit;
const float clamp_val = min(static_cast<float>(val), p.limit);
const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
return dclamp_val ? dsilu_val : 0.0f;
}
template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f);
......
......@@ -11,7 +11,7 @@
#include "../common.h"
#include "../utils.cuh"
#include "math.h"
namespace transformer_engine {
/* \brief Helper class that enables storing multiple values of type DType
......@@ -338,7 +338,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output,
const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N,
const Param params, cudaStream_t stream) {
const Param &params, cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output);
......@@ -372,7 +372,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax,
fp32 *scale_inv, const size_t N, const Param params,
fp32 *scale_inv, const size_t N, const Param &params,
cudaStream_t stream) {
if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output);
......@@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll
for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) {
__builtin_assume(max >= 0);
......@@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
bool dgate_in = true;
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
}
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p);
ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f;
if (requires_amax) {
__builtin_assume(max >= 0);
......
......@@ -205,6 +205,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
......
......@@ -3,7 +3,6 @@
*
* See LICENSE for license information.
************************************************************************/
#include "../extensions.h"
#include "common.h"
#include "pybind.h"
......@@ -12,10 +11,12 @@ namespace transformer_engine {
namespace pytorch {
namespace {
using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t);
using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t);
py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t),
const at::Tensor& input, py::handle quantizer,
int shape_divisor = 1) {
template <FuncType* act_func, auto act_func_with_args, typename... Args>
py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1,
Args&&... args) {
init_extension();
// Input tensor
......@@ -56,14 +57,28 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
// Compute activation in high precision, then quantize
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, out_nvte);
}
break;
case Impl::FULLY_FUSED:
// Compute activation directly
{
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), out_nvte.data(), stream);
}
});
}
break;
case Impl::FUSED_ACTIVATION_AMAX_FP8:
......@@ -73,7 +88,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
......@@ -84,7 +106,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
}
break;
......@@ -95,10 +124,9 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
return out_py;
}
py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor,
cudaStream_t),
const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer) {
template <DFuncType* dact_func, auto dact_func_with_args, typename... Args>
py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
py::handle quantizer, Args&&... args) {
init_extension();
// Grad output and input tensors
......@@ -142,8 +170,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
{
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
at::cuda::getCurrentCUDAStream());
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
}
......@@ -152,7 +184,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
// Compute activation backward directly
{
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
}
});
}
break;
......@@ -163,8 +200,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
......@@ -175,8 +218,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE(
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); });
NVTE_SCOPED_GIL_RELEASE({
if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
}
break;
......@@ -186,90 +235,98 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
return grad_input_py;
}
} // namespace
/* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_gelu, input, quantizer);
return activation_helper<nvte_gelu, nullptr>(input, quantizer);
}
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dgelu, grad, input, quantizer);
return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
}
py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_geglu, input, quantizer, 2);
return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
}
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dgeglu, grad, input, quantizer);
return dactivation_helper<nvte_dgeglu, nullptr>(grad, input, quantizer);
}
py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_qgelu, input, quantizer);
return activation_helper<nvte_qgelu, nullptr>(input, quantizer);
}
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dqgelu, grad, input, quantizer);
return dactivation_helper<nvte_dqgelu, nullptr>(grad, input, quantizer);
}
py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_qgeglu, input, quantizer, 2);
return activation_helper<nvte_qgeglu, nullptr>(input, quantizer, 2);
}
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dqgeglu, grad, input, quantizer);
return dactivation_helper<nvte_dqgeglu, nullptr>(grad, input, quantizer);
}
/* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_relu, input, quantizer);
return activation_helper<nvte_relu, nullptr>(input, quantizer);
}
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_drelu, grad, input, quantizer);
return dactivation_helper<nvte_drelu, nullptr>(grad, input, quantizer);
}
py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_reglu, input, quantizer, 2);
return activation_helper<nvte_reglu, nullptr>(input, quantizer, 2);
}
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dreglu, grad, input, quantizer);
return dactivation_helper<nvte_dreglu, nullptr>(grad, input, quantizer);
}
py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_srelu, input, quantizer);
return activation_helper<nvte_srelu, nullptr>(input, quantizer);
}
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsrelu, grad, input, quantizer);
return dactivation_helper<nvte_dsrelu, nullptr>(grad, input, quantizer);
}
py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_sreglu, input, quantizer, 2);
return activation_helper<nvte_sreglu, nullptr>(input, quantizer, 2);
}
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsreglu, grad, input, quantizer);
return dactivation_helper<nvte_dsreglu, nullptr>(grad, input, quantizer);
}
/* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_silu, input, quantizer);
return activation_helper<nvte_silu, nullptr>(input, quantizer);
}
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsilu, grad, input, quantizer);
return dactivation_helper<nvte_dsilu, nullptr>(grad, input, quantizer);
}
py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_swiglu, input, quantizer, 2);
return activation_helper<nvte_swiglu, nullptr>(input, quantizer, 2);
}
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dswiglu, grad, input, quantizer);
return dactivation_helper<nvte_dswiglu, nullptr>(grad, input, quantizer);
}
/* clamped functions */
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha);
}
py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer,
float limit, float alpha) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha);
}
} // namespace pytorch
......
......@@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer"));
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
......@@ -178,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer"));
m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu,
"Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
......
......@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser."""
from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU
from .activation import (
GELU,
GEGLU,
QGELU,
QGEGLU,
ReLU,
ReGLU,
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput
from .all_gather import AllGather
from .all_reduce import AllReduce
......
......@@ -28,6 +28,7 @@ __all__ = [
"SReGLU",
"SiLU",
"SwiGLU",
"ClampedSwiGLU",
]
......@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment