[JAX] Bugfix for softmax primitives accepting invalid input sharding (#664)
* Softmax now forces XLA to unshard the hidden dimension with a warning. Unittests updated to check for numerics and warning with bad sharding Signed-off-by:Alp Dener <adener@nvidia.com> * correcting cudnn-frontend version Signed-off-by:
Alp Dener <adener@nvidia.com> * fixed mismatched output sharding Signed-off-by:
Alp Dener <adener@nvidia.com> * combined softmax tests and fixed code style/linting issues Signed-off-by:
Alp Dener <adener@nvidia.com> --------- Signed-off-by:
Alp Dener <adener@nvidia.com>
Showing
Please register or sign in to comment