[JAX] Remove `dot_1_output_axes` usage in LayerNormMLP (#2029)
* remove dot1_output_axes Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment