[JAX] Remove unneccessary MXFP8 scale_inv padding (#1954)
* remove unnecessary padding Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> * adapt the test_distributed_layernorm byte count Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment