[JAX] Flatten_axis for quantization and Sharding propagation fixes (#1644)
* rename QuantizeAxis to QuantizeLayout, get_layout to get_data_layout, q_axis to q_layout
* add fatten_axis option
* added gated act to test encoder
* sharding constraint fixes
* fix padding when flattening first dim needs to be padded
* update test sizes so that padding is tested
* rm output sharding as it can be done in the flax module
* sharding scale_inv for mxfp8
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment