[JAX] Bug fix for distributed normalization (#1366)
* fix ctx.aval_out indexing for workspace
* add cudnn init to prepare phase of norm custom calls
* add thread_local for norm registry instance
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment