[JAX] Fixes for the grouped_gemm with MXFP8 (#1945)
* memset for the mxfp8 scale padding Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment