Unverified Commit 78a38212 authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)



Reset FP8 weight workspace if usages are invalid
Signed-off-by: default avatarTim Moon <tmoon@nvidia.com>
parent 5ba7953f
...@@ -42,7 +42,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer ...@@ -42,7 +42,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.float8_tensor_base import Float8TensorBase
from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from ..utils import torch_get_autocast_gpu_dtype from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype
from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase
from ...common.recipe import DelayedScaling, Recipe from ...common.recipe import DelayedScaling, Recipe
from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_state import TEDebugState
...@@ -1293,22 +1293,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC): ...@@ -1293,22 +1293,30 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
# Try getting workspace from cache # Try getting workspace from cache
out = None out = None
if cache_name is not None: if cache_name is not None:
out = self._fp8_workspaces.get(cache_name, None) out = self._fp8_workspaces.get(cache_name, None)
if quantizer is not None and isinstance(out, MXFP8TensorBase):
# Reset cache if workspace is invalid
if out is not None and quantizer is not None:
reset_cache = False
if isinstance(out, Float8TensorBase):
if (
not is_non_tn_fp8_gemm_supported()
and quantizer.columnwise_usage
and out._transpose is None
):
reset_cache = True
elif isinstance(out, MXFP8TensorBase):
if quantizer.rowwise_usage and out._rowwise_data is None: if quantizer.rowwise_usage and out._rowwise_data is None:
out = None reset_cache = True
del self._fp8_workspaces[cache_name]
elif quantizer.columnwise_usage and out._columnwise_data is None: elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
out = None out = None
del self._fp8_workspaces[cache_name] del self._fp8_workspaces[cache_name]
is_debug = isinstance(quantizer, DebugQuantizer)
is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor)
if is_debug != is_out_debug_tensor:
out = None
# Gather cached Fp8 workspace if it's distributed # Gather cached Fp8 workspace if it's distributed
# NOTE: FSDP sharding is supported only for Fp8 buffers and will not work # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
# for models initialized with Fp8 primary weights. # for models initialized with Fp8 primary weights.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment