"backend/apps/vscode:/vscode.git/clone" did not exist on "3a737af190873f5a611bb8dd178cf076372f83dd"
Unverified Commit dbd0197e authored by Jinhang Choi's avatar Jinhang Choi Committed by GitHub
Browse files

Reset cache logic of weight workspace for NVFP4TensorStorage (#2524)



reset weight ws cache for NVFP4TensorStorage
Signed-off-by: default avatarJinhang Choi <jinhangc@nvidia.com>
parent eac8af6a
......@@ -45,6 +45,7 @@ from ..tensor.mxfp8_tensor import MXFP8Quantizer
from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ..tensor.storage.float8_tensor_storage import Float8TensorStorage
from ..tensor.storage.mxfp8_tensor_storage import MXFP8TensorStorage
from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage
from ..utils import (
is_non_tn_fp8_gemm_supported,
torch_get_autocast_gpu_dtype,
......@@ -1388,6 +1389,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
elif isinstance(out, NVFP4TensorStorage):
if quantizer.rowwise_usage and out._rowwise_data is None:
reset_cache = True
elif quantizer.columnwise_usage and out._columnwise_data is None:
reset_cache = True
if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer):
reset_cache = True
if reset_cache:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment