[JAX] Scaling Enum Abstracting (#1655)
* scaling enum abstract
* rm NVTE_ from ScalingMode names
* rework scaling mode enum in grouped gemm
* fix norm sharding
---------
Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment