[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)
* add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment