[JAX] Make SR rng state always 2D (num_devices, 4) to fix partitioning issue (#2294)
* Make SR rng state always 2D (num_devices, 4) Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com> * fix pure-jax impl Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix test shape Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
Showing
Please register or sign in to comment