[JAX] `ScaledTensor1x` to store `amax` (#2117)
* added amax as an optional arg Signed-off-by:Phuong Nguyen <phuonguyen@nvidia.com> --------- Signed-off-by:
Phuong Nguyen <phuonguyen@nvidia.com>
Showing
Please register or sign in to comment