[JAX] Fix grouped GEMM error on CUDA 12.9.1 & later (#1925)
* Fix JAX grouped gemm error on CUDA 12.9.1 & later by using 16B alignment for scale ptr Signed-off-by:Hua Huang <huah@nvidia.com> * Pad MXFP8 scales with 2*-127 instead of NaNs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com>
Showing
Please register or sign in to comment