-
Ming-Xu Huang authored
* Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Modify with the feedback of code review Signed-off-by:
Ming Huang <mingh@nvidia.com> * Hiding FlaxFloatMeta32 inside fp8.py Signed-off-by:
Ming Huang <mingh@nvidia.com> * Make functions to be JAX tracable objects. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Rebased with mian. Signed-off-by:
Ming Huang <mingh@nvidia.com> * Update jax images for github workflow. Signed-off-by:
Ming Huang <mingh@nvidia.com> --------- Signed-off-by:
Ming Huang <mingh@nvidia.com>
eed4dfc6