Unverified Commit 7fa0f554 authored by vthumbe1503's avatar vthumbe1503 Committed by GitHub
Browse files

[Pytorch] Support for Swiglu Activation used in GPT OSS (#2161)



* Test working as I think it should work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

revert accidental change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

bug: missed a } in the code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

missed a quant code removal
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug fix
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor code cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor cosmetics
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Address review comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor comment update
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug: quantizer should not be none for unfused quantization
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* initial draft of changes to get GPT oss based swiglu integrated, gated kernels needs to be fixed
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* redundant implementation for the pytorch to te hook up, refactoring to be done later
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* all gated kernels modified, pytest working for oss swiglu
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix the merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] fix cross entropy vanishing gradients (#2139)

* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix bug when enabling --overlap-grad-reduce in mcore (#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CUDA version in setup.py (#2132)

* Fix CUDA version in setup.py
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Re-enable building comm-gemm tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* WAR for nvidia-nvshmem package
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)

* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)

Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fixing few issues with multi-process launching. (#2155)

* Fixing few issues with multi-process launching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Update list of authorized CI users (#2152)
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

a bit of cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* accidentally had removed some activations, minor bug in the templated function
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262513 +0000

parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262476 +0000

parent de9ef2fe450daae0d4ea1b647a37219f72814f66
author Varun Thumbe <vthumbe@nvidia.com> 1757373536 +0000
committer Varun Thumbe <vthumbe@nvidia.com> 1758262304 +0000

merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] fix cross entropy vanishing gradients (#2139)

* fix cross entropy
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix comments
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: few more style issues
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* fix: remove grad_output_stride (unnecessary)
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix: only backward was broken
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Generalize cross entropy backward kernel to handle reduced and unreduced loss
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarCasper <casperbh.96@gmail.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

Fix bug when enabling --overlap-grad-reduce in mcore (#2142)

* fix bugs when enabling --overlap-grad-reduce in mcore
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix CI
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* format
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix CUDA version in setup.py (#2132)

* Fix CUDA version in setup.py
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Re-enable building comm-gemm tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* WAR for nvidia-nvshmem package
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[JAX] NoScaleTensor wrapper for non-quantized data (#2136)

* Custom call tests passing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix comments
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix shardy issue with amax being shape 1,1,1 instead of shape (1,)
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Add higher-precision VJP tests to test_distributed_layernorm_mlp
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Cast non-quantized kernels to input dtype in VJPs
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename HighPrecisionTensor to NoScaleTensor
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Use NoScaleTensor in pure JAX impls where it was missing
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Fix tests
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)

Fix GroupedScaledTensor creation
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Fixing few issues with multi-process launching. (#2155)

* Fixing few issues with multi-process launching.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Update list of authorized CI users (#2152)
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

Fused RoPE with combined QKV input. (#2122)

* Fused RoPE with combined QKV input.

Initial commit for Dropout with 8-bit RNG

Fix documentation

Initial commit for Fused QKV RoPE

WIP

Initial tests passing

Enable rotary percent and margin

Enable CP2, start_positions, interleaved

Cleanup test

Revert "Fix documentation"

This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715.

Revert "Initial commit for Dropout with 8-bit RNG"

This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca.

Cleanup.

Minor cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernels
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Misc. Cleanup
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Optimize kernel performance
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* Move fused_qkv_rope test to test_fused_rope.py
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* apply shared memory optimization to separate fused rope kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* accidentally removed the copyright
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix linting issue
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor issue in comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Commit is for another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* revert changes since this belongs to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert change back since belongs to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Changes belong to another PR
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Revert changes here
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162)

* add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion
Signed-off-by: default avatartongliu <tongliu@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatartongliu <tongliu@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[JAX] Scale swizzling via JAX transpose op (#2163)

* add swizzle in jax
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* added outer_impl
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* clean up FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

Extract cpp distributed tests into a separate project (#2165)

* Extract cpp distributed tests into a separate project
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove obsolete exclusion
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Run L1_cpp_distributed tests if at least 4 GPUs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129)

* test - adds unit test for cp utilities and the utilites
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* assert line change
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



[PyTorch Debug] Fix issue with negative underflow% stat. (#2107)

* fix underflows log issue
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* Address review comments, fix mxfp8 kernel bug: was not passing clamped swiglu parameter correctly
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Lower precision gated-act to accelerate FP8 current-scaling. (#2153)

* Applying the original precision as Norm outputs' and activation compuations.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding knob to control norm output precision.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Removing the knob and applying lower-precision norm with current-scaling only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fix the error when quantizer==None
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

[PyTorch] Support activation CPU offloading in fusible ops (#2158)

* Add CPU offloading logic to ops. Fix test to compute dgrad.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Make sure grads are contiguous in op backwards
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Add op-based MLP to CPU offloading tests
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Handle different weight cache behavior on Hopper/Blackwell

Add MXFP8 to CPU offload tests.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove MXFP8 test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174)

* Do not use norm fwd + amax fusion if cudnn backend is requested
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Read envirornment vairable directly to avoid include error
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Fix unjoined comm stream in UB communicator (#2160)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

FP8 Output Quantization for GEMM (#2123)

* Test working as I think it should work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* revert accidental change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

bug: missed a } in the code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

mxfp8 unfused quant support, refined unit test, remove unecessary quantization code
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

missed a quant code removal
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

minor bug fix
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Add cuBLASMp-backed GEMM-like API to TE common (#1824)

* Pick up cuBLASMp during build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Change lib order to fix link error
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Context creation, incomplete...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Test fixure
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A sanity AgGemm test, failing...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix axes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Take care of uneven distribution
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use MPI to get position of local matrices
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Refactor & fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-RS
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Gemm-AR, not working...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fixes
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Setting all-reduce epilogue for gemm-ar
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use supported shapes for GEMM-AR
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tolerance
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* First shot at fp8
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use TensorHolder in tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Support comm_sm_count
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Parametrize dtypes for A, B and D separately
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak scaling
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Amax ptr
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Flags parity with cublas_gemm, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Cleanup
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Bias tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix bias test
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Aux, saving...
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* aux_ld
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* A fix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Use test::Tensor
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Set scale inv
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove unsupported test configs
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Replace libcal with NCCL
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add NVTX markers to API functions
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Tweak GemmAr tests
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More test config
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix merge fallout
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove MPI dependency, comment API, add algo parameter
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem dependency
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix nvshmem build
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Excluse CommGemm tests from L0_cppunittest
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Add cpp_distributed sh file for CI
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Adapt tp TensorAllocator
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Skip GemmAr test on unsupported HW
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Oversibscribe is needed on some clusters
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Fix incomplete libcal removal
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Move CI tests to L1
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Rename context to include NVTE prefix
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove leftover code
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* NVTE_WITH_CUBLASMP off by default
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed NVTE_CHECK diag
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Comment API
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Include stdbool header for legacy C compilers
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Remove now unused argument
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* Abstract away cuBLASMp algo behind our own enum
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* More detailed shape diag messages
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/include/transformer_engine/comm_gemm.h
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>

* Add license
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

---------
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: default avatarVladimir Cherepanov <56651474+mk-61@users.noreply.github.com>
Co-authored-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPrzemyslaw Tredak <ptrendx@gmail.com>

FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)

* FP8 AllGather in FP8 GroupedGEMM

1. Support current scaling FP8 quantation with a given amax.
2. Support FP8 AG in fwd and BF16 RS in bwd.
3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Slightly refactor
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding documents of new args.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding unit-tests.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding license.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move unit-tests to L1.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Move quantizaer store/reset into FP8 only.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adding all layout support for Blackwell+
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Adopt the feedback from code-review.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

* Fixed the wrong stream used by d2d in groupedGEMM FFI.
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>

---------
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Delay MeshResource validation until first usage (#2124)

Delay MeshResource validation until first usage
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Decouple Recipe and ScalingMode (#1728)

* Decouple recipe and scaling mode
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Expose global QuantizeConfig instance as a getter
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Format and lint
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Rename UsageType to TensorSource
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* Update test_layer.py
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: default avatarjberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com>

[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)

* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)

* add amax input to DBiasQuantizePrimitive and FFI
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make sure amax is init with zero
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

* fix sharding rule
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

---------
Signed-off-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121)
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Temporarily remove comm_gemm tests (#2133)
Signed-off-by: default avatarVladimir Cherepanov <vcherepanov@nvidia.com>

[PyTorch] Disable determinism for sm100 (#2130)

* disable determinism for sm100+ and cudnn<9.14
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* fix remaining CI failures
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* revert some changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* revert more changes
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* remove sm100 from determinism table
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

---------
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch] ONNX export of FP8 Current Scaling (#2068)

* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* code drop
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

* apply tims suggestions
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134)

use torch empty for empty shape instead of from_blob
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

build: pull cached wheels (#2127)

* build: pull cached wheels
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update setup.py
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>

---------
Signed-off-by: default avataroliver könig <okoenig@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

feat: Add support for multiple quantization modes in the UB communicators (#2043)

[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)

* add checks to cuda kernel launch and cuda API calls
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* Remove exceptions from destructors
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* fix weired dispatch in ln/rmsnorm
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

[PyTorch] Support bf16+fp8 cudagraph (#2098)

* support bf16+fp8 model
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>

---------
Signed-off-by: default avatarRobin Zhang <robinz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

Dropout with 8-bit RNG (#2014)

* Add dropout kernel with 8-bit RNG
Co-authored-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix license
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Avoid ambiguous types
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Do not enforce dropout prob is representable in 8 bits
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Expand error message
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix small statistical bug from using less-equal instead of less-than

Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints.
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* Fix linter warning
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Remove unnecessary helper function in PyTorch extensions
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Create GPU reload buffers on main stream (#2131)

* Create GPU relaod buffers on main stream
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

* Fixed typo
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-preos01.a51.clusters.nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>

minor code cleanup
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor cosmetics
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



Address review comment
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



minor comment update
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

Fix CI failures for UB overlap changes (#2149)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

minor bug: quantizer should not be none for unfused quantization
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135)

* Fix failing tests for dropout=0.1 and bias for fused attn for blackwell
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix the skip message
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Assert in fused attn bwd pass for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

Add check for sm100
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Add support to get all devs in the process for jax
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Code clean up
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* Represent attn bias using enum instead of string
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

fix linting error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)

* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* Update test_multi_process_distributed_grouped_gemm.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* Update dense.py

change accidentally added while merging
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* address revie comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Bug solved: delayed scaling quantization with mxfp8 inputs didnt work
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the unit test error
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* just to trigger ci
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

TE Gemma tutorial attempt#2 (#1839)

* add tutorial files and other local changes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove extraneous code for easy debu
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* make cuda graphs work with non-paged and paged attention
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* perf imp for kv cache ops
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add code for calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* optimize kv_cache reindex and copy kernels
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* changes to make quantizers work with fp8_calibration
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* avoid reindexing from python side
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* rename variable from previous commit
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* minor fix
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

* use quantizer only if needed
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* functionality of the tutorial tested and perf checked
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove files and update headers/licenses
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* update header/license
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* update tutorial for review
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make weights downloadable on the fly; remove extra print statements
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix lint and update comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add comma back, typo
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* sequence_start_positions should be None for training
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add paged attention numberes and update requirements.txt file
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* make tutorial work on blackwell
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* remove gemma FT tutorial for now
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixing the headings placement and rewording attention -> kv caching
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fixes from comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix the images
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* misc fixes
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add more comments to te_gemma.py and cleanup utils.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add more information about the hierarchy of the classes used in the tutorial
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* add better cuda graphs picture
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* addd updated cuda graphs pictures
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add illustrated cuda graphs
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* small fixes in documentation
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add torch.no_grad() to force reduced memory usage
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* some fixes from recent comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* more fixes from remaining comments
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* add te_rope_emb to class desc
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

* fix tutorial wording; add calibration fix to grouped_linear.py
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>

---------
Signed-off-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>

Fix memory overhead of linear layer when all gather from sequence parallel (#2125)

* fix memory overhead of all gather from sequence parallel
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* quick fix the errors when for UB buffers
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/pytorch/module/linear.py
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Avoid deallocating FP8 scale-invs since they are reused
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>

---------
Signed-off-by: default avatarYuzhong Wang <yuzhongw@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <tmoon@nvidia.com>

Fix incorrect TP rank calculation when using data parallel (#2179)
Signed-off-by: default avatardjns99 <40156487+djns99@users.noreply.github.com>

[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045)

* feat: add cutlass group gemm support
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor multi tensor gemm interface
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* refactor: refactor nvte_multi_stream_cublas_gemm func and add license info
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add unit test for cutlass group gemm
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: add cutlass support type protect
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* add tests and fix lint
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: fix unit tests error
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* feat: refactor host workspace malloc
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* update cutlass
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* further relex threshold and add a env var to warn fall back
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarMin Yang <min.yang@shopee.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avataralan yang <89962857+cassiewilliam@users.noreply.github.com>
Co-authored-by: default avatarMin Yang <min.yang@shopee.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarPhuong Nguyen <phuonguyen@nvidia.com>

[PyTorch] Support FA3 for MLA and with CP (#1907)

feature(FA3,MLA,CP):
1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward
2. Update get_attention_backend method because FA3 support MLA now
3. Add CP MLA support for FA3
4. Add unit tests for FA3 MLA CP
5. Update attention doc
Signed-off-by: default avatarzhujian <zhujian.whu.cs@gmail.com>

Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185)

* Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Minor fix for cuDNN version condition check
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Use limit=0.75 in clamped SwiGLU test
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

* Address review comments
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* accidentally removed a line while resolving merge conflict
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* match pytorch implementation: dclamp should be 1 for borders of clamping limits as well
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* fix dswiglu quantization fusion bug
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* pass param by reference as much as possible
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* float should rather be bool: fix by copilot
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* { missed in activation.cpp
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* minor formatting change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

* nvfp4 change
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>

---------
Signed-off-by: default avatarVarun Thumbe <vthumbe@nvidia.com>
Signed-off-by: default avatarVasudevan Rengasamy <vrengasamy@nvidia.com>
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Signed-off-by: default avatarvthumbe1503 <vthumbe@nvidia.com>
Signed-off-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: default avatarKshitij Lakhani <klakhani@nvidia.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarJonathan Mitchell <jomitchell@login-eos02.eos.clusters.nvidia.com>
Co-authored-by: default avatarSudhakar Singh <sudhakars@nvidia.com>
parent 25252e9f
...@@ -1736,6 +1736,80 @@ class TestBasicOps: ...@@ -1736,6 +1736,80 @@ class TestBasicOps:
torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("quantization", _quantization_list)
@pytest.mark.parametrize("quantize_forward", (False, True))
@pytest.mark.parametrize("quantize_backward", (False, True))
def test_clamped_swiglu(
self,
*,
out_shape: Iterable[int] = (32, 32),
dtype: torch.dtype,
device: torch.device = "cuda",
quantization: Optional[str],
quantize_forward: bool,
quantize_backward: bool,
limit: float = 0.75,
alpha: float = 1.702,
):
# Test SwiGLU variant used in GPT OSS.
# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2
# Skip invalid configurations
quantized_compute = quantization is not None
if not quantized_compute and (quantize_forward or quantize_backward):
pytest.skip("Quantization scheme has not been provided")
maybe_skip_quantization(quantization, dims=in_shape, device=device)
# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)
# Plain PyTorch implementation
x_glu, x_linear = x_ref.chunk(2, dim=-1)
x_glu = x_glu.clamp(min=None, max=limit)
x_linear = x_linear.clamp(min=-limit, max=limit)
out_glu = x_glu * torch.sigmoid(alpha * x_glu)
y_ref = out_glu * (x_linear + 1)
y_ref.backward(dy_ref)
# Implementation with fusible operation
recipe = make_recipe(quantization)
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=quantize_backward),
te_ops.ClampedSwiGLU(limit=limit, alpha=alpha),
te_ops.Quantize(forward=quantize_forward, backward=False),
)
with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe):
y_test = forward(x_test)
y_test.backward(dy_test)
# Expected numerical error
tols = dtype_tols(dtype)
if quantized_compute and quantization == "nvfp4":
tols = dtype_tols(tex.DType.kFloat4E2M1)
elif quantized_compute:
tols = dtype_tols(tex.DType.kFloat8E4M3)
# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)
@pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5))
@pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2)))
@pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("dtype", _dtypes)
......
...@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, ...@@ -51,22 +51,20 @@ void dact_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output,
} }
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)> template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &)>
void gated_act_fn(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void gated_act_fn(const NVTETensor input, NVTETensor output, Param &p, cudaStream_t stream) {
using namespace detail; using namespace detail;
constexpr bool IS_DGATED = false; constexpr bool IS_DGATED = false;
constexpr NVTETensor grad = nullptr; constexpr NVTETensor grad = nullptr;
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, p, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, nullptr>(grad, input, output, stream);
} }
template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &), template <typename ComputeType, typename Param, ComputeType (*ActOP)(ComputeType, const Param &),
ComputeType (*DActOP)(ComputeType, const Param &)> ComputeType (*DActOP)(ComputeType, const Param &)>
void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, void dgated_act_fn(const NVTETensor grad, const NVTETensor input, NVTETensor output, Param &p,
cudaStream_t stream) { cudaStream_t stream) {
using namespace detail; using namespace detail;
constexpr bool IS_DGATED = true; constexpr bool IS_DGATED = true;
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, p, stream);
quantize_gated_helper<IS_DGATED, Param, ActOP, DActOP>(grad, input, output, stream);
} }
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -23,14 +23,16 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_geglu); NVTE_API_CALL(nvte_geglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, stream); Empty e = {};
gated_act_fn<fp32, Empty, gelu<fp32, fp32>>(input, output, e, stream);
} }
void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dgeglu); NVTE_API_CALL(nvte_dgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, stream); Empty e = {};
dgated_act_fn<fp32, Empty, gelu<fp32, fp32>, dgelu<fp32, fp32>>(grad, input, output, e, stream);
} }
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -49,12 +51,14 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_qgeglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_qgeglu); NVTE_API_CALL(nvte_qgeglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, stream); Empty e = {};
gated_act_fn<fp32, Empty, qgelu<fp32, fp32>>(input, output, e, stream);
} }
void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dqgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgeglu); NVTE_API_CALL(nvte_dqgeglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, stream); Empty e = {};
dgated_act_fn<fp32, Empty, qgelu<fp32, fp32>, dqgelu<fp32, fp32>>(grad, input, output, e, stream);
} }
...@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -23,14 +23,16 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_reglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_reglu); NVTE_API_CALL(nvte_reglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, stream); Empty e = {};
gated_act_fn<fp32, Empty, relu<fp32, fp32>>(input, output, e, stream);
} }
void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dreglu); NVTE_API_CALL(nvte_dreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, stream); Empty e = {};
dgated_act_fn<fp32, Empty, relu<fp32, fp32>, drelu<fp32, fp32>>(grad, input, output, e, stream);
} }
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
...@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -49,12 +51,14 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_sreglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_sreglu); NVTE_API_CALL(nvte_sreglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, stream); Empty e = {};
gated_act_fn<fp32, Empty, srelu<fp32, fp32>>(input, output, e, stream);
} }
void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dsreglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dsreglu); NVTE_API_CALL(nvte_dsreglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, stream); Empty e = {};
dgated_act_fn<fp32, Empty, srelu<fp32, fp32>, dsrelu<fp32, fp32>>(grad, input, output, e, stream);
} }
...@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output ...@@ -23,12 +23,31 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swiglu); NVTE_API_CALL(nvte_swiglu);
using namespace transformer_engine; using namespace transformer_engine;
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, stream); Empty e = {};
gated_act_fn<fp32, Empty, silu<fp32, fp32>>(input, output, e, stream);
} }
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) { cudaStream_t stream) {
NVTE_API_CALL(nvte_dswiglu); NVTE_API_CALL(nvte_dswiglu);
using namespace transformer_engine; using namespace transformer_engine;
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, stream); Empty e = {};
dgated_act_fn<fp32, Empty, silu<fp32, fp32>, dsilu<fp32, fp32>>(grad, input, output, e, stream);
}
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_swiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
gated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>>(input, output, param, stream);
}
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream) {
NVTE_API_CALL(nvte_clamped_dswiglu);
using namespace transformer_engine;
ClampedSwiGLUParam param = {limit, alpha};
dgated_act_fn<fp32, ClampedSwiGLUParam, clamped_silu<fp32, fp32>, clamped_dsilu<fp32, fp32>>(
grad, input, output, param, stream);
} }
...@@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); ...@@ -173,6 +173,26 @@ void nvte_geglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/ */
void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_swiglu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the gated Swish activation of the input used in GPT OSS.
*
* See https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This Gated activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes Act(input[N, :H]) x input[N, H:]
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_swiglu(const NVTETensor input, NVTETensor output, float limit, float alpha,
cudaStream_t stream);
/*! \brief Computes the gated ReLU activation of the input. /*! \brief Computes the gated ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used. * the block quantization (MXFP8) of the specified shape of the block will be used.
...@@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu ...@@ -230,6 +250,26 @@ void nvte_dgeglu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, void nvte_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream); cudaStream_t stream);
/*! \brief Computes the gradient of gated Swish activation of the input used in GPT OSS.
*
* https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250
* This activation has two differences compared to the original SwiGLU
* 1. Both gate and pre-activations are clipped based on parameter limit.
* 2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation inspired
* by original GELU paper https://arxiv.org/pdf/1606.08415
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] limit Clipping limits for gate and pre-activation.
* \param[in] alpha Scaling factor for the sigmoid function used in the activation.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_clamped_dswiglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
float limit, float alpha, cudaStream_t stream);
/*! \brief Computes the gated ReLU activation gradient. /*! \brief Computes the gated ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used. * the block quantization (MXFP8) of the specified shape of the block will be used.
......
...@@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -55,7 +55,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_act, const __grid_constant__ CUtensorMap tensor_map_output_act,
const __grid_constant__ CUtensorMap tensor_map_output_gate, const __grid_constant__ CUtensorMap tensor_map_output_gate,
float *const amax_ptr, float *const scale_inv_ptr, float *const amax_ptr, float *const scale_inv_ptr,
const float *const scale_ptr, const size_t rows, const size_t cols) { const float *const scale_ptr, const size_t rows, const size_t cols,
const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
...@@ -161,7 +162,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -161,7 +162,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_sh_curr = out_act_sh + buff * buff_elems; OType *out_act_sh_curr = out_act_sh + buff * buff_elems;
OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems; OType *out_gate_sh_curr = out_gate_sh + buff * buff_elems;
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
...@@ -171,6 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -171,6 +171,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]); float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
...@@ -178,18 +184,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -178,18 +184,27 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; if (act_elt <= p.limit) {
dact_x = s + s * (1 - s) * p.alpha * x;
} else {
dact_x = 0.0f;
}
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
float after_dact = dact_x * grad_elt * gate_elt; float after_dact = dact_x * grad_elt * gate_elt;
float after_dgate = act_x * grad_elt; float after_dgate = dgate_elt ? act_x * grad_elt : 0.0f;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact); out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dact);
out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate); out_gate_sh_curr[shmem_idx] = static_cast<OType>(scale * after_dgate);
...@@ -197,7 +212,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -197,7 +212,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
amax = fmaxf(amax, fabsf(after_dact)); amax = fmaxf(amax, fabsf(after_dact));
amax = fmaxf(amax, fabsf(after_dgate)); amax = fmaxf(amax, fabsf(after_dgate));
} else { } else {
const float after_act = ActOP(act_elt, {}) * gate_elt; const float after_act = ActOP(act_elt, p) * gate_elt;
out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act); out_act_sh_curr[shmem_idx] = static_cast<OType>(scale * after_act);
amax = fmaxf(amax, fabsf(after_act)); amax = fmaxf(amax, fabsf(after_act));
} }
...@@ -300,7 +315,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -300,7 +315,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise, const __grid_constant__ CUtensorMap tensor_map_output_gate_colwise,
e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise, e8m0_t *const scales_rowwise, e8m0_t *const scales_colwise,
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise, const ParamOP p) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
using IType2 = typename ptx::FPx2<IType>; using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>; using OType2 = typename ptx::FPx2<OType>;
...@@ -476,25 +491,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -476,25 +491,37 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]); float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt; float after_act_elt;
float after_gate_elt; float after_gate_elt;
bool dgate_elt = true; // gating is ideally an identity function
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]); float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
after_act_elt = dact_x * grad_elt * gate_elt; after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt; after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
} else { } else {
after_act_elt = ActOP(act_elt, {}) * gate_elt; after_act_elt = ActOP(act_elt, p) * gate_elt;
} }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) { if constexpr (!std::is_same_v<IType, float>) {
...@@ -720,27 +747,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -720,27 +747,39 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
float gate_elt = static_cast<float>(in_gate.data.elt[e]); float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt; float after_act_elt;
float after_gate_elt; float after_gate_elt;
bool dgate_elt = true;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
dgate_elt = gate_elt <= p.limit && gate_elt >= -p.limit; // Derivative of clamp
gate_elt = min(max(-p.limit, gate_elt), p.limit) + 1.0f;
}
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]); float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt; const float x = act_elt;
float act_x; float act_x;
float dact_x; float dact_x;
if constexpr (std::is_same<ParamOP, ClampedSwiGLUParam>::value) {
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { const float x = min(act_elt, p.limit);
const float s = sigmoidf(x); const float s = sigmoidf(p.alpha * x);
act_x = x * s; act_x = x * s;
dact_x = x * s * (1 - s) + s; dact_x = act_elt <= p.limit ? s + s * (1 - s) * p.alpha * x : 0.0f;
} else { } else {
act_x = ActOP(x, {}); if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
dact_x = DActOP(x, {}); const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, p);
dact_x = DActOP(x, p);
}
} }
after_act_elt = dact_x * grad_elt * gate_elt; after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt; after_gate_elt = dgate_elt ? act_x * grad_elt : 0.0f;
after_act_rowwise[j] = after_act_elt; after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt; after_gate_rowwise[j] = after_gate_elt;
} else { } else {
after_act_elt = ActOP(act_elt, {}) * gate_elt; after_act_elt = ActOP(act_elt, p) * gate_elt;
after_act_rowwise[j] = after_act_elt; after_act_rowwise[j] = after_act_elt;
} }
...@@ -885,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -885,7 +924,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -960,15 +999,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -960,15 +999,14 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType> cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>
<<<grid_dim, block_dim, shmem_size, stream>>>( <<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act,
tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols, p);
cols);
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
} }
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -1099,7 +1137,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1099,7 +1137,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
...@@ -1116,7 +1155,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1116,7 +1155,8 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
...@@ -1125,7 +1165,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1125,7 +1165,6 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
OType, true, true, OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>, THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType, mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE> true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>( <<<grid, block_size, shmem_size, stream>>>(
...@@ -1133,7 +1172,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1133,7 +1172,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise); scale_stride_colwise, p);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
break; break;
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1141,12 +1180,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -1141,12 +1180,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
} }
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)> template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>
void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_gated(const Tensor &input, Tensor *output, ParamOP &p, cudaStream_t stream) {
CheckInputTensor(input, "gated_act_input"); CheckInputTensor(input, "gated_act_input");
CheckOutputTensor(*output, "gated_act_output"); CheckOutputTensor(*output, "gated_act_output");
NVTE_CHECK(output->flat_first_dim() == input.flat_first_dim(),
"Wrong output shape. Expected (after flattening) [", input.flat_first_dim(),
", *], got [", output->flat_first_dim(), ", ", output->flat_last_dim(), "].");
NVTE_CHECK(input.flat_last_dim() % 2 == 0, NVTE_CHECK(input.flat_last_dim() % 2 == 0,
"Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [", "Wrong input shape. Expected (after flattening) last dimension to be even, ", "got [",
input.flat_first_dim(), ", ", input.flat_last_dim(), "]."); input.flat_first_dim(), ", ", input.flat_last_dim(), "].");
...@@ -1168,7 +1204,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -1168,7 +1204,7 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(), reinterpret_cast<fp32 *>(output->scale_inv.dptr), input.flat_first_dim(),
output->flat_last_dim(), {}, stream); output->flat_last_dim(), p, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1177,7 +1213,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) { ...@@ -1177,7 +1213,8 @@ void cast_gated(const Tensor &input, Tensor *output, cudaStream_t stream) {
template <typename ParamOP, float (*ActOP)(float, const ParamOP &), template <typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaStream_t stream) { void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, ParamOP &p,
cudaStream_t stream) {
CheckInputTensor(grad, "dgated_act_grad"); CheckInputTensor(grad, "dgated_act_grad");
CheckInputTensor(input, "dgated_act_input"); CheckInputTensor(input, "dgated_act_input");
CheckOutputTensor(*output, "dgated_act_output"); CheckOutputTensor(*output, "dgated_act_output");
...@@ -1206,7 +1243,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt ...@@ -1206,7 +1243,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
reinterpret_cast<const fp32 *>(output->scale.dptr), reinterpret_cast<const fp32 *>(output->scale.dptr),
reinterpret_cast<fp32 *>(output->amax.dptr), reinterpret_cast<fp32 *>(output->amax.dptr),
reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(), reinterpret_cast<fp32 *>(output->scale_inv.dptr), grad.flat_first_dim(),
grad.flat_last_dim(), {}, stream); grad.flat_last_dim(), p, stream);
} else { } else {
NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + "."); NVTE_ERROR("Not implemented scaling mode: " + to_string(output->scaling_mode) + ".");
}); // NOLINT(*) }); // NOLINT(*)
...@@ -1215,7 +1252,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt ...@@ -1215,7 +1252,7 @@ void cast_dgated(const Tensor &grad, const Tensor &input, Tensor *output, cudaSt
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, ParamOP &p,
cudaStream_t stream) { cudaStream_t stream) {
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
...@@ -1255,17 +1292,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -1255,17 +1292,17 @@ void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
if (is_delayed_tensor_scaling(output->scaling_mode)) { if (is_delayed_tensor_scaling(output->scaling_mode)) {
if (use_tma_kernels) { if (use_tma_kernels) {
cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_fp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_dgated<ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
cast_gated<ParamOP, ActOP>(gated_input, output, stream); cast_gated<ParamOP, ActOP>(gated_input, output, p, stream);
} }
} }
} else if (is_mxfp8_scaling(output->scaling_mode)) { } else if (is_mxfp8_scaling(output->scaling_mode)) {
if (use_tma_kernels) { if (use_tma_kernels) {
cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, stream); cast_mxfp8_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad, gated_input, output, p, stream);
} else { } else {
NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ", NVTE_ERROR("Invalid input shape. Expected the last dimension to be divisible ",
"by 32, got input of shape ", gated_input.data.shape); "by 32, got input of shape ", gated_input.data.shape);
...@@ -1281,7 +1318,7 @@ namespace detail { ...@@ -1281,7 +1318,7 @@ namespace detail {
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output, void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, NVTETensor output,
cudaStream_t stream) { ParamOP &p, cudaStream_t stream) {
using namespace gated_kernels; using namespace gated_kernels;
Tensor grad_empty_tensor; Tensor grad_empty_tensor;
const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor; const Tensor &grad_tensor = IS_DGATED ? *(convertNVTETensorCheck(grad)) : grad_empty_tensor;
...@@ -1290,13 +1327,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input, ...@@ -1290,13 +1327,14 @@ void quantize_gated_helper(const NVTETensor grad, const NVTETensor gated_input,
if (is_supported_by_CC_100()) { if (is_supported_by_CC_100()) {
quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, quantize_gated<IS_DGATED, ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor,
output_tensor, stream); output_tensor, p, stream);
} else { } else {
if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) { if (is_delayed_tensor_scaling(output_tensor->scaling_mode)) {
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, stream); cast_dgated<ParamOP, ActOP, DActOP>(grad_tensor, gated_input_tensor, output_tensor, p,
stream);
} else { } else {
cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, stream); cast_gated<ParamOP, ActOP>(gated_input_tensor, output_tensor, p, stream);
} }
} else { } else {
// MX scaling // MX scaling
......
...@@ -11,6 +11,11 @@ namespace transformer_engine { ...@@ -11,6 +11,11 @@ namespace transformer_engine {
struct Empty {}; struct Empty {};
struct ClampedSwiGLUParam {
float limit;
float alpha = 1.702f; // Default value for QuickGELU
};
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType gelu(const IType val, const Empty&) { __device__ inline OType gelu(const IType val, const Empty&) {
const float cval = val; const float cval = val;
...@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) { ...@@ -38,17 +43,29 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s); return s * (1.f - s);
} }
template <typename OType, typename IType>
__device__ inline OType qgelu_with_alpha(const IType val, const float alpha) {
const float cval = val;
Empty e = {};
return cval * sigmoid<float, float>(alpha * cval, e);
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) { __device__ inline OType qgelu(const IType val, const Empty& e) {
return qgelu_with_alpha<OType, IType>(val, 1.702f);
}
template <typename OType, typename IType>
__device__ inline OType dqgelu_with_alpha(const IType val, const float alpha) {
const float cval = val; const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e); Empty e = {};
return alpha * cval * dsigmoid<float, float>(alpha * cval, e) +
sigmoid<float, float>(alpha * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) { __device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val; return dqgelu_with_alpha<OType, IType>(val, 1.702f);
return 1.702f * cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
} }
template <typename OType, typename IType> template <typename OType, typename IType>
...@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) { ...@@ -57,12 +74,26 @@ __device__ inline OType silu(const IType val, const Empty& e) {
return cval * sigmoid<float, float>(cval, e); return cval * sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType>
__device__ inline OType clamped_silu(const IType val, const ClampedSwiGLUParam& p) {
const float cval = min(p.limit, static_cast<float>(val)); // Clamping
return qgelu_with_alpha<OType, float>(cval, p.alpha);
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType dsilu(const IType val, const Empty& e) { __device__ inline OType dsilu(const IType val, const Empty& e) {
const float cval = val; const float cval = val;
return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e); return cval * dsigmoid<float, float>(cval, e) + sigmoid<float, float>(cval, e);
} }
template <typename OType, typename IType>
__device__ inline OType clamped_dsilu(const IType val, const ClampedSwiGLUParam& p) {
const bool dclamp_val = static_cast<float>(val) <= p.limit;
const float clamp_val = min(static_cast<float>(val), p.limit);
const float dsilu_val = dqgelu_with_alpha<OType, float>(clamp_val, p.alpha);
return dclamp_val ? dsilu_val : 0.0f;
}
template <typename OType, typename IType> template <typename OType, typename IType>
__device__ inline OType relu(IType value, const Empty&) { __device__ inline OType relu(IType value, const Empty&) {
return fmaxf(value, 0.f); return fmaxf(value, 0.f);
......
...@@ -11,7 +11,7 @@ ...@@ -11,7 +11,7 @@
#include "../common.h" #include "../common.h"
#include "../utils.cuh" #include "../utils.cuh"
#include "math.h"
namespace transformer_engine { namespace transformer_engine {
/* \brief Helper class that enables storing multiple values of type DType /* \brief Helper class that enables storing multiple values of type DType
...@@ -338,7 +338,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen ...@@ -338,7 +338,7 @@ template <int nvec, typename Param, fp32 (*OP)(const fp32, const Param &), typen
typename OutputType> typename OutputType>
void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output, void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, OutputType *output,
const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N, const fp32 *scale, fp32 *amax, fp32 *scale_inv, const size_t N,
const Param params, cudaStream_t stream) { const Param &params, cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, output); auto align = CheckAlignment(N, nvec, input, output);
...@@ -372,7 +372,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In ...@@ -372,7 +372,7 @@ template <int nvec, typename Param, fp32 (*OP)(fp32, const Param &), typename In
typename InputTypeGrad, typename OutputType> typename InputTypeGrad, typename OutputType>
void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input, void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputType *input,
OutputType *output, const fp32 *scale, fp32 *amax, OutputType *output, const fp32 *scale, fp32 *amax,
fp32 *scale_inv, const size_t N, const Param params, fp32 *scale_inv, const size_t N, const Param &params,
cudaStream_t stream) { cudaStream_t stream) {
if (N != 0) { if (N != 0) {
auto align = CheckAlignment(N, nvec, input, grad, output); auto align = CheckAlignment(N, nvec, input, grad, output);
...@@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -431,7 +431,13 @@ __launch_bounds__(unary_kernel_threads) __global__
#pragma unroll #pragma unroll
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]); ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// Clamp the gated value and add 1 at the end
ComputeType limit = p.limit;
val2 = std::min(std::max(-limit, val2), limit) + 1;
}
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2); ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if (requires_amax) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
...@@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -532,10 +538,18 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]); const ComputeType grad_val = static_cast<ComputeType>(grad_loader.separate()[i]);
const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]); const ComputeType gelu_in = static_cast<ComputeType>(input_loader0.separate()[i]);
const ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]); ComputeType gate_in = static_cast<ComputeType>(input_loader1.separate()[i]);
bool dgate_in = true;
if constexpr (std::is_same<Param, ClampedSwiGLUParam>::value) {
// In case of GPT OSS, clamp the activation and gate values
const ComputeType limit = p.limit;
dgate_in = gate_in <= limit && gate_in >= -limit; // Derivative of clamp
gate_in = std::min(std::max(-limit, gate_in), limit) + 1.0f;
}
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p); ComputeType after_dgate = dgate_in ? grad_val * Activation(gelu_in, p) : 0.0f;
if (requires_amax) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
......
...@@ -205,6 +205,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer); ...@@ -205,6 +205,10 @@ py::object swiglu(const at::Tensor &input, py::handle quantizer);
py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer);
py::object clamped_swiglu(const at::Tensor &input, py::handle quantizer, float limit, float alpha);
py::object clamped_dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer,
float limit, float alpha);
/*************************************************************************************************** /***************************************************************************************************
* LayerNorm * LayerNorm
**************************************************************************************************/ **************************************************************************************************/
......
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
* *
* See LICENSE for license information. * See LICENSE for license information.
************************************************************************/ ************************************************************************/
#include "../extensions.h" #include "../extensions.h"
#include "common.h" #include "common.h"
#include "pybind.h" #include "pybind.h"
...@@ -12,10 +11,12 @@ namespace transformer_engine { ...@@ -12,10 +11,12 @@ namespace transformer_engine {
namespace pytorch { namespace pytorch {
namespace { namespace {
using FuncType = void(const NVTETensor, NVTETensor, cudaStream_t);
using DFuncType = void(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t);
py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cudaStream_t), template <FuncType* act_func, auto act_func_with_args, typename... Args>
const at::Tensor& input, py::handle quantizer, py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1,
int shape_divisor = 1) { Args&&... args) {
init_extension(); init_extension();
// Input tensor // Input tensor
...@@ -56,14 +57,28 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud ...@@ -56,14 +57,28 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
// Compute activation in high precision, then quantize // Compute activation in high precision, then quantize
{ {
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
quantizer_cpp->quantize(temp_nvte, out_nvte); quantizer_cpp->quantize(temp_nvte, out_nvte);
} }
break; break;
case Impl::FULLY_FUSED: case Impl::FULLY_FUSED:
// Compute activation directly // Compute activation directly
{ {
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), out_nvte.data(), stream); }); NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), out_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), out_nvte.data(), stream);
}
});
} }
break; break;
case Impl::FUSED_ACTIVATION_AMAX_FP8: case Impl::FUSED_ACTIVATION_AMAX_FP8:
...@@ -73,7 +88,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud ...@@ -73,7 +88,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] = auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype); fp8_quantizer_cpp->create_unquantized_tensor_with_amax(output_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
} }
break; break;
...@@ -84,7 +106,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud ...@@ -84,7 +106,14 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] = auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype); nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(out_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ act_func(input_nvte.data(), temp_nvte.data(), stream); }); NVTE_SCOPED_GIL_RELEASE({
if constexpr (act_func == nullptr) {
act_func_with_args(input_nvte.data(), temp_nvte.data(), std::forward<Args>(args)...,
stream);
} else {
act_func(input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, out_nvte);
} }
break; break;
...@@ -95,10 +124,9 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud ...@@ -95,10 +124,9 @@ py::object activation_forward(void (*act_func)(const NVTETensor, NVTETensor, cud
return out_py; return out_py;
} }
py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, template <DFuncType* dact_func, auto dact_func_with_args, typename... Args>
cudaStream_t), py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input,
const at::Tensor& grad_output, const at::Tensor& input, py::handle quantizer, Args&&... args) {
py::handle quantizer) {
init_extension(); init_extension();
// Grad output and input tensors // Grad output and input tensors
...@@ -142,8 +170,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen ...@@ -142,8 +170,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
{ {
auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); auto [temp_nvte, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), if constexpr (dact_func == nullptr) {
at::cuda::getCurrentCUDAStream()); dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
}); });
quantizer_cpp->quantize(temp_nvte, grad_input_nvte); quantizer_cpp->quantize(temp_nvte, grad_input_nvte);
} }
...@@ -152,7 +184,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen ...@@ -152,7 +184,12 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
// Compute activation backward directly // Compute activation backward directly
{ {
NVTE_SCOPED_GIL_RELEASE({ NVTE_SCOPED_GIL_RELEASE({
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream); if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), grad_input_nvte.data(), stream);
}
}); });
} }
break; break;
...@@ -163,8 +200,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen ...@@ -163,8 +200,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer"); NVTE_CHECK(fp8_quantizer_cpp != nullptr, "Could not cast to FP8 current scaling quantizer");
auto [temp_nvte, _] = auto [temp_nvte, _] =
fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype); fp8_quantizer_cpp->create_unquantized_tensor_with_amax(input_shape, fake_dtype);
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE({
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); fp8_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
} }
break; break;
...@@ -175,8 +218,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen ...@@ -175,8 +218,14 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid static_cast<NVFP4Quantizer*>(quantizer_cpp.get()); // Already checked cast is valid
auto [temp_nvte, _] = auto [temp_nvte, _] =
nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype); nvfp4_quantizer_cpp->create_unquantized_tensor_with_amax(grad_input_nvte, fake_dtype);
NVTE_SCOPED_GIL_RELEASE( NVTE_SCOPED_GIL_RELEASE({
{ dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream); }); if constexpr (dact_func == nullptr) {
dact_func_with_args(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(),
std::forward<Args>(args)..., stream);
} else {
dact_func(grad_output_nvte.data(), input_nvte.data(), temp_nvte.data(), stream);
}
});
nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte); nvfp4_quantizer_cpp->quantize_with_amax(temp_nvte, grad_input_nvte);
} }
break; break;
...@@ -186,90 +235,98 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen ...@@ -186,90 +235,98 @@ py::object activation_backward(void (*dact_func)(const NVTETensor, const NVTETen
return grad_input_py; return grad_input_py;
} }
} // namespace } // namespace
/* GELU and variants */ /* GELU and variants */
py::object gelu(const at::Tensor& input, py::handle quantizer) { py::object gelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_gelu, input, quantizer); return activation_helper<nvte_gelu, nullptr>(input, quantizer);
} }
py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dgelu, grad, input, quantizer); return dactivation_helper<nvte_dgelu, nullptr>(grad, input, quantizer);
} }
py::object geglu(const at::Tensor& input, py::handle quantizer) { py::object geglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_geglu, input, quantizer, 2); return activation_helper<nvte_geglu, nullptr>(input, quantizer, 2);
} }
py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dgeglu, grad, input, quantizer); return dactivation_helper<nvte_dgeglu, nullptr>(grad, input, quantizer);
} }
py::object qgelu(const at::Tensor& input, py::handle quantizer) { py::object qgelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_qgelu, input, quantizer); return activation_helper<nvte_qgelu, nullptr>(input, quantizer);
} }
py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dqgelu, grad, input, quantizer); return dactivation_helper<nvte_dqgelu, nullptr>(grad, input, quantizer);
} }
py::object qgeglu(const at::Tensor& input, py::handle quantizer) { py::object qgeglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_qgeglu, input, quantizer, 2); return activation_helper<nvte_qgeglu, nullptr>(input, quantizer, 2);
} }
py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dqgeglu, grad, input, quantizer); return dactivation_helper<nvte_dqgeglu, nullptr>(grad, input, quantizer);
} }
/* ReLU and variants */ /* ReLU and variants */
py::object relu(const at::Tensor& input, py::handle quantizer) { py::object relu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_relu, input, quantizer); return activation_helper<nvte_relu, nullptr>(input, quantizer);
} }
py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_drelu, grad, input, quantizer); return dactivation_helper<nvte_drelu, nullptr>(grad, input, quantizer);
} }
py::object reglu(const at::Tensor& input, py::handle quantizer) { py::object reglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_reglu, input, quantizer, 2); return activation_helper<nvte_reglu, nullptr>(input, quantizer, 2);
} }
py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dreglu, grad, input, quantizer); return dactivation_helper<nvte_dreglu, nullptr>(grad, input, quantizer);
} }
py::object srelu(const at::Tensor& input, py::handle quantizer) { py::object srelu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_srelu, input, quantizer); return activation_helper<nvte_srelu, nullptr>(input, quantizer);
} }
py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsrelu, grad, input, quantizer); return dactivation_helper<nvte_dsrelu, nullptr>(grad, input, quantizer);
} }
py::object sreglu(const at::Tensor& input, py::handle quantizer) { py::object sreglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_sreglu, input, quantizer, 2); return activation_helper<nvte_sreglu, nullptr>(input, quantizer, 2);
} }
py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsreglu, grad, input, quantizer); return dactivation_helper<nvte_dsreglu, nullptr>(grad, input, quantizer);
} }
/* Silu and variants */ /* Silu and variants */
py::object silu(const at::Tensor& input, py::handle quantizer) { py::object silu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_silu, input, quantizer); return activation_helper<nvte_silu, nullptr>(input, quantizer);
} }
py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dsilu, grad, input, quantizer); return dactivation_helper<nvte_dsilu, nullptr>(grad, input, quantizer);
} }
py::object swiglu(const at::Tensor& input, py::handle quantizer) { py::object swiglu(const at::Tensor& input, py::handle quantizer) {
return activation_forward(nvte_swiglu, input, quantizer, 2); return activation_helper<nvte_swiglu, nullptr>(input, quantizer, 2);
} }
py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) {
return activation_backward(nvte_dswiglu, grad, input, quantizer); return dactivation_helper<nvte_dswiglu, nullptr>(grad, input, quantizer);
}
/* clamped functions */
py::object clamped_swiglu(const at::Tensor& input, py::handle quantizer, float limit, float alpha) {
return activation_helper<nullptr, nvte_clamped_swiglu>(input, quantizer, 2, limit, alpha);
}
py::object clamped_dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer,
float limit, float alpha) {
return dactivation_helper<nullptr, nvte_clamped_dswiglu>(grad, input, quantizer, limit, alpha);
} }
} // namespace pytorch } // namespace pytorch
......
...@@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -155,6 +155,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("quantizer")); py::arg("quantizer"));
m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"),
py::arg("quantizer")); py::arg("quantizer"));
m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu,
"SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"),
py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* Backward of GELU and variants */ /* Backward of GELU and variants */
m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
...@@ -178,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -178,6 +181,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"),
py::arg("fwd_input"), py::arg("quantizer")); py::arg("fwd_input"), py::arg("quantizer"));
m.def("clamped_dswiglu", transformer_engine::pytorch::clamped_dswiglu,
"Backward of SwiGLU used in GPT OSS", py::arg("grad"), py::arg("fwd_input"),
py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f);
/* DBias + DAct fusions*/ /* DBias + DAct fusions*/
m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize",
py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer"));
......
...@@ -4,7 +4,19 @@ ...@@ -4,7 +4,19 @@
"""Single tensor operations supported by the operation fuser.""" """Single tensor operations supported by the operation fuser."""
from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU from .activation import (
GELU,
GEGLU,
QGELU,
QGEGLU,
ReLU,
ReGLU,
SReLU,
SReGLU,
SiLU,
SwiGLU,
ClampedSwiGLU,
)
from .add_extra_input import AddExtraInput from .add_extra_input import AddExtraInput
from .all_gather import AllGather from .all_gather import AllGather
from .all_reduce import AllReduce from .all_reduce import AllReduce
......
...@@ -28,6 +28,7 @@ __all__ = [ ...@@ -28,6 +28,7 @@ __all__ = [
"SReGLU", "SReGLU",
"SiLU", "SiLU",
"SwiGLU", "SwiGLU",
"ClampedSwiGLU",
] ]
...@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation): ...@@ -392,3 +393,38 @@ class SwiGLU(_ActivationOperation):
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.dswiglu(*args, **kwargs) return tex.dswiglu(*args, **kwargs)
class ClampedSwiGLU(_ActivationOperation):
r"""GPT-OSS
Implementation based on `GPT-OSS<https://github.com/openai/gpt-oss/blob/a0a84273e9e0c14a233cb9befdfd159c2bcfa6cd/gpt_oss/torch/model.py#L250>`__.
This activation has two differences compared to the original SwiGLU
1. Both gate and pre-activations are clipped based on parameter limit.
2. Activation uses sigmoid(alpha * x) instead of sigmoid(x) used in Swish activation.
.. warning:: The input tensor is chunked along the last dimension to get gates/pre-activations which is differnt
from GPT OSS implementation where the gates/pre-activations are assumed to be interleaved in the input tensor.
Parameters
----------
limit: float
The clamp limit.
alpha: float
The scaling factor for the sigmoid function used in the activation.
cache_quantized_input: bool, default = False
Quantize input tensor when caching for use in the backward pass.
"""
def __init__(
self, *, limit: float = 7.0, alpha: float = 1.702, cache_quantized_input: bool = False
):
super().__init__(cache_quantized_input=cache_quantized_input)
self.limit = limit
self.alpha = alpha
def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_swiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor:
return tex.clamped_dswiglu(*args, limit=self.limit, alpha=self.alpha, **kwargs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment