[JAX] Load modules during initialize for Norm and Act primitives (#2219)
Load modules during initialize Signed-off-by:Jeremy Berchtold <jberchtold@nvidia.com> Co-authored-by:
JAX Toolbox <jax@nvidia.com>
Showing
Please register or sign in to comment