Unverified Commit abbdd769 authored by guyueh1's avatar guyueh1 Committed by GitHub
Browse files

Fix mxfp8 columnwise data missing (#1593)



* Fix mxfp8 columnwise data missing when switching from validation to training
Signed-off-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Fix when you interleave training and inference
Signed-off-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* refact
Signed-off-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* rm useless code
Signed-off-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>

* Update transformer_engine/pytorch/module/base.py
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Signed-off-by: default avatarguyueh1 <140554423+guyueh1@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix linter warnings
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>

---------
Signed-off-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Signed-off-by: default avatarguyueh1 <140554423+guyueh1@users.noreply.github.com>
Signed-off-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
Co-authored-by: default avatarGuyue Huang <guyueh@login-preos02.a51.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent cf00d537
...@@ -999,6 +999,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -999,6 +999,13 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
elif quantizer.columnwise_usage and out._columnwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment