JAX small changes (#251)
* Use the same default in the function to what the class default. Signed-off-by:Frederic Bastien <fbastien@nvidia.com> * Assert instead of silently ignoring not supported variation. Small doc correction, amax_compute_algo is partially supported. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix line lenght to fix the CI. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Apply suggestions from code review Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * grammar Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Clarify that it is only TE/JAX that don't support that faeture. Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Update transformer_engine/jax/fp8.py Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> * Update the test following the change in default value Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> * Fix ci Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by:
Frederic Bastien <fbastien@nvidia.com> Signed-off-by:
Frédéric Bastien <frederic.bastien@gmail.com> Signed-off-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Showing
Please register or sign in to comment