"vscode:/vscode.git/clone" did not exist on "a57d13cc967d04478cebcfa2c3440fb167b21c6d"
[JAX] Fix grouped GEMM error on CUDA 12.9.1 & later (#1925)
* Fix JAX grouped gemm error on CUDA 12.9.1 & later by using 16B alignment for scale ptr Signed-off-by:Hua Huang <huah@nvidia.com> * Pad MXFP8 scales with 2*-127 instead of NaNs Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com>
Showing
Please register or sign in to comment