[JAX] Custom partitioning for Permutation primitives (#2591)
* initial impl, not tested Signed-off-by:tdophung <tdophung@nvidia.com> * consolidate different unpermute primitives with with_pad and with_merging_probs booleans. Implement partitioning for all permutation primitives Signed-off-by:
tdophung <tdophung@nvidia.com> * Add distributed test for non-padding permutation Signed-off-by:
tdophung <tdophung@nvidia.com> * fix issues in distributed test for padding permutation. Make common kernel zero intiialize output permuted scales, permuted probs and output tokens Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert zeroing in triton common kernel as it is a race condition. Instead, add extra input (aliased wiuth output) buffer to inner primitive of permutation on jax side to pass in zero intitiated buffers done with jnp zeros Signed-off-by:
tdophung <tdophung@nvidia.com> * fix utils to handle input output aliasing in autotuned kernels Signed-off-by:
tdophung <tdophung@nvidia.com> * Clean up comments, and add more comments explaining input output alias in utils Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and greptile comment Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix issues that lint fixing introduced Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
tdophung <tdophung@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Showing
Please register or sign in to comment