Add primary weighs fp8 support for mxfp8 (#2055)
* Add primary weighs fp8 support for mxfp8 Signed-off-by:kunlunl <kunlunl@nvidia.com> * Fix unit test and add better error log to unit test Signed-off-by:
kunlunl <kunlunl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Move post all-gather processing out of for loop Signed-off-by:
kunlunl <kunlunl@nvidia.com> * Add descriptions and ASCII diagrams for partial cast and partial amax functions Signed-off-by:
kunlunl <kunlunl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix based on greptile bot Signed-off-by:
kunlunl <kunlunl@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix compilation errors due to arch-specific PTX instructions Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Remove unused noop flag from C API Signed-off-by:
Tim Moon <tmoon@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Expose test_partial_cast Signed-off-by:
kunlunl <kunlunl@nvidia.com> * Skip mxfp8 partial cast test if mxfp8 is not available Signed-off-by:
kunlunl <kunlunl@nvidia.com> * Fix pytest error Signed-off-by:
kunlunl <kunlunl@nvidia.com> * pylint ignore unused manual_post_all_gather_processing Signed-off-by:
kunlunl <kunlunl@nvidia.com> * Fix error when using is_mxfp8_available Signed-off-by:
kunlunl <kunlunl@nvidia.com> --------- Signed-off-by:
kunlunl <kunlunl@nvidia.com> Signed-off-by:
Tim Moon <tmoon@nvidia.com> Co-authored-by:
pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by:
Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by:
Tim Moon <tmoon@nvidia.com>
Showing
Please register or sign in to comment