[JAX] Added unit tests for distributed LayernormMLP (#878)
* added distributed test for ln_mlp primitive
* added distributed test for LayerNorm layer
* changed error messages
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment