-
Teddy Do authored
* branch off of initial permutation jax-triton PR Signed-off-by:
tdophung <tdophung@nvidia.com> * Set 0 as the size of dummy tensors to reduce memory usage. Signed-off-by:
tdophung <tdophung@nvidia.com> * Correct setting of permuted_probs_stride_token, unpermuted_probs_stride_token and unpermuted_probs_stride_expert in unpermutation Signed-off-by:
tdophung <tdophung@nvidia.com> * Implement primitives, wrapper, test for wrapper, edit trit on binding to accomodate scalars Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Change implemementation of VJP functions to match correct pattern. Deduce some static scalar args from shapes of inputs. Accept B, S instead of num_tokens. Change test to use value_and_grad to test vjp funcs properly Signed-off-by:
tdophung <tdophung@nvidia.com> * formatting Signed-off-by:
tdophung <tdophung@nvidia.com> * fix pylint Signed-off-by:
tdophung <tdophung@nvidia.com> * fix test to compare to the correct reference impl. relax 1 tol for grad compare, fix lint the rightway Signed-off-by:
tdophung <tdophung@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test_permutation to use value_and_grad for reference impl, tighten tols, and add unpermute with probs for token combine bwd rule Signed-off-by:
tdophung <tdophung@nvidia.com> * added forgotten file in prev commit Signed-off-by:
tdophung <tdophung@nvidia.com> * format Signed-off-by:
tdophung <tdophung@nvidia.com> * merge with_probs to without_probs Signed-off-by:
tdophung <tdophung@nvidia.com> * add aserts and fix lint Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
tdophung <tdophung@nvidia.com> Co-authored-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
46c6ef31