[JAX] Fixed the shape miss-matching issue in MLP. (#859)
* Fixed the shape mismatching issue in MLP. Signed-off-by:Ming Huang <mingh@nvidia.com> * Add a corresponding test Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com> Co-authored-by:
Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com>
Showing
Please register or sign in to comment