[Flax] adding support for batch norm layers (#21581)
* [flax] adding support for batch norm layers * fixing bugs related to pt+flax integration * cleanup, batchnorm support in sharded pt to flax * support for batchnorm tests in pt+flax integration * simplifying checking batch norm layer
Showing
Please register or sign in to comment