[PyTorch] Refactor FP8 workspaces in linear modules (#820)
* Initial refactor of FP8 workspaces in Linear module Signed-off-by:Tim Moon <tmoon@nvidia.com> * Remove extra kernel launch Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Minor perf optimizations Tensor base class functions in Float8Tensor have significant overhead. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug FP8 recipe test Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Refactor FP8 workspaces in LayerNormLinear and LayerNormMLP Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Document FP8 workspace function Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Revert changes to FP8 recipe tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Add support for lazy FP8 transpose caching Previous caching behavior (always fill cache) incorrectly filled cache during CUDA graph warmup steps. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Fix Pylint warnings Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug ONNX export ONNX FP8 cast ops assumed that FP8 scales were created during model export (i.e. not initialized during training). Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug fused attention tests Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Make sure Float8Tensor.transpose_2d is backward compatible Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Revert changes to ONNX export operations Work around ONNX test failures by filling FP8 scale tensors instead of copying into them. Signed-off-by:
Tim Moon <tmoon@nvidia.com> * Debug scale factor update in Float8Tensor transpose_2d Signed-off-by:
Tim Moon <tmoon@nvidia.com> --------- Signed-off-by:
Tim Moon <tmoon@nvidia.com>
Showing
This diff is collapsed.
Please register or sign in to comment