[JAX] GroupedQuantizer and GroupedScaledTensor (#1666)
* refactor the multi_stream utils + implement nvte_multi_tensor_quantize in TE/Common * implement GroupedQuantizer and grouped_quantize in jaxx * fix logical_axes_names for transpose tensor in ScaledTensor Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com> Co-authored-by:
Hua Huang <huah@nvidia.com> Co-authored-by:
Ming Huang <mingh@nvidia.com>
Showing
Please register or sign in to comment