-
Phuong Nguyen authored
* removes unneccessary reshapes for FP8 GEMM * use nn.jax.scaled_matmul Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
5bee81e2
* removes unneccessary reshapes for FP8 GEMM * use nn.jax.scaled_matmul Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>