• Shubhamai's avatar
    [Flax] adding support for batch norm layers (#21581) · f7ca656f
    Shubhamai authored
    * [flax] adding support for batch norm layers
    
    * fixing bugs related to pt+flax integration
    
    * cleanup, batchnorm support in sharded pt to flax
    
    * support for batchnorm tests in pt+flax integration
    
    * simplifying checking batch norm layer
    f7ca656f
test_modeling_flax_common.py 58.3 KB