- 27 Dec, 2025 1 commit
-
-
xiaoxi-wangfj authored
* [PyTorch] Fuse permute+pad and unpermute+unpad ops for FP8 optimization 1.Fused `moe_permute_with_probs` + `Fp8Padding` and fused `moe_unpermute` + `Fp8Unpadding`, that can remove the explicit padding/unpadding of moe expert, improved performance and reduced peak gpu memory usage. 2.Add tests of fused permute/pad and unpermute/unpad. Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> * [PyTorch/Common] Fuse permute+pad and unpermute+unpad support with_merging_probs Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> * [PyTorch]format code Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> * [Common]perf expert_idx loaded once Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> * fix: pad_offsets can be None Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> * add padding + merging probs bwd support. Not tested Signed-off-by:
tdophung <tdophung@nvidia.com> * Fix garbage initialized act grad Signed-off-by:
tdophung <tdophung@nvidia.com> * all test passing for jax permutation + pad Signed-off-by:
tdophung <tdophung@nvidia.com> * change tokens_per_experts APIs to num_out_tokens with conservative allocation of worst case padding for output buffer Signed-off-by:
tdophung <tdophung@nvidia.com> * change test permutation to reduce test time Signed-off-by:
tdophung <tdophung@nvidia.com> * triggering PR refresh Signed-off-by:
tdophung <tdophung@nvidia.com> * format code Signed-off-by:
tdophung <tdophung@nvidia.com> * Remove some tests cases from pytorch side. Add a separate toekn_dispatch test for sanity in case combine accidentally undo an error on dispatch in the roundtrip test. Add distinction between L0 and L2 in test cases in jax Signed-off-by:
tdophung <tdophung@nvidia.com> * format code Signed-off-by:
tdophung <tdophung@nvidia.com> * remove chance for inefficiency in moving between CPU and GPU, remove redundant primitive using a new static bool for padding, add assert for align size Signed-off-by:
tdophung <tdophung@nvidia.com> * fix lint in jax Signed-off-by:
tdophung <tdophung@nvidia.com> * account for both jax newer and older than version 0.8.2. Adjusted gpu triton binding accordingly Signed-off-by:
tdophung <tdophung@nvidia.com> * format code Signed-off-by:
tdophung <tdophung@nvidia.com> * fix typo Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
xiaoxi-wangfj <690912414@qq.com> Signed-off-by:
tdophung <tdophung@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
tdophung <tdophung@nvidia.com>
-
- 09 Dec, 2025 1 commit
-
-
Teddy Do authored
* branch off of initial permutation jax-triton PR Signed-off-by:
tdophung <tdophung@nvidia.com> * Set 0 as the size of dummy tensors to reduce memory usage. Signed-off-by:
tdophung <tdophung@nvidia.com> * Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation Signed-off-by:
tdophung <tdophung@nvidia.com> * Implement primitives, wrapper, test for wrapper, edit trit on binding to accomodate scalars Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change implemementation of VJP functions to match correct pattern. Deduce some static scalar args from shapes of inputs. Accept B, S instead of num_tokens. Change test to use value_and_grad to test vjp funcs properly Signed-off-by:
tdophung <tdophung@nvidia.com> * formatting Signed-off-by:
tdophung <tdophung@nvidia.com> * fix pylint Signed-off-by:
tdophung <tdophung@nvidia.com> * fix test to compare to the correct reference impl. relax 1 tol for grad compare, fix lint the rightway Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test_permutation to use value_and_grad for reference impl, tighten tols, and add unpermute with probs for token combine bwd rule Signed-off-by:
tdophung <tdophung@nvidia.com> * added forgotten file in prev commit Signed-off-by:
tdophung <tdophung@nvidia.com> * format Signed-off-by:
tdophung <tdophung@nvidia.com> * merge with_probs to without_probs Signed-off-by:
tdophung <tdophung@nvidia.com> * add aserts and fix lint Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
tdophung <tdophung@nvidia.com> Co-authored-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
-
- 02 Dec, 2025 1 commit
-
-
Phuong Nguyen authored
* init triton binding with test case/example * added Triton as TE-JAX test dependency * grid with blocksize from autotune Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
-