Unverified Commit 78a38212 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)



Reset FP8 weight workspace if usages are invalid
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 5ba7953f
......@@ -42,7 +42,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype
from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState
......@@ -1293,22 +1293,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache
out = None
if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorBase):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None:
out = None
del self._fp8_workspaces[cache_name]
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
out = None
del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment