-
jberchtold-nvidia authored
* Make SR rng state always 2D (num_devices, 4) Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix pure-jax impl Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> * fix test shape Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com> --------- Signed-off-by:
Jeremy Berchtold <jberchtold@nvidia.com>
e2f2a0b4