[PyTorch] Change arguments order in triton kernels to make jax-triton work (#2416)
* Change order of arguments to make jax works Signed-off-by:tdophung <tdophung@nvidia.com> * make num_experts a tl.constepxr again Signed-off-by:
tdophung <tdophung@nvidia.com> --------- Signed-off-by:
tdophung <tdophung@nvidia.com>
Showing
Please register or sign in to comment